facet_reflect/poke/
enum_.rs

1use core::ptr::NonNull;
2#[cfg(feature = "alloc")]
3extern crate alloc;
4#[cfg(feature = "alloc")]
5use alloc::boxed::Box;
6
7use facet_core::{
8    EnumDef, EnumRepr, Facet, FieldError, Opaque, OpaqueUninit, Shape, Variant, VariantKind,
9};
10
11use crate::Guard;
12
13use super::{ISet, PokeValueUninit};
14
15/// Represents an enum before a variant has been selected
16pub struct PokeEnumNoVariant<'mem> {
17    data: OpaqueUninit<'mem>,
18    shape: &'static Shape,
19    def: EnumDef,
20}
21
22impl<'mem> PokeEnumNoVariant<'mem> {
23    /// Coerce back into a `PokeValue`
24    #[inline(always)]
25    pub fn into_value(self) -> PokeValueUninit<'mem> {
26        unsafe { PokeValueUninit::new(self.data, self.shape) }
27    }
28
29    /// Shape getter
30    #[inline(always)]
31    pub fn shape(&self) -> &'static Shape {
32        self.shape
33    }
34
35    /// Creates a new PokeEnumNoVariant from raw data
36    ///
37    /// # Safety
38    ///
39    /// The data buffer must match the size and alignment of the enum shape described by shape
40    pub(crate) unsafe fn new(
41        data: OpaqueUninit<'mem>,
42        shape: &'static Shape,
43        def: EnumDef,
44    ) -> Self {
45        Self { data, shape, def }
46    }
47
48    /// Sets the variant of an enum by name.
49    ///
50    /// # Errors
51    ///
52    /// Returns an error if:
53    /// - No variant with the given name exists.
54    pub fn set_variant_by_name(self, variant_name: &str) -> Result<PokeEnum<'mem>, FieldError> {
55        let variant_index = self
56            .def
57            .variants
58            .iter()
59            .enumerate()
60            .find(|(_, v)| v.name == variant_name)
61            .map(|(i, _)| i)
62            .ok_or(FieldError::NoSuchStaticField)?;
63
64        self.set_variant_by_index(variant_index)
65    }
66
67    /// Sets the variant of an enum by index.
68    ///
69    /// # Errors
70    ///
71    /// Returns an error if:
72    /// - The index is out of bounds.
73    pub fn set_variant_by_index(self, variant_index: usize) -> Result<PokeEnum<'mem>, FieldError> {
74        if variant_index >= self.def.variants.len() {
75            return Err(FieldError::IndexOutOfBounds);
76        }
77
78        // Get the current variant info
79        let variant = &self.def.variants[variant_index];
80
81        // Prepare memory for the enum
82        unsafe {
83            // Zero out the memory first to ensure clean state
84            core::ptr::write_bytes(self.data.as_mut_bytes(), 0, self.shape.layout.size());
85
86            // Set up the discriminant (tag)
87            // For enums in Rust, the first bytes contain the discriminant
88            let discriminant_value = match &variant.discriminant {
89                // If we have an explicit discriminant, use it
90                Some(discriminant) => *discriminant,
91                // Otherwise, use the variant index directly
92                None => variant_index as i64,
93            };
94
95            // Write the discriminant value based on the representation
96            match self.def.repr {
97                EnumRepr::U8 => {
98                    let tag_ptr = self.data.as_mut_bytes();
99                    *tag_ptr = discriminant_value as u8;
100                }
101                EnumRepr::U16 => {
102                    let tag_ptr = self.data.as_mut_bytes() as *mut u16;
103                    *tag_ptr = discriminant_value as u16;
104                }
105                EnumRepr::U32 => {
106                    let tag_ptr = self.data.as_mut_bytes() as *mut u32;
107                    *tag_ptr = discriminant_value as u32;
108                }
109                EnumRepr::U64 => {
110                    let tag_ptr = self.data.as_mut_bytes() as *mut u64;
111                    *tag_ptr = discriminant_value as u64;
112                }
113                EnumRepr::USize => {
114                    let tag_ptr = self.data.as_mut_bytes() as *mut usize;
115                    *tag_ptr = discriminant_value as usize;
116                }
117                EnumRepr::I8 => {
118                    let tag_ptr = self.data.as_mut_bytes() as *mut i8;
119                    *tag_ptr = discriminant_value as i8;
120                }
121                EnumRepr::I16 => {
122                    let tag_ptr = self.data.as_mut_bytes() as *mut i16;
123                    *tag_ptr = discriminant_value as i16;
124                }
125                EnumRepr::I32 => {
126                    let tag_ptr = self.data.as_mut_bytes() as *mut i32;
127                    *tag_ptr = discriminant_value as i32;
128                }
129                EnumRepr::I64 => {
130                    let tag_ptr = self.data.as_mut_bytes() as *mut i64;
131                    *tag_ptr = discriminant_value;
132                }
133                EnumRepr::ISize => {
134                    let tag_ptr = self.data.as_mut_bytes() as *mut isize;
135                    *tag_ptr = discriminant_value as isize;
136                }
137                _ => {
138                    panic!("Unsupported enum representation: {:?}", self.def.repr);
139                }
140            }
141        }
142
143        // Create PokeEnum with the selected variant
144        Ok(PokeEnum {
145            data: self.data,
146            iset: Default::default(),
147            shape: self.shape,
148            def: self.def,
149            selected_variant: variant_index,
150        })
151    }
152
153    /// Try to find the variant by name.
154    pub fn variant_by_name(&self, name: &str) -> Option<&Variant> {
155        self.def
156            .variants
157            .iter()
158            .find(|variant| variant.name == name)
159    }
160
161    /// Whether the enum has a variant.
162    pub fn contains_variant_with_name(&self, name: &str) -> bool {
163        self.def.variants.iter().any(|variant| variant.name == name)
164    }
165}
166
167/// Allows poking an enum with a selected variant (setting fields, etc.)
168pub struct PokeEnum<'mem> {
169    /// The internal data storage for the enum
170    ///
171    /// Note that this stores both the discriminant and the variant data
172    /// (if any), and the layout depends on the enum representation.
173    /// Use [`Self::variant_data`] to get a pointer to the variant data.
174    data: OpaqueUninit<'mem>,
175    iset: ISet,
176    shape: &'static Shape,
177    def: EnumDef,
178    selected_variant: usize,
179}
180
181impl<'mem> PokeEnum<'mem> {
182    /// Coerce back into a `PokeValue`
183    #[inline(always)]
184    pub fn into_value(self) -> PokeValueUninit<'mem> {
185        unsafe { PokeValueUninit::new(self.data, self.shape) }
186    }
187
188    pub(crate) fn variant_data(&self) -> OpaqueUninit<'mem> {
189        let variant_offset = self.def.variants[self.selected_variant].offset;
190        unsafe { self.data.field_uninit(variant_offset) }
191    }
192
193    /// Creates a new PokeEnum from raw data
194    ///
195    /// # Safety
196    ///
197    /// The data buffer must match the size and alignment of the enum shape described by shape
198    #[allow(dead_code)]
199    pub(crate) unsafe fn new(
200        data: OpaqueUninit<'mem>,
201        shape: &'static Shape,
202        def: EnumDef,
203        selected_variant: usize,
204    ) -> Self {
205        Self {
206            data,
207            iset: Default::default(),
208            shape,
209            def,
210            selected_variant,
211        }
212    }
213
214    #[inline(always)]
215    /// Shape getter
216    pub fn shape(&self) -> &'static Shape {
217        self.shape
218    }
219
220    /// Returns the currently selected variant index
221    pub fn selected_variant_index(&self) -> usize {
222        self.selected_variant
223    }
224
225    /// Gets a field by name in the currently selected variant.
226    ///
227    /// # Errors
228    ///
229    /// Returns an error if:
230    /// - The field name doesn't exist in the selected variant.
231    /// - The selected variant is a unit variant (which has no fields).
232    pub fn field_by_name(
233        &self,
234        name: &str,
235    ) -> Result<(usize, crate::PokeUninit<'mem>), FieldError> {
236        let variant = &self.def.variants[self.selected_variant];
237
238        // Find the field in the variant
239        match &variant.kind {
240            VariantKind::Unit => {
241                // Unit variants have no fields
242                Err(FieldError::NoSuchStaticField)
243            }
244            VariantKind::Tuple { fields } => {
245                // For tuple variants, find the field by name
246                let (index, field) = fields
247                    .iter()
248                    .enumerate()
249                    .find(|(_, f)| f.name == name)
250                    .ok_or(FieldError::NoSuchStaticField)?;
251
252                // Get the field's address
253                let field_data = unsafe { self.variant_data().field_uninit(field.offset) };
254                let poke = unsafe { crate::PokeUninit::unchecked_new(field_data, field.shape) };
255                Ok((index, poke))
256            }
257            VariantKind::Struct { fields } => {
258                // For struct variants, find the field by name
259                let (index, field) = fields
260                    .iter()
261                    .enumerate()
262                    .find(|(_, f)| f.name == name)
263                    .ok_or(FieldError::NoSuchStaticField)?;
264
265                // Get the field's address
266                let field_data = unsafe { self.variant_data().field_uninit(field.offset) };
267                let poke = unsafe { crate::PokeUninit::unchecked_new(field_data, field.shape) };
268                Ok((index, poke))
269            }
270            _ => {
271                panic!("Unsupported enum variant kind: {:?}", variant.kind);
272            }
273        }
274    }
275
276    /// Get a field writer for a tuple field by index in the currently selected variant.
277    ///
278    /// # Errors
279    ///
280    /// Returns an error if:
281    /// - The index is out of bounds.
282    /// - The selected variant is not a tuple variant.
283    pub fn tuple_field(&self, index: usize) -> Result<crate::PokeUninit<'mem>, FieldError> {
284        let variant = &self.def.variants[self.selected_variant];
285
286        // Make sure we're working with a tuple variant
287        match &variant.kind {
288            VariantKind::Tuple { fields } => {
289                // Check if the index is valid
290                if index >= fields.len() {
291                    return Err(FieldError::IndexOutOfBounds);
292                }
293
294                // Get the field at the specified index
295                let field = &fields[index];
296
297                // Get the field's address
298                let field_data = unsafe { self.variant_data().field_uninit(field.offset) };
299                let poke = unsafe { crate::PokeUninit::unchecked_new(field_data, field.shape) };
300                Ok(poke)
301            }
302            _ => {
303                // Not a tuple variant
304                Err(FieldError::NoSuchStaticField)
305            }
306        }
307    }
308
309    /// Marks a field in the current variant as initialized.
310    ///
311    /// # Safety
312    ///
313    /// The caller must ensure that the field is initialized. Only call this after writing to
314    /// an address gotten through [`Self::field_by_name`] or [`Self::tuple_field`].
315    pub unsafe fn mark_initialized(&mut self, field_index: usize) {
316        self.iset.set(field_index);
317    }
318
319    /// Checks if all required fields in the enum are initialized.
320    ///
321    /// # Panics
322    ///
323    /// Panics if any field in the selected variant is not initialized.
324    pub fn assert_all_fields_initialized(&self) {
325        let variant = &self.def.variants[self.selected_variant];
326
327        // Check if all fields of the selected variant are initialized
328        match &variant.kind {
329            VariantKind::Unit => {
330                // Unit variants don't have fields, so they're always fully initialized
331            }
332            VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
333                // Check each field
334                for (field_index, field) in fields.iter().enumerate() {
335                    if !self.iset.has(field_index) {
336                        panic!(
337                            "Field '{}' of variant '{}' was not initialized. Complete schema:\n{}",
338                            field.name, variant.name, self.shape
339                        );
340                    }
341                }
342            }
343            _ => {
344                panic!("Unsupported enum variant kind: {:?}", variant.kind);
345            }
346        }
347    }
348
349    fn assert_matching_shape<T: Facet>(&self) {
350        if !self.shape.is_type::<T>() {
351            panic!(
352                "This is a partial \x1b[1;34m{}\x1b[0m, you can't build a \x1b[1;32m{}\x1b[0m out of it",
353                self.shape,
354                T::SHAPE,
355            );
356        }
357    }
358
359    /// Asserts that every field in the selected variant has been initialized and forgets the PokeEnum.
360    ///
361    /// This method is only used when the origin is borrowed.
362    /// If this method is not called, all fields will be freed when the PokeEnum is dropped.
363    ///
364    /// # Panics
365    ///
366    /// This function will panic if any required field is not initialized.
367    pub fn build_in_place(self) -> Opaque<'mem> {
368        // ensure all fields are initialized
369        self.assert_all_fields_initialized();
370        let data = unsafe { self.data.assume_init() };
371        // prevent field drops when the PokeEnum is dropped
372        core::mem::forget(self);
373        data
374    }
375
376    /// Builds a value of type `T` from the PokeEnum, then deallocates the memory
377    /// that this PokeEnum was pointing to.
378    ///
379    /// # Panics
380    ///
381    /// This function will panic if:
382    /// - Not all fields in the selected variant have been initialized.
383    /// - The generic type parameter T does not match the shape that this PokeEnum is building.
384    pub fn build<T: Facet>(self, guard: Option<Guard>) -> T {
385        let mut guard = guard;
386        let this = self;
387        // this changes drop order: guard must be dropped _after_ this.
388
389        this.assert_all_fields_initialized();
390        this.assert_matching_shape::<T>();
391        if let Some(guard) = &guard {
392            guard.shape.assert_type::<T>();
393        }
394
395        let result = unsafe {
396            let ptr = this.data.as_mut_bytes() as *const T;
397            core::ptr::read(ptr)
398        };
399        guard.take(); // dealloc
400        core::mem::forget(this);
401        result
402    }
403
404    /// Build that PokeEnum into a boxed completed shape.
405    ///
406    /// # Panics
407    ///
408    /// This function will panic if:
409    /// - Not all fields in the selected variant have been initialized.
410    /// - The generic type parameter T does not match the shape that this PokeEnum is building.
411    #[cfg(feature = "alloc")]
412    pub fn build_boxed<T: Facet>(self) -> Box<T> {
413        self.assert_all_fields_initialized();
414        self.assert_matching_shape::<T>();
415
416        let boxed = unsafe { Box::from_raw(self.data.as_mut_bytes() as *mut T) };
417        core::mem::forget(self);
418        boxed
419    }
420
421    /// Moves the contents of this `PokeEnum` into a target memory location.
422    ///
423    /// # Safety
424    ///
425    /// The target pointer must be valid and properly aligned,
426    /// and must be large enough to hold the value.
427    /// The caller is responsible for ensuring that the target memory is properly deallocated
428    /// when it's no longer needed.
429    pub unsafe fn move_into(self, target: NonNull<u8>) {
430        self.assert_all_fields_initialized();
431        unsafe {
432            core::ptr::copy_nonoverlapping(
433                self.data.as_mut_bytes(),
434                target.as_ptr(),
435                self.shape.layout.size(),
436            );
437        }
438        core::mem::forget(self);
439    }
440}
441
442impl Drop for PokeEnum<'_> {
443    fn drop(&mut self) {
444        let variant = &self.def.variants[self.selected_variant];
445
446        // Drop fields based on the variant kind
447        match &variant.kind {
448            VariantKind::Unit => {
449                // Unit variants have no fields to drop
450            }
451            VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
452                // Drop each initialized field
453                for (field_index, field) in fields.iter().enumerate() {
454                    if self.iset.has(field_index) {
455                        if let Some(drop_fn) = field.shape.vtable.drop_in_place {
456                            unsafe {
457                                drop_fn(self.variant_data().field_init(field.offset));
458                            }
459                        }
460                    }
461                }
462            }
463            _ => {
464                panic!("Unsupported enum variant kind: {:?}", variant.kind);
465            }
466        }
467    }
468}
469
470/// All possible errors when getting a variant by index or by name
471#[derive(Debug, Copy, Clone, PartialEq, Eq)]
472#[non_exhaustive]
473pub enum VariantError {
474    /// `variant_by_index` was called with an index that is out of bounds.
475    IndexOutOfBounds,
476
477    /// `variant_by_name` or `variant_by_index` was called on a non-enum type.
478    NotAnEnum,
479
480    /// `variant_by_name` was called with a name that doesn't match any variant.
481    NoSuchVariant,
482}
483
484impl core::error::Error for VariantError {}
485
486impl core::fmt::Display for VariantError {
487    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
488        match self {
489            VariantError::IndexOutOfBounds => write!(f, "Variant index out of bounds"),
490            VariantError::NotAnEnum => write!(f, "Not an enum"),
491            VariantError::NoSuchVariant => write!(f, "No such variant"),
492        }
493    }
494}