Skip to main content

facet_reflect/poke/
enum_.rs

1use facet_core::{Def, EnumRepr, EnumType, Facet, FieldError};
2use facet_path::Path;
3
4use crate::{ReflectError, ReflectErrorKind, peek::VariantError};
5
6use super::Poke;
7
8/// Lets you mutate an enum's fields.
9pub struct PokeEnum<'mem, 'facet> {
10    /// The internal data storage for the enum
11    ///
12    /// Note that this stores both the discriminant and the variant data
13    /// (if any), and the layout depends on the enum representation.
14    pub(crate) value: Poke<'mem, 'facet>,
15
16    /// The definition of the enum.
17    pub(crate) ty: EnumType,
18}
19
20impl core::fmt::Debug for PokeEnum<'_, '_> {
21    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
22        write!(f, "{:?}", self.value)
23    }
24}
25
26impl<'mem, 'facet> PokeEnum<'mem, 'facet> {
27    fn err(&self, kind: ReflectErrorKind) -> ReflectError {
28        ReflectError::new(kind, Path::new(self.value.shape))
29    }
30
31    /// Returns the enum definition
32    #[inline(always)]
33    pub const fn ty(&self) -> EnumType {
34        self.ty
35    }
36
37    /// Returns the enum representation
38    #[inline(always)]
39    pub const fn enum_repr(&self) -> EnumRepr {
40        self.ty.enum_repr
41    }
42
43    /// Returns the enum variants
44    #[inline(always)]
45    pub const fn variants(&self) -> &'static [facet_core::Variant] {
46        self.ty.variants
47    }
48
49    /// Returns the number of variants in this enum
50    #[inline(always)]
51    pub const fn variant_count(&self) -> usize {
52        self.ty.variants.len()
53    }
54
55    /// Returns the variant name at the given index
56    #[inline(always)]
57    pub fn variant_name(&self, index: usize) -> Option<&'static str> {
58        self.ty.variants.get(index).map(|variant| variant.name)
59    }
60
61    /// Returns the discriminant value for the current enum value
62    ///
63    /// Note: For `RustNPO` (null pointer optimization) types, there is no explicit
64    /// discriminant stored in memory. In this case, 0 is returned. Use
65    /// [`variant_index()`](Self::variant_index) to determine the active variant for NPO types.
66    #[inline]
67    pub fn discriminant(&self) -> i64 {
68        // Read the discriminant based on the enum representation
69        match self.ty.enum_repr {
70            EnumRepr::Rust => {
71                panic!("cannot read discriminant from Rust enum with unspecified layout")
72            }
73            // For RustNPO types, there is no explicit discriminant stored in memory.
74            // The variant is determined by niche optimization (e.g., null pointer pattern).
75            // Return 0 since that's the declared discriminant for NPO variants.
76            // This also prevents UB when reading from zero-sized types.
77            EnumRepr::RustNPO => 0,
78            EnumRepr::U8 => unsafe { self.value.data().read::<u8>() as i64 },
79            EnumRepr::U16 => unsafe { self.value.data().read::<u16>() as i64 },
80            EnumRepr::U32 => unsafe { self.value.data().read::<u32>() as i64 },
81            EnumRepr::U64 => unsafe { self.value.data().read::<u64>() as i64 },
82            EnumRepr::USize => unsafe { self.value.data().read::<usize>() as i64 },
83            EnumRepr::I8 => unsafe { self.value.data().read::<i8>() as i64 },
84            EnumRepr::I16 => unsafe { self.value.data().read::<i16>() as i64 },
85            EnumRepr::I32 => unsafe { self.value.data().read::<i32>() as i64 },
86            EnumRepr::I64 => unsafe { self.value.data().read::<i64>() },
87            EnumRepr::ISize => unsafe { self.value.data().read::<isize>() as i64 },
88        }
89    }
90
91    /// Returns the variant index for this enum value
92    #[inline]
93    pub fn variant_index(&self) -> Result<usize, VariantError> {
94        // For Option<T> types, use the OptionVTable to correctly determine if the value is Some or None.
95        // This handles both RustNPO (niche-optimized) and Rust (non-niche) representations.
96        if let Def::Option(option_def) = self.value.shape.def {
97            let is_some = unsafe { (option_def.vtable.is_some)(self.value.data()) };
98            return Ok(self
99                .ty
100                .variants
101                .iter()
102                .position(|variant| {
103                    let has_fields = !variant.data.fields.is_empty();
104                    has_fields == is_some
105                })
106                .expect("No variant found matching Option state"));
107        }
108
109        if self.ty.enum_repr == EnumRepr::RustNPO {
110            // Fallback for other RustNPO types (e.g., Option<&T> where all-zeros means None)
111            let layout = self
112                .value
113                .shape
114                .layout
115                .sized_layout()
116                .expect("Unsized enums in NPO repr are unsupported");
117
118            let data = self.value.data();
119            let slice = unsafe { core::slice::from_raw_parts(data.as_byte_ptr(), layout.size()) };
120            let all_zero = slice.iter().all(|v| *v == 0);
121
122            Ok(self
123                .ty
124                .variants
125                .iter()
126                .position(|variant| {
127                    // Find the maximum end bound
128                    let mut max_offset = 0;
129
130                    for field in variant.data.fields {
131                        let offset = field.offset
132                            + field
133                                .shape()
134                                .layout
135                                .sized_layout()
136                                .map(|v| v.size())
137                                .unwrap_or(0);
138                        max_offset = core::cmp::max(max_offset, offset);
139                    }
140
141                    // If we are all zero, then find the enum variant that has no size,
142                    // otherwise, the one with size.
143                    if all_zero {
144                        max_offset == 0
145                    } else {
146                        max_offset != 0
147                    }
148                })
149                .expect("No variant found with matching discriminant"))
150        } else {
151            let discriminant = self.discriminant();
152
153            // Find the variant with matching discriminant using position method
154            Ok(self
155                .ty
156                .variants
157                .iter()
158                .position(|variant| variant.discriminant == Some(discriminant))
159                .expect("No variant found with matching discriminant"))
160        }
161    }
162
163    /// Returns the active variant
164    #[inline]
165    pub fn active_variant(&self) -> Result<&'static facet_core::Variant, VariantError> {
166        let index = self.variant_index()?;
167        Ok(&self.ty.variants[index])
168    }
169
170    /// Returns the name of the active variant for this enum value
171    #[inline]
172    pub fn variant_name_active(&self) -> Result<&'static str, VariantError> {
173        Ok(self.active_variant()?.name)
174    }
175
176    /// Returns a Poke handle to a field of a tuple or struct variant by index
177    pub fn field(&mut self, index: usize) -> Result<Option<Poke<'_, 'facet>>, VariantError> {
178        let variant = self.active_variant()?;
179        let fields = &variant.data.fields;
180
181        if index >= fields.len() {
182            return Ok(None);
183        }
184
185        let field = &fields[index];
186        let field_data = unsafe { self.value.data.field(field.offset) };
187        Ok(Some(unsafe {
188            Poke::from_raw_parts(field_data, field.shape())
189        }))
190    }
191
192    /// Returns the index of a field in the active variant by name
193    pub fn field_index(&self, field_name: &str) -> Result<Option<usize>, VariantError> {
194        let variant = self.active_variant()?;
195        Ok(variant
196            .data
197            .fields
198            .iter()
199            .position(|f| f.name == field_name))
200    }
201
202    /// Returns a Poke handle to a field of a tuple or struct variant by name
203    pub fn field_by_name(
204        &mut self,
205        field_name: &str,
206    ) -> Result<Option<Poke<'_, 'facet>>, VariantError> {
207        let index_opt = self.field_index(field_name)?;
208        match index_opt {
209            Some(index) => self.field(index),
210            None => Ok(None),
211        }
212    }
213
214    /// Sets a field of the current variant by index.
215    ///
216    /// Returns an error if:
217    /// - The parent enum is not POD
218    /// - The index is out of bounds
219    /// - The value type doesn't match the field type
220    pub fn set_field<T: Facet<'facet>>(
221        &mut self,
222        index: usize,
223        value: T,
224    ) -> Result<(), ReflectError> {
225        // Check that the parent enum is POD before allowing field mutation
226        if !self.value.shape.is_pod() {
227            return Err(self.err(ReflectErrorKind::NotPod {
228                shape: self.value.shape,
229            }));
230        }
231
232        let variant = self.active_variant().map_err(|_| {
233            self.err(ReflectErrorKind::OperationFailed {
234                shape: self.value.shape,
235                operation: "get active variant",
236            })
237        })?;
238        let fields = &variant.data.fields;
239
240        let field = fields.get(index).ok_or_else(|| {
241            self.err(ReflectErrorKind::FieldError {
242                shape: self.value.shape,
243                field_error: FieldError::IndexOutOfBounds {
244                    index,
245                    bound: fields.len(),
246                },
247            })
248        })?;
249
250        let field_shape = field.shape();
251        if field_shape != T::SHAPE {
252            return Err(self.err(ReflectErrorKind::WrongShape {
253                expected: field_shape,
254                actual: T::SHAPE,
255            }));
256        }
257
258        unsafe {
259            let field_ptr = self.value.data.field(field.offset);
260            // Drop the old value and write the new one
261            field_shape.call_drop_in_place(field_ptr);
262            core::ptr::write(field_ptr.as_mut_byte_ptr() as *mut T, value);
263        }
264
265        Ok(())
266    }
267
268    /// Sets a field of the current variant by name.
269    ///
270    /// Returns an error if:
271    /// - The parent enum is not POD
272    /// - No field with the given name exists
273    /// - The value type doesn't match the field type
274    pub fn set_field_by_name<T: Facet<'facet>>(
275        &mut self,
276        name: &str,
277        value: T,
278    ) -> Result<(), ReflectError> {
279        let index = self.field_index(name).map_err(|_| {
280            self.err(ReflectErrorKind::OperationFailed {
281                shape: self.value.shape,
282                operation: "get active variant",
283            })
284        })?;
285
286        let index = index.ok_or_else(|| {
287            self.err(ReflectErrorKind::FieldError {
288                shape: self.value.shape,
289                field_error: FieldError::NoSuchField,
290            })
291        })?;
292
293        self.set_field(index, value)
294    }
295
296    /// Gets a read-only view of a field by index.
297    pub fn peek_field(
298        &self,
299        index: usize,
300    ) -> Result<Option<crate::Peek<'_, 'facet>>, VariantError> {
301        let variant = self.active_variant()?;
302        let fields = &variant.data.fields;
303
304        if index >= fields.len() {
305            return Ok(None);
306        }
307
308        let field = &fields[index];
309        let field_data = unsafe { self.value.data.as_const().field(field.offset) };
310        Ok(Some(unsafe {
311            crate::Peek::unchecked_new(field_data, field.shape())
312        }))
313    }
314
315    /// Gets a read-only view of a field by name.
316    pub fn peek_field_by_name(
317        &self,
318        field_name: &str,
319    ) -> Result<Option<crate::Peek<'_, 'facet>>, VariantError> {
320        let index_opt = self.field_index(field_name)?;
321        match index_opt {
322            Some(index) => self.peek_field(index),
323            None => Ok(None),
324        }
325    }
326
327    /// Converts this back into the underlying `Poke`.
328    #[inline]
329    pub const fn into_inner(self) -> Poke<'mem, 'facet> {
330        self.value
331    }
332
333    /// Returns a read-only `PeekEnum` view.
334    #[inline]
335    pub fn as_peek_enum(&self) -> crate::PeekEnum<'_, 'facet> {
336        crate::PeekEnum {
337            value: self.value.as_peek(),
338            ty: self.ty,
339        }
340    }
341}