facet_poke/
enum_.rs

1use core::ptr::NonNull;
2use facet_trait::{
3    EnumDef, EnumRepr, Facet, FieldError, Opaque, OpaqueUninit, Shape, ShapeExt as _, VariantKind,
4};
5
6use super::{ISet, Poke, PokeValue};
7
8/// Represents an enum before a variant has been selected
9pub struct PokeEnumNoVariant<'mem> {
10    data: OpaqueUninit<'mem>,
11    shape: &'static Shape,
12    def: EnumDef,
13}
14
15impl<'mem> PokeEnumNoVariant<'mem> {
16    /// Coerce back into a `PokeValue`
17    #[inline(always)]
18    pub fn into_value(self) -> PokeValue<'mem> {
19        unsafe { PokeValue::new(self.data, self.shape) }
20    }
21
22    /// Shape getter
23    #[inline(always)]
24    pub fn shape(&self) -> &'static Shape {
25        self.shape
26    }
27    /// Creates a new PokeEnumNoVariant from raw data
28    ///
29    /// # Safety
30    ///
31    /// The data buffer must match the size and alignment of the enum shape described by shape
32    pub(crate) unsafe fn new(
33        data: OpaqueUninit<'mem>,
34        shape: &'static Shape,
35        def: EnumDef,
36    ) -> Self {
37        Self { data, shape, def }
38    }
39
40    /// Sets the variant of an enum by name.
41    ///
42    /// # Errors
43    ///
44    /// Returns an error if:
45    /// - No variant with the given name exists.
46    pub fn set_variant_by_name(self, variant_name: &str) -> Result<PokeEnum<'mem>, FieldError> {
47        let variant_index = self
48            .def
49            .variants
50            .iter()
51            .enumerate()
52            .find(|(_, v)| v.name == variant_name)
53            .map(|(i, _)| i)
54            .ok_or(FieldError::NoSuchStaticField)?;
55
56        self.set_variant_by_index(variant_index)
57    }
58
59    /// Sets the variant of an enum by index.
60    ///
61    /// # Errors
62    ///
63    /// Returns an error if:
64    /// - The index is out of bounds.
65    pub fn set_variant_by_index(self, variant_index: usize) -> Result<PokeEnum<'mem>, FieldError> {
66        if variant_index >= self.def.variants.len() {
67            return Err(FieldError::IndexOutOfBounds);
68        }
69
70        // Get the current variant info
71        let variant = &self.def.variants[variant_index];
72
73        // Prepare memory for the enum
74        unsafe {
75            // Zero out the memory first to ensure clean state
76            core::ptr::write_bytes(self.data.as_mut_ptr(), 0, self.shape.layout.size());
77
78            // Set up the discriminant (tag)
79            // For enums in Rust, the first bytes contain the discriminant
80            let discriminant_value = match &variant.discriminant {
81                // If we have an explicit discriminant, use it
82                Some(discriminant) => *discriminant,
83                // Otherwise, use the variant index directly
84                None => variant_index as i64,
85            };
86
87            // Write the discriminant value based on the representation
88            match self.def.repr {
89                EnumRepr::U8 => {
90                    let tag_ptr = self.data.as_mut_ptr();
91                    *tag_ptr = discriminant_value as u8;
92                }
93                EnumRepr::U16 => {
94                    let tag_ptr = self.data.as_mut_ptr() as *mut u16;
95                    *tag_ptr = discriminant_value as u16;
96                }
97                EnumRepr::U32 => {
98                    let tag_ptr = self.data.as_mut_ptr() as *mut u32;
99                    *tag_ptr = discriminant_value as u32;
100                }
101                EnumRepr::U64 => {
102                    let tag_ptr = self.data.as_mut_ptr() as *mut u64;
103                    *tag_ptr = discriminant_value as u64;
104                }
105                EnumRepr::USize => {
106                    let tag_ptr = self.data.as_mut_ptr() as *mut usize;
107                    *tag_ptr = discriminant_value as usize;
108                }
109                EnumRepr::I8 => {
110                    let tag_ptr = self.data.as_mut_ptr() as *mut i8;
111                    *tag_ptr = discriminant_value as i8;
112                }
113                EnumRepr::I16 => {
114                    let tag_ptr = self.data.as_mut_ptr() as *mut i16;
115                    *tag_ptr = discriminant_value as i16;
116                }
117                EnumRepr::I32 => {
118                    let tag_ptr = self.data.as_mut_ptr() as *mut i32;
119                    *tag_ptr = discriminant_value as i32;
120                }
121                EnumRepr::I64 => {
122                    let tag_ptr = self.data.as_mut_ptr() as *mut i64;
123                    *tag_ptr = discriminant_value;
124                }
125                EnumRepr::ISize => {
126                    let tag_ptr = self.data.as_mut_ptr() as *mut isize;
127                    *tag_ptr = discriminant_value as isize;
128                }
129                EnumRepr::Default => {
130                    // Use a heuristic based on the number of variants
131                    if self.def.variants.len() <= 256 {
132                        // Can fit in a u8
133                        let tag_ptr = self.data.as_mut_ptr();
134                        *tag_ptr = discriminant_value as u8;
135                    } else if self.def.variants.len() <= 65536 {
136                        // Can fit in a u16
137                        let tag_ptr = self.data.as_mut_ptr() as *mut u16;
138                        *tag_ptr = discriminant_value as u16;
139                    } else {
140                        // Default to u32
141                        let tag_ptr = self.data.as_mut_ptr() as *mut u32;
142                        *tag_ptr = discriminant_value as u32;
143                    }
144                }
145                _ => {
146                    panic!("Unsupported enum representation: {:?}", self.def.repr);
147                }
148            }
149        }
150
151        // Create PokeEnum with the selected variant
152        Ok(PokeEnum {
153            data: self.data,
154            iset: Default::default(),
155            shape: self.shape,
156            def: self.def,
157            selected_variant: variant_index,
158        })
159    }
160}
161
162/// Allows poking an enum with a selected variant (setting fields, etc.)
163pub struct PokeEnum<'mem> {
164    data: OpaqueUninit<'mem>,
165    iset: ISet,
166    shape: &'static Shape,
167    def: EnumDef,
168    selected_variant: usize,
169}
170
171impl<'mem> PokeEnum<'mem> {
172    /// Returns the currently selected variant index
173    pub fn selected_variant_index(&self) -> usize {
174        self.selected_variant
175    }
176
177    /// Get a field writer for a field in the currently selected variant.
178    ///
179    /// # Errors
180    ///
181    /// Returns an error if:
182    /// - The field name doesn't exist in the selected variant.
183    /// - The selected variant is a unit variant (which has no fields).
184    pub fn variant_field_by_name<'s>(&'s mut self, name: &str) -> Result<Poke<'s>, FieldError> {
185        let variant = &self.def.variants[self.selected_variant];
186
187        // Find the field in the variant
188        match &variant.kind {
189            VariantKind::Unit => {
190                // Unit variants have no fields
191                Err(FieldError::NoSuchStaticField)
192            }
193            VariantKind::Tuple { fields } => {
194                // For tuple variants, find the field by name
195                let field = fields
196                    .iter()
197                    .find(|f| f.name == name)
198                    .ok_or(FieldError::NoSuchStaticField)?;
199
200                // Get the field's address
201                let field_data = unsafe { self.data.field_uninit(field.offset) };
202                let poke = unsafe { Poke::unchecked_new(field_data, field.shape) };
203                Ok(poke)
204            }
205            VariantKind::Struct { fields } => {
206                // For struct variants, find the field by name
207                let field = fields
208                    .iter()
209                    .find(|f| f.name == name)
210                    .ok_or(FieldError::NoSuchStaticField)?;
211
212                // Get the field's address
213                let field_data = unsafe { self.data.field_uninit(field.offset) };
214                let poke = unsafe { Poke::unchecked_new(field_data, field.shape) };
215                Ok(poke)
216            }
217            _ => {
218                panic!("Unsupported enum variant kind: {:?}", variant.kind);
219            }
220        }
221    }
222
223    /// Marks a field in the current variant as initialized.
224    ///
225    /// # Safety
226    ///
227    /// The caller must ensure that the field is not already initialized.
228    pub unsafe fn mark_field_as_initialized(&mut self, field_index: usize) {
229        self.iset.set(field_index);
230    }
231
232    /// Checks if all required fields in the enum are initialized.
233    ///
234    /// # Panics
235    ///
236    /// Panics if any field in the selected variant is not initialized.
237    pub fn assert_all_fields_initialized(&self) {
238        let variant = &self.def.variants[self.selected_variant];
239
240        // Check if all fields of the selected variant are initialized
241        match &variant.kind {
242            VariantKind::Unit => {
243                // Unit variants don't have fields, so they're always fully initialized
244            }
245            VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
246                // Check each field
247                for (field_index, field) in fields.iter().enumerate() {
248                    if !self.iset.has(field_index) {
249                        panic!(
250                            "Field '{}' of variant '{}' was not initialized. Complete schema:\n{}",
251                            field.name, variant.name, self.shape
252                        );
253                    }
254                }
255            }
256            _ => {
257                panic!("Unsupported enum variant kind: {:?}", variant.kind);
258            }
259        }
260    }
261
262    fn assert_matching_shape<T: Facet>(&self) {
263        if !self.shape.is_type::<T>() {
264            panic!(
265                "This is a partial \x1b[1;34m{}\x1b[0m, you can't build a \x1b[1;32m{}\x1b[0m out of it",
266                self.shape,
267                T::SHAPE,
268            );
269        }
270    }
271
272    /// Asserts that every field in the selected variant has been initialized and forgets the PokeEnum.
273    ///
274    /// This method is only used when the origin is borrowed.
275    /// If this method is not called, all fields will be freed when the PokeEnum is dropped.
276    ///
277    /// # Panics
278    ///
279    /// This function will panic if any required field is not initialized.
280    pub fn build_in_place(self) -> Opaque<'mem> {
281        // ensure all fields are initialized
282        self.assert_all_fields_initialized();
283        let data = unsafe { self.data.assume_init() };
284        // prevent field drops when the PokeEnum is dropped
285        core::mem::forget(self);
286        data
287    }
288
289    /// Builds a value of type `T` from the PokeEnum.
290    ///
291    /// # Panics
292    ///
293    /// This function will panic if:
294    /// - Not all fields in the selected variant have been initialized.
295    /// - The generic type parameter T does not match the shape that this PokeEnum is building.
296    pub fn build<T: Facet>(self) -> T {
297        self.assert_all_fields_initialized();
298        self.assert_matching_shape::<T>();
299
300        let result = unsafe {
301            let ptr = self.data.as_ptr() as *const T;
302            core::ptr::read(ptr)
303        };
304        core::mem::forget(self);
305        result
306    }
307
308    /// Build that PokeEnum into a boxed completed shape.
309    ///
310    /// # Panics
311    ///
312    /// This function will panic if:
313    /// - Not all fields in the selected variant have been initialized.
314    /// - The generic type parameter T does not match the shape that this PokeEnum is building.
315    pub fn build_boxed<T: Facet>(self) -> Box<T> {
316        self.assert_all_fields_initialized();
317        self.assert_matching_shape::<T>();
318
319        let boxed = unsafe { Box::from_raw(self.data.as_mut_ptr() as *mut T) };
320        core::mem::forget(self);
321        boxed
322    }
323
324    /// Moves the contents of this `PokeEnum` into a target memory location.
325    ///
326    /// # Safety
327    ///
328    /// The target pointer must be valid and properly aligned,
329    /// and must be large enough to hold the value.
330    /// The caller is responsible for ensuring that the target memory is properly deallocated
331    /// when it's no longer needed.
332    pub unsafe fn move_into(self, target: NonNull<u8>) {
333        self.assert_all_fields_initialized();
334        unsafe {
335            core::ptr::copy_nonoverlapping(
336                self.data.as_mut_ptr(),
337                target.as_ptr(),
338                self.shape.layout.size(),
339            );
340        }
341        core::mem::forget(self);
342    }
343}
344
345impl Drop for PokeEnum<'_> {
346    fn drop(&mut self) {
347        let variant = &self.def.variants[self.selected_variant];
348
349        // Drop fields based on the variant kind
350        match &variant.kind {
351            VariantKind::Unit => {
352                // Unit variants have no fields to drop
353            }
354            VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
355                // Drop each initialized field
356                for (field_index, field) in fields.iter().enumerate() {
357                    if self.iset.has(field_index) {
358                        if let Some(drop_fn) = field.shape.vtable.drop_in_place {
359                            unsafe {
360                                drop_fn(self.data.field_init(field.offset));
361                            }
362                        }
363                    }
364                }
365            }
366            _ => {
367                panic!("Unsupported enum variant kind: {:?}", variant.kind);
368            }
369        }
370    }
371}
372
373/// All possible errors when getting a variant by index or by name
374#[derive(Debug, Copy, Clone, PartialEq, Eq)]
375#[non_exhaustive]
376pub enum VariantError {
377    /// `variant_by_index` was called with an index that is out of bounds.
378    IndexOutOfBounds,
379
380    /// `variant_by_name` or `variant_by_index` was called on a non-enum type.
381    NotAnEnum,
382
383    /// `variant_by_name` was called with a name that doesn't match any variant.
384    NoSuchVariant,
385}
386
387impl std::error::Error for VariantError {}
388
389impl core::fmt::Display for VariantError {
390    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
391        match self {
392            VariantError::IndexOutOfBounds => write!(f, "Variant index out of bounds"),
393            VariantError::NotAnEnum => write!(f, "Not an enum"),
394            VariantError::NoSuchVariant => write!(f, "No such variant"),
395        }
396    }
397}