facet_reflect/poke/
enum_.rs

1use facet_core::{Def, EnumRepr, EnumType, Facet, FieldError};
2
3use crate::{ReflectError, peek::VariantError};
4
5use super::Poke;
6
7/// Lets you mutate an enum's fields.
8pub struct PokeEnum<'mem, 'facet> {
9    /// The internal data storage for the enum
10    ///
11    /// Note that this stores both the discriminant and the variant data
12    /// (if any), and the layout depends on the enum representation.
13    pub(crate) value: Poke<'mem, 'facet>,
14
15    /// The definition of the enum.
16    pub(crate) ty: EnumType,
17}
18
19impl core::fmt::Debug for PokeEnum<'_, '_> {
20    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
21        write!(f, "{:?}", self.value)
22    }
23}
24
25impl<'mem, 'facet> PokeEnum<'mem, 'facet> {
26    /// Returns the enum definition
27    #[inline(always)]
28    pub fn ty(&self) -> EnumType {
29        self.ty
30    }
31
32    /// Returns the enum representation
33    #[inline(always)]
34    pub fn enum_repr(&self) -> EnumRepr {
35        self.ty.enum_repr
36    }
37
38    /// Returns the enum variants
39    #[inline(always)]
40    pub fn variants(&self) -> &'static [facet_core::Variant] {
41        self.ty.variants
42    }
43
44    /// Returns the number of variants in this enum
45    #[inline(always)]
46    pub fn variant_count(&self) -> usize {
47        self.ty.variants.len()
48    }
49
50    /// Returns the variant name at the given index
51    #[inline(always)]
52    pub fn variant_name(&self, index: usize) -> Option<&'static str> {
53        self.ty.variants.get(index).map(|variant| variant.name)
54    }
55
56    /// Returns the discriminant value for the current enum value
57    ///
58    /// Note: For `RustNPO` (null pointer optimization) types, there is no explicit
59    /// discriminant stored in memory. In this case, 0 is returned. Use
60    /// [`variant_index()`](Self::variant_index) to determine the active variant for NPO types.
61    #[inline]
62    pub fn discriminant(&self) -> i64 {
63        // Read the discriminant based on the enum representation
64        match self.ty.enum_repr {
65            // For RustNPO types, there is no explicit discriminant stored in memory.
66            // The variant is determined by niche optimization (e.g., null pointer pattern).
67            // Return 0 since that's the declared discriminant for NPO variants.
68            // This also prevents UB when reading from zero-sized types.
69            EnumRepr::RustNPO => 0,
70            EnumRepr::U8 => unsafe { self.value.data().read::<u8>() as i64 },
71            EnumRepr::U16 => unsafe { self.value.data().read::<u16>() as i64 },
72            EnumRepr::U32 => unsafe { self.value.data().read::<u32>() as i64 },
73            EnumRepr::U64 => unsafe { self.value.data().read::<u64>() as i64 },
74            EnumRepr::USize => unsafe { self.value.data().read::<usize>() as i64 },
75            EnumRepr::I8 => unsafe { self.value.data().read::<i8>() as i64 },
76            EnumRepr::I16 => unsafe { self.value.data().read::<i16>() as i64 },
77            EnumRepr::I32 => unsafe { self.value.data().read::<i32>() as i64 },
78            EnumRepr::I64 => unsafe { self.value.data().read::<i64>() },
79            EnumRepr::ISize => unsafe { self.value.data().read::<isize>() as i64 },
80        }
81    }
82
83    /// Returns the variant index for this enum value
84    #[inline]
85    pub fn variant_index(&self) -> Result<usize, VariantError> {
86        if self.ty.enum_repr == EnumRepr::RustNPO {
87            // For Option<T> types with niche optimization, use the OptionVTable
88            // to correctly determine if the value is Some or None.
89            if let Def::Option(option_def) = self.value.shape.def {
90                let is_some = unsafe { (option_def.vtable.is_some)(self.value.data()) };
91                return Ok(self
92                    .ty
93                    .variants
94                    .iter()
95                    .position(|variant| {
96                        let has_fields = !variant.data.fields.is_empty();
97                        has_fields == is_some
98                    })
99                    .expect("No variant found matching Option state"));
100            }
101
102            // Fallback for other RustNPO types (e.g., Option<&T> where all-zeros means None)
103            let layout = self
104                .value
105                .shape
106                .layout
107                .sized_layout()
108                .expect("Unsized enums in NPO repr are unsupported");
109
110            let data = self.value.data();
111            let slice = unsafe { core::slice::from_raw_parts(data.as_byte_ptr(), layout.size()) };
112            let all_zero = slice.iter().all(|v| *v == 0);
113
114            Ok(self
115                .ty
116                .variants
117                .iter()
118                .position(|variant| {
119                    // Find the maximum end bound
120                    let mut max_offset = 0;
121
122                    for field in variant.data.fields {
123                        let offset = field.offset
124                            + field
125                                .shape()
126                                .layout
127                                .sized_layout()
128                                .map(|v| v.size())
129                                .unwrap_or(0);
130                        max_offset = core::cmp::max(max_offset, offset);
131                    }
132
133                    // If we are all zero, then find the enum variant that has no size,
134                    // otherwise, the one with size.
135                    if all_zero {
136                        max_offset == 0
137                    } else {
138                        max_offset != 0
139                    }
140                })
141                .expect("No variant found with matching discriminant"))
142        } else {
143            let discriminant = self.discriminant();
144
145            // Find the variant with matching discriminant using position method
146            Ok(self
147                .ty
148                .variants
149                .iter()
150                .position(|variant| variant.discriminant == Some(discriminant))
151                .expect("No variant found with matching discriminant"))
152        }
153    }
154
155    /// Returns the active variant
156    #[inline]
157    pub fn active_variant(&self) -> Result<&'static facet_core::Variant, VariantError> {
158        let index = self.variant_index()?;
159        Ok(&self.ty.variants[index])
160    }
161
162    /// Returns the name of the active variant for this enum value
163    #[inline]
164    pub fn variant_name_active(&self) -> Result<&'static str, VariantError> {
165        Ok(self.active_variant()?.name)
166    }
167
168    /// Returns a Poke handle to a field of a tuple or struct variant by index
169    pub fn field(&mut self, index: usize) -> Result<Option<Poke<'_, 'facet>>, VariantError> {
170        let variant = self.active_variant()?;
171        let fields = &variant.data.fields;
172
173        if index >= fields.len() {
174            return Ok(None);
175        }
176
177        let field = &fields[index];
178        let field_data = unsafe { self.value.data.field(field.offset) };
179        Ok(Some(unsafe {
180            Poke::from_raw_parts(field_data, field.shape())
181        }))
182    }
183
184    /// Returns the index of a field in the active variant by name
185    pub fn field_index(&self, field_name: &str) -> Result<Option<usize>, VariantError> {
186        let variant = self.active_variant()?;
187        Ok(variant
188            .data
189            .fields
190            .iter()
191            .position(|f| f.name == field_name))
192    }
193
194    /// Returns a Poke handle to a field of a tuple or struct variant by name
195    pub fn field_by_name(
196        &mut self,
197        field_name: &str,
198    ) -> Result<Option<Poke<'_, 'facet>>, VariantError> {
199        let index_opt = self.field_index(field_name)?;
200        match index_opt {
201            Some(index) => self.field(index),
202            None => Ok(None),
203        }
204    }
205
206    /// Sets a field of the current variant by index.
207    ///
208    /// Returns an error if:
209    /// - The parent enum is not POD
210    /// - The index is out of bounds
211    /// - The value type doesn't match the field type
212    pub fn set_field<T: Facet<'facet>>(
213        &mut self,
214        index: usize,
215        value: T,
216    ) -> Result<(), ReflectError> {
217        // Check that the parent enum is POD before allowing field mutation
218        if !self.value.shape.is_pod() {
219            return Err(ReflectError::NotPod {
220                shape: self.value.shape,
221            });
222        }
223
224        let variant = self
225            .active_variant()
226            .map_err(|_| ReflectError::OperationFailed {
227                shape: self.value.shape,
228                operation: "get active variant",
229            })?;
230        let fields = &variant.data.fields;
231
232        let field = fields.get(index).ok_or(ReflectError::FieldError {
233            shape: self.value.shape,
234            field_error: FieldError::IndexOutOfBounds {
235                index,
236                bound: fields.len(),
237            },
238        })?;
239
240        let field_shape = field.shape();
241        if field_shape != T::SHAPE {
242            return Err(ReflectError::WrongShape {
243                expected: field_shape,
244                actual: T::SHAPE,
245            });
246        }
247
248        unsafe {
249            let field_ptr = self.value.data.field(field.offset);
250            // Drop the old value and write the new one
251            field_shape.call_drop_in_place(field_ptr);
252            core::ptr::write(field_ptr.as_mut_byte_ptr() as *mut T, value);
253        }
254
255        Ok(())
256    }
257
258    /// Sets a field of the current variant by name.
259    ///
260    /// Returns an error if:
261    /// - The parent enum is not POD
262    /// - No field with the given name exists
263    /// - The value type doesn't match the field type
264    pub fn set_field_by_name<T: Facet<'facet>>(
265        &mut self,
266        name: &str,
267        value: T,
268    ) -> Result<(), ReflectError> {
269        let index = self
270            .field_index(name)
271            .map_err(|_| ReflectError::OperationFailed {
272                shape: self.value.shape,
273                operation: "get active variant",
274            })?;
275
276        let index = index.ok_or(ReflectError::FieldError {
277            shape: self.value.shape,
278            field_error: FieldError::NoSuchField,
279        })?;
280
281        self.set_field(index, value)
282    }
283
284    /// Gets a read-only view of a field by index.
285    pub fn peek_field(
286        &self,
287        index: usize,
288    ) -> Result<Option<crate::Peek<'_, 'facet>>, VariantError> {
289        let variant = self.active_variant()?;
290        let fields = &variant.data.fields;
291
292        if index >= fields.len() {
293            return Ok(None);
294        }
295
296        let field = &fields[index];
297        let field_data = unsafe { self.value.data.as_const().field(field.offset) };
298        Ok(Some(unsafe {
299            crate::Peek::unchecked_new(field_data, field.shape())
300        }))
301    }
302
303    /// Gets a read-only view of a field by name.
304    pub fn peek_field_by_name(
305        &self,
306        field_name: &str,
307    ) -> Result<Option<crate::Peek<'_, 'facet>>, VariantError> {
308        let index_opt = self.field_index(field_name)?;
309        match index_opt {
310            Some(index) => self.peek_field(index),
311            None => Ok(None),
312        }
313    }
314
315    /// Converts this back into the underlying `Poke`.
316    #[inline]
317    pub fn into_inner(self) -> Poke<'mem, 'facet> {
318        self.value
319    }
320
321    /// Returns a read-only `PeekEnum` view.
322    #[inline]
323    pub fn as_peek_enum(&self) -> crate::PeekEnum<'_, 'facet> {
324        crate::PeekEnum {
325            value: self.value.as_peek(),
326            ty: self.ty,
327        }
328    }
329}