1use core::ptr::NonNull;
2use facet_trait::{
3 EnumDef, EnumRepr, Facet, FieldError, Opaque, OpaqueUninit, Shape, ShapeExt as _, VariantKind,
4};
5
6use super::{ISet, Poke, PokeValue};
7
8pub struct PokeEnumNoVariant<'mem> {
10 data: OpaqueUninit<'mem>,
11 shape: &'static Shape,
12 def: EnumDef,
13}
14
15impl<'mem> PokeEnumNoVariant<'mem> {
16 #[inline(always)]
18 pub fn into_value(self) -> PokeValue<'mem> {
19 unsafe { PokeValue::new(self.data, self.shape) }
20 }
21
22 #[inline(always)]
24 pub fn shape(&self) -> &'static Shape {
25 self.shape
26 }
27 pub(crate) unsafe fn new(
33 data: OpaqueUninit<'mem>,
34 shape: &'static Shape,
35 def: EnumDef,
36 ) -> Self {
37 Self { data, shape, def }
38 }
39
40 pub fn set_variant_by_name(self, variant_name: &str) -> Result<PokeEnum<'mem>, FieldError> {
47 let variant_index = self
48 .def
49 .variants
50 .iter()
51 .enumerate()
52 .find(|(_, v)| v.name == variant_name)
53 .map(|(i, _)| i)
54 .ok_or(FieldError::NoSuchStaticField)?;
55
56 self.set_variant_by_index(variant_index)
57 }
58
59 pub fn set_variant_by_index(self, variant_index: usize) -> Result<PokeEnum<'mem>, FieldError> {
66 if variant_index >= self.def.variants.len() {
67 return Err(FieldError::IndexOutOfBounds);
68 }
69
70 let variant = &self.def.variants[variant_index];
72
73 unsafe {
75 core::ptr::write_bytes(self.data.as_mut_ptr(), 0, self.shape.layout.size());
77
78 let discriminant_value = match &variant.discriminant {
81 Some(discriminant) => *discriminant,
83 None => variant_index as i64,
85 };
86
87 match self.def.repr {
89 EnumRepr::U8 => {
90 let tag_ptr = self.data.as_mut_ptr();
91 *tag_ptr = discriminant_value as u8;
92 }
93 EnumRepr::U16 => {
94 let tag_ptr = self.data.as_mut_ptr() as *mut u16;
95 *tag_ptr = discriminant_value as u16;
96 }
97 EnumRepr::U32 => {
98 let tag_ptr = self.data.as_mut_ptr() as *mut u32;
99 *tag_ptr = discriminant_value as u32;
100 }
101 EnumRepr::U64 => {
102 let tag_ptr = self.data.as_mut_ptr() as *mut u64;
103 *tag_ptr = discriminant_value as u64;
104 }
105 EnumRepr::USize => {
106 let tag_ptr = self.data.as_mut_ptr() as *mut usize;
107 *tag_ptr = discriminant_value as usize;
108 }
109 EnumRepr::I8 => {
110 let tag_ptr = self.data.as_mut_ptr() as *mut i8;
111 *tag_ptr = discriminant_value as i8;
112 }
113 EnumRepr::I16 => {
114 let tag_ptr = self.data.as_mut_ptr() as *mut i16;
115 *tag_ptr = discriminant_value as i16;
116 }
117 EnumRepr::I32 => {
118 let tag_ptr = self.data.as_mut_ptr() as *mut i32;
119 *tag_ptr = discriminant_value as i32;
120 }
121 EnumRepr::I64 => {
122 let tag_ptr = self.data.as_mut_ptr() as *mut i64;
123 *tag_ptr = discriminant_value;
124 }
125 EnumRepr::ISize => {
126 let tag_ptr = self.data.as_mut_ptr() as *mut isize;
127 *tag_ptr = discriminant_value as isize;
128 }
129 EnumRepr::Default => {
130 if self.def.variants.len() <= 256 {
132 let tag_ptr = self.data.as_mut_ptr();
134 *tag_ptr = discriminant_value as u8;
135 } else if self.def.variants.len() <= 65536 {
136 let tag_ptr = self.data.as_mut_ptr() as *mut u16;
138 *tag_ptr = discriminant_value as u16;
139 } else {
140 let tag_ptr = self.data.as_mut_ptr() as *mut u32;
142 *tag_ptr = discriminant_value as u32;
143 }
144 }
145 _ => {
146 panic!("Unsupported enum representation: {:?}", self.def.repr);
147 }
148 }
149 }
150
151 Ok(PokeEnum {
153 data: self.data,
154 iset: Default::default(),
155 shape: self.shape,
156 def: self.def,
157 selected_variant: variant_index,
158 })
159 }
160}
161
162pub struct PokeEnum<'mem> {
164 data: OpaqueUninit<'mem>,
165 iset: ISet,
166 shape: &'static Shape,
167 def: EnumDef,
168 selected_variant: usize,
169}
170
171impl<'mem> PokeEnum<'mem> {
172 pub fn selected_variant_index(&self) -> usize {
174 self.selected_variant
175 }
176
177 pub fn variant_field_by_name<'s>(&'s mut self, name: &str) -> Result<Poke<'s>, FieldError> {
185 let variant = &self.def.variants[self.selected_variant];
186
187 match &variant.kind {
189 VariantKind::Unit => {
190 Err(FieldError::NoSuchStaticField)
192 }
193 VariantKind::Tuple { fields } => {
194 let field = fields
196 .iter()
197 .find(|f| f.name == name)
198 .ok_or(FieldError::NoSuchStaticField)?;
199
200 let field_data = unsafe { self.data.field_uninit(field.offset) };
202 let poke = unsafe { Poke::unchecked_new(field_data, field.shape) };
203 Ok(poke)
204 }
205 VariantKind::Struct { fields } => {
206 let field = fields
208 .iter()
209 .find(|f| f.name == name)
210 .ok_or(FieldError::NoSuchStaticField)?;
211
212 let field_data = unsafe { self.data.field_uninit(field.offset) };
214 let poke = unsafe { Poke::unchecked_new(field_data, field.shape) };
215 Ok(poke)
216 }
217 _ => {
218 panic!("Unsupported enum variant kind: {:?}", variant.kind);
219 }
220 }
221 }
222
223 pub unsafe fn mark_field_as_initialized(&mut self, field_index: usize) {
229 self.iset.set(field_index);
230 }
231
232 pub fn assert_all_fields_initialized(&self) {
238 let variant = &self.def.variants[self.selected_variant];
239
240 match &variant.kind {
242 VariantKind::Unit => {
243 }
245 VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
246 for (field_index, field) in fields.iter().enumerate() {
248 if !self.iset.has(field_index) {
249 panic!(
250 "Field '{}' of variant '{}' was not initialized. Complete schema:\n{}",
251 field.name, variant.name, self.shape
252 );
253 }
254 }
255 }
256 _ => {
257 panic!("Unsupported enum variant kind: {:?}", variant.kind);
258 }
259 }
260 }
261
262 fn assert_matching_shape<T: Facet>(&self) {
263 if !self.shape.is_type::<T>() {
264 panic!(
265 "This is a partial \x1b[1;34m{}\x1b[0m, you can't build a \x1b[1;32m{}\x1b[0m out of it",
266 self.shape,
267 T::SHAPE,
268 );
269 }
270 }
271
272 pub fn build_in_place(self) -> Opaque<'mem> {
281 self.assert_all_fields_initialized();
283 let data = unsafe { self.data.assume_init() };
284 core::mem::forget(self);
286 data
287 }
288
289 pub fn build<T: Facet>(self) -> T {
297 self.assert_all_fields_initialized();
298 self.assert_matching_shape::<T>();
299
300 let result = unsafe {
301 let ptr = self.data.as_ptr() as *const T;
302 core::ptr::read(ptr)
303 };
304 core::mem::forget(self);
305 result
306 }
307
308 pub fn build_boxed<T: Facet>(self) -> Box<T> {
316 self.assert_all_fields_initialized();
317 self.assert_matching_shape::<T>();
318
319 let boxed = unsafe { Box::from_raw(self.data.as_mut_ptr() as *mut T) };
320 core::mem::forget(self);
321 boxed
322 }
323
324 pub unsafe fn move_into(self, target: NonNull<u8>) {
333 self.assert_all_fields_initialized();
334 unsafe {
335 core::ptr::copy_nonoverlapping(
336 self.data.as_mut_ptr(),
337 target.as_ptr(),
338 self.shape.layout.size(),
339 );
340 }
341 core::mem::forget(self);
342 }
343}
344
345impl Drop for PokeEnum<'_> {
346 fn drop(&mut self) {
347 let variant = &self.def.variants[self.selected_variant];
348
349 match &variant.kind {
351 VariantKind::Unit => {
352 }
354 VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
355 for (field_index, field) in fields.iter().enumerate() {
357 if self.iset.has(field_index) {
358 if let Some(drop_fn) = field.shape.vtable.drop_in_place {
359 unsafe {
360 drop_fn(self.data.field_init(field.offset));
361 }
362 }
363 }
364 }
365 }
366 _ => {
367 panic!("Unsupported enum variant kind: {:?}", variant.kind);
368 }
369 }
370 }
371}
372
373#[derive(Debug, Copy, Clone, PartialEq, Eq)]
375#[non_exhaustive]
376pub enum VariantError {
377 IndexOutOfBounds,
379
380 NotAnEnum,
382
383 NoSuchVariant,
385}
386
387impl std::error::Error for VariantError {}
388
389impl core::fmt::Display for VariantError {
390 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
391 match self {
392 VariantError::IndexOutOfBounds => write!(f, "Variant index out of bounds"),
393 VariantError::NotAnEnum => write!(f, "Not an enum"),
394 VariantError::NoSuchVariant => write!(f, "No such variant"),
395 }
396 }
397}