facet_poke/
enum_.rs

1use core::ptr::NonNull;
2use facet_core::{EnumDef, EnumRepr, Facet, FieldError, Opaque, OpaqueUninit, Shape, VariantKind};
3
4use crate::Guard;
5
6use super::{ISet, PokeValueUninit};
7
8/// Represents an enum before a variant has been selected
9pub struct PokeEnumNoVariant<'mem> {
10    data: OpaqueUninit<'mem>,
11    shape: &'static Shape,
12    def: EnumDef,
13}
14
15impl<'mem> PokeEnumNoVariant<'mem> {
16    /// Coerce back into a `PokeValue`
17    #[inline(always)]
18    pub fn into_value(self) -> PokeValueUninit<'mem> {
19        unsafe { PokeValueUninit::new(self.data, self.shape) }
20    }
21
22    /// Shape getter
23    #[inline(always)]
24    pub fn shape(&self) -> &'static Shape {
25        self.shape
26    }
27    /// Creates a new PokeEnumNoVariant from raw data
28    ///
29    /// # Safety
30    ///
31    /// The data buffer must match the size and alignment of the enum shape described by shape
32    pub(crate) unsafe fn new(
33        data: OpaqueUninit<'mem>,
34        shape: &'static Shape,
35        def: EnumDef,
36    ) -> Self {
37        Self { data, shape, def }
38    }
39
40    /// Sets the variant of an enum by name.
41    ///
42    /// # Errors
43    ///
44    /// Returns an error if:
45    /// - No variant with the given name exists.
46    pub fn set_variant_by_name(self, variant_name: &str) -> Result<PokeEnum<'mem>, FieldError> {
47        let variant_index = self
48            .def
49            .variants
50            .iter()
51            .enumerate()
52            .find(|(_, v)| v.name == variant_name)
53            .map(|(i, _)| i)
54            .ok_or(FieldError::NoSuchStaticField)?;
55
56        self.set_variant_by_index(variant_index)
57    }
58
59    /// Sets the variant of an enum by index.
60    ///
61    /// # Errors
62    ///
63    /// Returns an error if:
64    /// - The index is out of bounds.
65    pub fn set_variant_by_index(self, variant_index: usize) -> Result<PokeEnum<'mem>, FieldError> {
66        if variant_index >= self.def.variants.len() {
67            return Err(FieldError::IndexOutOfBounds);
68        }
69
70        // Get the current variant info
71        let variant = &self.def.variants[variant_index];
72
73        // Prepare memory for the enum
74        unsafe {
75            // Zero out the memory first to ensure clean state
76            core::ptr::write_bytes(self.data.as_mut_bytes(), 0, self.shape.layout.size());
77
78            // Set up the discriminant (tag)
79            // For enums in Rust, the first bytes contain the discriminant
80            let discriminant_value = match &variant.discriminant {
81                // If we have an explicit discriminant, use it
82                Some(discriminant) => *discriminant,
83                // Otherwise, use the variant index directly
84                None => variant_index as i64,
85            };
86
87            // Write the discriminant value based on the representation
88            match self.def.repr {
89                EnumRepr::U8 => {
90                    let tag_ptr = self.data.as_mut_bytes();
91                    *tag_ptr = discriminant_value as u8;
92                }
93                EnumRepr::U16 => {
94                    let tag_ptr = self.data.as_mut_bytes() as *mut u16;
95                    *tag_ptr = discriminant_value as u16;
96                }
97                EnumRepr::U32 => {
98                    let tag_ptr = self.data.as_mut_bytes() as *mut u32;
99                    *tag_ptr = discriminant_value as u32;
100                }
101                EnumRepr::U64 => {
102                    let tag_ptr = self.data.as_mut_bytes() as *mut u64;
103                    *tag_ptr = discriminant_value as u64;
104                }
105                EnumRepr::USize => {
106                    let tag_ptr = self.data.as_mut_bytes() as *mut usize;
107                    *tag_ptr = discriminant_value as usize;
108                }
109                EnumRepr::I8 => {
110                    let tag_ptr = self.data.as_mut_bytes() as *mut i8;
111                    *tag_ptr = discriminant_value as i8;
112                }
113                EnumRepr::I16 => {
114                    let tag_ptr = self.data.as_mut_bytes() as *mut i16;
115                    *tag_ptr = discriminant_value as i16;
116                }
117                EnumRepr::I32 => {
118                    let tag_ptr = self.data.as_mut_bytes() as *mut i32;
119                    *tag_ptr = discriminant_value as i32;
120                }
121                EnumRepr::I64 => {
122                    let tag_ptr = self.data.as_mut_bytes() as *mut i64;
123                    *tag_ptr = discriminant_value;
124                }
125                EnumRepr::ISize => {
126                    let tag_ptr = self.data.as_mut_bytes() as *mut isize;
127                    *tag_ptr = discriminant_value as isize;
128                }
129                _ => {
130                    panic!("Unsupported enum representation: {:?}", self.def.repr);
131                }
132            }
133        }
134
135        // Create PokeEnum with the selected variant
136        Ok(PokeEnum {
137            data: self.data,
138            iset: Default::default(),
139            shape: self.shape,
140            def: self.def,
141            selected_variant: variant_index,
142        })
143    }
144}
145
146/// Allows poking an enum with a selected variant (setting fields, etc.)
147pub struct PokeEnum<'mem> {
148    data: OpaqueUninit<'mem>,
149    iset: ISet,
150    shape: &'static Shape,
151    def: EnumDef,
152    selected_variant: usize,
153}
154
155impl<'mem> PokeEnum<'mem> {
156    /// Coerce back into a `PokeValue`
157    #[inline(always)]
158    pub fn into_value(self) -> PokeValueUninit<'mem> {
159        unsafe { PokeValueUninit::new(self.data, self.shape) }
160    }
161
162    /// Creates a new PokeEnum from raw data
163    ///
164    /// # Safety
165    ///
166    /// The data buffer must match the size and alignment of the enum shape described by shape
167    #[allow(dead_code)]
168    pub(crate) unsafe fn new(
169        data: OpaqueUninit<'mem>,
170        shape: &'static Shape,
171        def: EnumDef,
172        selected_variant: usize,
173    ) -> Self {
174        Self {
175            data,
176            iset: Default::default(),
177            shape,
178            def,
179            selected_variant,
180        }
181    }
182
183    #[inline(always)]
184    /// Shape getter
185    pub fn shape(&self) -> &'static Shape {
186        self.shape
187    }
188
189    /// Returns the currently selected variant index
190    pub fn selected_variant_index(&self) -> usize {
191        self.selected_variant
192    }
193
194    /// Gets a field by name in the currently selected variant.
195    ///
196    /// # Errors
197    ///
198    /// Returns an error if:
199    /// - The field name doesn't exist in the selected variant.
200    /// - The selected variant is a unit variant (which has no fields).
201    pub fn field_by_name(
202        &self,
203        name: &str,
204    ) -> Result<(usize, crate::PokeUninit<'mem>), FieldError> {
205        let variant = &self.def.variants[self.selected_variant];
206
207        // Find the field in the variant
208        match &variant.kind {
209            VariantKind::Unit => {
210                // Unit variants have no fields
211                Err(FieldError::NoSuchStaticField)
212            }
213            VariantKind::Tuple { fields } => {
214                // For tuple variants, find the field by name
215                let (index, field) = fields
216                    .iter()
217                    .enumerate()
218                    .find(|(_, f)| f.name == name)
219                    .ok_or(FieldError::NoSuchStaticField)?;
220
221                // Get the field's address
222                let field_data = unsafe { self.data.field_uninit(field.offset) };
223                let poke = unsafe { crate::PokeUninit::unchecked_new(field_data, field.shape) };
224                Ok((index, poke))
225            }
226            VariantKind::Struct { fields } => {
227                // For struct variants, find the field by name
228                let (index, field) = fields
229                    .iter()
230                    .enumerate()
231                    .find(|(_, f)| f.name == name)
232                    .ok_or(FieldError::NoSuchStaticField)?;
233
234                // Get the field's address
235                let field_data = unsafe { self.data.field_uninit(field.offset) };
236                let poke = unsafe { crate::PokeUninit::unchecked_new(field_data, field.shape) };
237                Ok((index, poke))
238            }
239            _ => {
240                panic!("Unsupported enum variant kind: {:?}", variant.kind);
241            }
242        }
243    }
244
245    /// Get a field writer for a tuple field by index in the currently selected variant.
246    ///
247    /// # Errors
248    ///
249    /// Returns an error if:
250    /// - The index is out of bounds.
251    /// - The selected variant is not a tuple variant.
252    pub fn tuple_field(&self, index: usize) -> Result<crate::PokeUninit<'mem>, FieldError> {
253        let variant = &self.def.variants[self.selected_variant];
254
255        // Make sure we're working with a tuple variant
256        match &variant.kind {
257            VariantKind::Tuple { fields } => {
258                // Check if the index is valid
259                if index >= fields.len() {
260                    return Err(FieldError::IndexOutOfBounds);
261                }
262
263                // Get the field at the specified index
264                let field = &fields[index];
265
266                // Get the field's address
267                let field_data = unsafe { self.data.field_uninit(field.offset) };
268                let poke = unsafe { crate::PokeUninit::unchecked_new(field_data, field.shape) };
269                Ok(poke)
270            }
271            _ => {
272                // Not a tuple variant
273                Err(FieldError::NoSuchStaticField)
274            }
275        }
276    }
277
278    /// Marks a field in the current variant as initialized.
279    ///
280    /// # Safety
281    ///
282    /// The caller must ensure that the field is initialized. Only call this after writing to
283    /// an address gotten through [`Self::field_by_name`] or [`Self::tuple_field`].
284    pub unsafe fn mark_initialized(&mut self, field_index: usize) {
285        self.iset.set(field_index);
286    }
287
288    /// Checks if all required fields in the enum are initialized.
289    ///
290    /// # Panics
291    ///
292    /// Panics if any field in the selected variant is not initialized.
293    pub fn assert_all_fields_initialized(&self) {
294        let variant = &self.def.variants[self.selected_variant];
295
296        // Check if all fields of the selected variant are initialized
297        match &variant.kind {
298            VariantKind::Unit => {
299                // Unit variants don't have fields, so they're always fully initialized
300            }
301            VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
302                // Check each field
303                for (field_index, field) in fields.iter().enumerate() {
304                    if !self.iset.has(field_index) {
305                        panic!(
306                            "Field '{}' of variant '{}' was not initialized. Complete schema:\n{}",
307                            field.name, variant.name, self.shape
308                        );
309                    }
310                }
311            }
312            _ => {
313                panic!("Unsupported enum variant kind: {:?}", variant.kind);
314            }
315        }
316    }
317
318    fn assert_matching_shape<T: Facet>(&self) {
319        if !self.shape.is_type::<T>() {
320            panic!(
321                "This is a partial \x1b[1;34m{}\x1b[0m, you can't build a \x1b[1;32m{}\x1b[0m out of it",
322                self.shape,
323                T::SHAPE,
324            );
325        }
326    }
327
328    /// Asserts that every field in the selected variant has been initialized and forgets the PokeEnum.
329    ///
330    /// This method is only used when the origin is borrowed.
331    /// If this method is not called, all fields will be freed when the PokeEnum is dropped.
332    ///
333    /// # Panics
334    ///
335    /// This function will panic if any required field is not initialized.
336    pub fn build_in_place(self) -> Opaque<'mem> {
337        // ensure all fields are initialized
338        self.assert_all_fields_initialized();
339        let data = unsafe { self.data.assume_init() };
340        // prevent field drops when the PokeEnum is dropped
341        core::mem::forget(self);
342        data
343    }
344
345    /// Builds a value of type `T` from the PokeEnum, then deallocates the memory
346    /// that this PokeEnum was pointing to.
347    ///
348    /// # Panics
349    ///
350    /// This function will panic if:
351    /// - Not all fields in the selected variant have been initialized.
352    /// - The generic type parameter T does not match the shape that this PokeEnum is building.
353    pub fn build<T: Facet>(self, guard: Option<Guard>) -> T {
354        let mut guard = guard;
355        let this = self;
356        // this changes drop order: guard must be dropped _after_ this.
357
358        this.assert_all_fields_initialized();
359        this.assert_matching_shape::<T>();
360        if let Some(guard) = &guard {
361            guard.shape.assert_type::<T>();
362        }
363
364        let result = unsafe {
365            let ptr = this.data.as_mut_bytes() as *const T;
366            core::ptr::read(ptr)
367        };
368        guard.take(); // dealloc
369        core::mem::forget(this);
370        result
371    }
372
373    /// Build that PokeEnum into a boxed completed shape.
374    ///
375    /// # Panics
376    ///
377    /// This function will panic if:
378    /// - Not all fields in the selected variant have been initialized.
379    /// - The generic type parameter T does not match the shape that this PokeEnum is building.
380    pub fn build_boxed<T: Facet>(self) -> alloc::boxed::Box<T> {
381        self.assert_all_fields_initialized();
382        self.assert_matching_shape::<T>();
383
384        let boxed = unsafe { alloc::boxed::Box::from_raw(self.data.as_mut_bytes() as *mut T) };
385        core::mem::forget(self);
386        boxed
387    }
388
389    /// Moves the contents of this `PokeEnum` into a target memory location.
390    ///
391    /// # Safety
392    ///
393    /// The target pointer must be valid and properly aligned,
394    /// and must be large enough to hold the value.
395    /// The caller is responsible for ensuring that the target memory is properly deallocated
396    /// when it's no longer needed.
397    pub unsafe fn move_into(self, target: NonNull<u8>) {
398        self.assert_all_fields_initialized();
399        unsafe {
400            core::ptr::copy_nonoverlapping(
401                self.data.as_mut_bytes(),
402                target.as_ptr(),
403                self.shape.layout.size(),
404            );
405        }
406        core::mem::forget(self);
407    }
408}
409
410impl Drop for PokeEnum<'_> {
411    fn drop(&mut self) {
412        let variant = &self.def.variants[self.selected_variant];
413
414        // Drop fields based on the variant kind
415        match &variant.kind {
416            VariantKind::Unit => {
417                // Unit variants have no fields to drop
418            }
419            VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
420                // Drop each initialized field
421                for (field_index, field) in fields.iter().enumerate() {
422                    if self.iset.has(field_index) {
423                        if let Some(drop_fn) = field.shape.vtable.drop_in_place {
424                            unsafe {
425                                drop_fn(self.data.field_init(field.offset));
426                            }
427                        }
428                    }
429                }
430            }
431            _ => {
432                panic!("Unsupported enum variant kind: {:?}", variant.kind);
433            }
434        }
435    }
436}
437
438/// All possible errors when getting a variant by index or by name
439#[derive(Debug, Copy, Clone, PartialEq, Eq)]
440#[non_exhaustive]
441pub enum VariantError {
442    /// `variant_by_index` was called with an index that is out of bounds.
443    IndexOutOfBounds,
444
445    /// `variant_by_name` or `variant_by_index` was called on a non-enum type.
446    NotAnEnum,
447
448    /// `variant_by_name` was called with a name that doesn't match any variant.
449    NoSuchVariant,
450}
451
452impl core::error::Error for VariantError {}
453
454impl core::fmt::Display for VariantError {
455    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
456        match self {
457            VariantError::IndexOutOfBounds => write!(f, "Variant index out of bounds"),
458            VariantError::NotAnEnum => write!(f, "Not an enum"),
459            VariantError::NoSuchVariant => write!(f, "No such variant"),
460        }
461    }
462}