facet_reflect/poke/
enum_.rs

1use core::ptr::NonNull;
2#[cfg(feature = "alloc")]
3extern crate alloc;
4#[cfg(feature = "alloc")]
5use alloc::boxed::Box;
6
7use facet_core::{EnumDef, EnumRepr, Facet, FieldError, Opaque, OpaqueUninit, Shape, VariantKind};
8
9use crate::Guard;
10
11use super::{ISet, PokeValueUninit};
12
13/// Represents an enum before a variant has been selected
14pub struct PokeEnumNoVariant<'mem> {
15    data: OpaqueUninit<'mem>,
16    shape: &'static Shape,
17    def: EnumDef,
18}
19
20impl<'mem> PokeEnumNoVariant<'mem> {
21    /// Coerce back into a `PokeValue`
22    #[inline(always)]
23    pub fn into_value(self) -> PokeValueUninit<'mem> {
24        unsafe { PokeValueUninit::new(self.data, self.shape) }
25    }
26
27    /// Shape getter
28    #[inline(always)]
29    pub fn shape(&self) -> &'static Shape {
30        self.shape
31    }
32    /// Creates a new PokeEnumNoVariant from raw data
33    ///
34    /// # Safety
35    ///
36    /// The data buffer must match the size and alignment of the enum shape described by shape
37    pub(crate) unsafe fn new(
38        data: OpaqueUninit<'mem>,
39        shape: &'static Shape,
40        def: EnumDef,
41    ) -> Self {
42        Self { data, shape, def }
43    }
44
45    /// Sets the variant of an enum by name.
46    ///
47    /// # Errors
48    ///
49    /// Returns an error if:
50    /// - No variant with the given name exists.
51    pub fn set_variant_by_name(self, variant_name: &str) -> Result<PokeEnum<'mem>, FieldError> {
52        let variant_index = self
53            .def
54            .variants
55            .iter()
56            .enumerate()
57            .find(|(_, v)| v.name == variant_name)
58            .map(|(i, _)| i)
59            .ok_or(FieldError::NoSuchStaticField)?;
60
61        self.set_variant_by_index(variant_index)
62    }
63
64    /// Sets the variant of an enum by index.
65    ///
66    /// # Errors
67    ///
68    /// Returns an error if:
69    /// - The index is out of bounds.
70    pub fn set_variant_by_index(self, variant_index: usize) -> Result<PokeEnum<'mem>, FieldError> {
71        if variant_index >= self.def.variants.len() {
72            return Err(FieldError::IndexOutOfBounds);
73        }
74
75        // Get the current variant info
76        let variant = &self.def.variants[variant_index];
77
78        // Prepare memory for the enum
79        unsafe {
80            // Zero out the memory first to ensure clean state
81            core::ptr::write_bytes(self.data.as_mut_bytes(), 0, self.shape.layout.size());
82
83            // Set up the discriminant (tag)
84            // For enums in Rust, the first bytes contain the discriminant
85            let discriminant_value = match &variant.discriminant {
86                // If we have an explicit discriminant, use it
87                Some(discriminant) => *discriminant,
88                // Otherwise, use the variant index directly
89                None => variant_index as i64,
90            };
91
92            // Write the discriminant value based on the representation
93            match self.def.repr {
94                EnumRepr::U8 => {
95                    let tag_ptr = self.data.as_mut_bytes();
96                    *tag_ptr = discriminant_value as u8;
97                }
98                EnumRepr::U16 => {
99                    let tag_ptr = self.data.as_mut_bytes() as *mut u16;
100                    *tag_ptr = discriminant_value as u16;
101                }
102                EnumRepr::U32 => {
103                    let tag_ptr = self.data.as_mut_bytes() as *mut u32;
104                    *tag_ptr = discriminant_value as u32;
105                }
106                EnumRepr::U64 => {
107                    let tag_ptr = self.data.as_mut_bytes() as *mut u64;
108                    *tag_ptr = discriminant_value as u64;
109                }
110                EnumRepr::USize => {
111                    let tag_ptr = self.data.as_mut_bytes() as *mut usize;
112                    *tag_ptr = discriminant_value as usize;
113                }
114                EnumRepr::I8 => {
115                    let tag_ptr = self.data.as_mut_bytes() as *mut i8;
116                    *tag_ptr = discriminant_value as i8;
117                }
118                EnumRepr::I16 => {
119                    let tag_ptr = self.data.as_mut_bytes() as *mut i16;
120                    *tag_ptr = discriminant_value as i16;
121                }
122                EnumRepr::I32 => {
123                    let tag_ptr = self.data.as_mut_bytes() as *mut i32;
124                    *tag_ptr = discriminant_value as i32;
125                }
126                EnumRepr::I64 => {
127                    let tag_ptr = self.data.as_mut_bytes() as *mut i64;
128                    *tag_ptr = discriminant_value;
129                }
130                EnumRepr::ISize => {
131                    let tag_ptr = self.data.as_mut_bytes() as *mut isize;
132                    *tag_ptr = discriminant_value as isize;
133                }
134                _ => {
135                    panic!("Unsupported enum representation: {:?}", self.def.repr);
136                }
137            }
138        }
139
140        // Create PokeEnum with the selected variant
141        Ok(PokeEnum {
142            data: self.data,
143            iset: Default::default(),
144            shape: self.shape,
145            def: self.def,
146            selected_variant: variant_index,
147        })
148    }
149}
150
151/// Allows poking an enum with a selected variant (setting fields, etc.)
152pub struct PokeEnum<'mem> {
153    /// The internal data storage for the enum
154    ///
155    /// Note that this stores both the discriminant and the variant data
156    /// (if any), and the layout depends on the enum representation.
157    /// Use [`Self::variant_data`] to get a pointer to the variant data.
158    data: OpaqueUninit<'mem>,
159    iset: ISet,
160    shape: &'static Shape,
161    def: EnumDef,
162    selected_variant: usize,
163}
164
165impl<'mem> PokeEnum<'mem> {
166    /// Coerce back into a `PokeValue`
167    #[inline(always)]
168    pub fn into_value(self) -> PokeValueUninit<'mem> {
169        unsafe { PokeValueUninit::new(self.data, self.shape) }
170    }
171
172    pub(crate) fn variant_data(&self) -> OpaqueUninit<'mem> {
173        let variant_offset = self.def.variants[self.selected_variant].offset;
174        unsafe { self.data.field_uninit(variant_offset) }
175    }
176
177    /// Creates a new PokeEnum from raw data
178    ///
179    /// # Safety
180    ///
181    /// The data buffer must match the size and alignment of the enum shape described by shape
182    #[allow(dead_code)]
183    pub(crate) unsafe fn new(
184        data: OpaqueUninit<'mem>,
185        shape: &'static Shape,
186        def: EnumDef,
187        selected_variant: usize,
188    ) -> Self {
189        Self {
190            data,
191            iset: Default::default(),
192            shape,
193            def,
194            selected_variant,
195        }
196    }
197
198    #[inline(always)]
199    /// Shape getter
200    pub fn shape(&self) -> &'static Shape {
201        self.shape
202    }
203
204    /// Returns the currently selected variant index
205    pub fn selected_variant_index(&self) -> usize {
206        self.selected_variant
207    }
208
209    /// Gets a field by name in the currently selected variant.
210    ///
211    /// # Errors
212    ///
213    /// Returns an error if:
214    /// - The field name doesn't exist in the selected variant.
215    /// - The selected variant is a unit variant (which has no fields).
216    pub fn field_by_name(
217        &self,
218        name: &str,
219    ) -> Result<(usize, crate::PokeUninit<'mem>), FieldError> {
220        let variant = &self.def.variants[self.selected_variant];
221
222        // Find the field in the variant
223        match &variant.kind {
224            VariantKind::Unit => {
225                // Unit variants have no fields
226                Err(FieldError::NoSuchStaticField)
227            }
228            VariantKind::Tuple { fields } => {
229                // For tuple variants, find the field by name
230                let (index, field) = fields
231                    .iter()
232                    .enumerate()
233                    .find(|(_, f)| f.name == name)
234                    .ok_or(FieldError::NoSuchStaticField)?;
235
236                // Get the field's address
237                let field_data = unsafe { self.variant_data().field_uninit(field.offset) };
238                let poke = unsafe { crate::PokeUninit::unchecked_new(field_data, field.shape) };
239                Ok((index, poke))
240            }
241            VariantKind::Struct { fields } => {
242                // For struct variants, find the field by name
243                let (index, field) = fields
244                    .iter()
245                    .enumerate()
246                    .find(|(_, f)| f.name == name)
247                    .ok_or(FieldError::NoSuchStaticField)?;
248
249                // Get the field's address
250                let field_data = unsafe { self.variant_data().field_uninit(field.offset) };
251                let poke = unsafe { crate::PokeUninit::unchecked_new(field_data, field.shape) };
252                Ok((index, poke))
253            }
254            _ => {
255                panic!("Unsupported enum variant kind: {:?}", variant.kind);
256            }
257        }
258    }
259
260    /// Get a field writer for a tuple field by index in the currently selected variant.
261    ///
262    /// # Errors
263    ///
264    /// Returns an error if:
265    /// - The index is out of bounds.
266    /// - The selected variant is not a tuple variant.
267    pub fn tuple_field(&self, index: usize) -> Result<crate::PokeUninit<'mem>, FieldError> {
268        let variant = &self.def.variants[self.selected_variant];
269
270        // Make sure we're working with a tuple variant
271        match &variant.kind {
272            VariantKind::Tuple { fields } => {
273                // Check if the index is valid
274                if index >= fields.len() {
275                    return Err(FieldError::IndexOutOfBounds);
276                }
277
278                // Get the field at the specified index
279                let field = &fields[index];
280
281                // Get the field's address
282                let field_data = unsafe { self.variant_data().field_uninit(field.offset) };
283                let poke = unsafe { crate::PokeUninit::unchecked_new(field_data, field.shape) };
284                Ok(poke)
285            }
286            _ => {
287                // Not a tuple variant
288                Err(FieldError::NoSuchStaticField)
289            }
290        }
291    }
292
293    /// Marks a field in the current variant as initialized.
294    ///
295    /// # Safety
296    ///
297    /// The caller must ensure that the field is initialized. Only call this after writing to
298    /// an address gotten through [`Self::field_by_name`] or [`Self::tuple_field`].
299    pub unsafe fn mark_initialized(&mut self, field_index: usize) {
300        self.iset.set(field_index);
301    }
302
303    /// Checks if all required fields in the enum are initialized.
304    ///
305    /// # Panics
306    ///
307    /// Panics if any field in the selected variant is not initialized.
308    pub fn assert_all_fields_initialized(&self) {
309        let variant = &self.def.variants[self.selected_variant];
310
311        // Check if all fields of the selected variant are initialized
312        match &variant.kind {
313            VariantKind::Unit => {
314                // Unit variants don't have fields, so they're always fully initialized
315            }
316            VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
317                // Check each field
318                for (field_index, field) in fields.iter().enumerate() {
319                    if !self.iset.has(field_index) {
320                        panic!(
321                            "Field '{}' of variant '{}' was not initialized. Complete schema:\n{}",
322                            field.name, variant.name, self.shape
323                        );
324                    }
325                }
326            }
327            _ => {
328                panic!("Unsupported enum variant kind: {:?}", variant.kind);
329            }
330        }
331    }
332
333    fn assert_matching_shape<T: Facet>(&self) {
334        if !self.shape.is_type::<T>() {
335            panic!(
336                "This is a partial \x1b[1;34m{}\x1b[0m, you can't build a \x1b[1;32m{}\x1b[0m out of it",
337                self.shape,
338                T::SHAPE,
339            );
340        }
341    }
342
343    /// Asserts that every field in the selected variant has been initialized and forgets the PokeEnum.
344    ///
345    /// This method is only used when the origin is borrowed.
346    /// If this method is not called, all fields will be freed when the PokeEnum is dropped.
347    ///
348    /// # Panics
349    ///
350    /// This function will panic if any required field is not initialized.
351    pub fn build_in_place(self) -> Opaque<'mem> {
352        // ensure all fields are initialized
353        self.assert_all_fields_initialized();
354        let data = unsafe { self.data.assume_init() };
355        // prevent field drops when the PokeEnum is dropped
356        core::mem::forget(self);
357        data
358    }
359
360    /// Builds a value of type `T` from the PokeEnum, then deallocates the memory
361    /// that this PokeEnum was pointing to.
362    ///
363    /// # Panics
364    ///
365    /// This function will panic if:
366    /// - Not all fields in the selected variant have been initialized.
367    /// - The generic type parameter T does not match the shape that this PokeEnum is building.
368    pub fn build<T: Facet>(self, guard: Option<Guard>) -> T {
369        let mut guard = guard;
370        let this = self;
371        // this changes drop order: guard must be dropped _after_ this.
372
373        this.assert_all_fields_initialized();
374        this.assert_matching_shape::<T>();
375        if let Some(guard) = &guard {
376            guard.shape.assert_type::<T>();
377        }
378
379        let result = unsafe {
380            let ptr = this.data.as_mut_bytes() as *const T;
381            core::ptr::read(ptr)
382        };
383        guard.take(); // dealloc
384        core::mem::forget(this);
385        result
386    }
387
388    /// Build that PokeEnum into a boxed completed shape.
389    ///
390    /// # Panics
391    ///
392    /// This function will panic if:
393    /// - Not all fields in the selected variant have been initialized.
394    /// - The generic type parameter T does not match the shape that this PokeEnum is building.
395    #[cfg(feature = "alloc")]
396    pub fn build_boxed<T: Facet>(self) -> Box<T> {
397        self.assert_all_fields_initialized();
398        self.assert_matching_shape::<T>();
399
400        let boxed = unsafe { Box::from_raw(self.data.as_mut_bytes() as *mut T) };
401        core::mem::forget(self);
402        boxed
403    }
404
405    /// Moves the contents of this `PokeEnum` into a target memory location.
406    ///
407    /// # Safety
408    ///
409    /// The target pointer must be valid and properly aligned,
410    /// and must be large enough to hold the value.
411    /// The caller is responsible for ensuring that the target memory is properly deallocated
412    /// when it's no longer needed.
413    pub unsafe fn move_into(self, target: NonNull<u8>) {
414        self.assert_all_fields_initialized();
415        unsafe {
416            core::ptr::copy_nonoverlapping(
417                self.data.as_mut_bytes(),
418                target.as_ptr(),
419                self.shape.layout.size(),
420            );
421        }
422        core::mem::forget(self);
423    }
424}
425
426impl Drop for PokeEnum<'_> {
427    fn drop(&mut self) {
428        let variant = &self.def.variants[self.selected_variant];
429
430        // Drop fields based on the variant kind
431        match &variant.kind {
432            VariantKind::Unit => {
433                // Unit variants have no fields to drop
434            }
435            VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
436                // Drop each initialized field
437                for (field_index, field) in fields.iter().enumerate() {
438                    if self.iset.has(field_index) {
439                        if let Some(drop_fn) = field.shape.vtable.drop_in_place {
440                            unsafe {
441                                drop_fn(self.variant_data().field_init(field.offset));
442                            }
443                        }
444                    }
445                }
446            }
447            _ => {
448                panic!("Unsupported enum variant kind: {:?}", variant.kind);
449            }
450        }
451    }
452}
453
454/// All possible errors when getting a variant by index or by name
455#[derive(Debug, Copy, Clone, PartialEq, Eq)]
456#[non_exhaustive]
457pub enum VariantError {
458    /// `variant_by_index` was called with an index that is out of bounds.
459    IndexOutOfBounds,
460
461    /// `variant_by_name` or `variant_by_index` was called on a non-enum type.
462    NotAnEnum,
463
464    /// `variant_by_name` was called with a name that doesn't match any variant.
465    NoSuchVariant,
466}
467
468impl core::error::Error for VariantError {}
469
470impl core::fmt::Display for VariantError {
471    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
472        match self {
473            VariantError::IndexOutOfBounds => write!(f, "Variant index out of bounds"),
474            VariantError::NotAnEnum => write!(f, "Not an enum"),
475            VariantError::NoSuchVariant => write!(f, "No such variant"),
476        }
477    }
478}