1use core::marker::PhantomData;
2
3use facet_core::{Def, Facet, PtrConst, PtrMut, Shape, Type, UserType, Variance};
4use facet_path::{Path, PathAccessError, PathStep};
5
6use crate::{ReflectError, ReflectErrorKind, peek::VariantError};
7
8use super::{PokeList, PokeStruct};
9
10pub struct Poke<'mem, 'facet> {
48 pub(crate) data: PtrMut,
50
51 pub(crate) shape: &'static Shape,
53
54 #[allow(clippy::type_complexity)]
57 _marker: PhantomData<(&'mem mut (), fn(&'facet ()) -> &'facet ())>,
58}
59
60impl<'mem, 'facet> Poke<'mem, 'facet> {
61 pub fn new<T: Facet<'facet>>(t: &'mem mut T) -> Self {
66 Self {
67 data: PtrMut::new(t as *mut T as *mut u8),
68 shape: T::SHAPE,
69 _marker: PhantomData,
70 }
71 }
72
73 pub unsafe fn from_raw_parts(data: PtrMut, shape: &'static Shape) -> Self {
80 Self {
81 data,
82 shape,
83 _marker: PhantomData,
84 }
85 }
86
87 #[inline(always)]
89 pub const fn shape(&self) -> &'static Shape {
90 self.shape
91 }
92
93 #[inline(always)]
95 pub const fn data(&self) -> PtrConst {
96 self.data.as_const()
97 }
98
99 #[inline]
101 fn err(&self, kind: ReflectErrorKind) -> ReflectError {
102 ReflectError::new(kind, Path::new(self.shape))
103 }
104
105 #[inline(always)]
107 pub const fn data_mut(&mut self) -> PtrMut {
108 self.data
109 }
110
111 #[inline]
113 pub fn variance(&self) -> Variance {
114 self.shape.computed_variance()
115 }
116
117 #[inline]
125 pub fn try_reborrow<'shorter>(&mut self) -> Option<Poke<'_, 'shorter>>
126 where
127 'facet: 'shorter,
128 {
129 if self.variance().can_shrink() {
130 Some(Poke {
131 data: self.data,
132 shape: self.shape,
133 _marker: PhantomData,
134 })
135 } else {
136 None
137 }
138 }
139
140 #[inline]
142 pub const fn is_struct(&self) -> bool {
143 matches!(self.shape.ty, Type::User(UserType::Struct(_)))
144 }
145
146 #[inline]
148 pub const fn is_enum(&self) -> bool {
149 matches!(self.shape.ty, Type::User(UserType::Enum(_)))
150 }
151
152 #[inline]
154 pub const fn is_scalar(&self) -> bool {
155 matches!(self.shape.def, Def::Scalar)
156 }
157
158 pub fn into_struct(self) -> Result<PokeStruct<'mem, 'facet>, ReflectError> {
160 match self.shape.ty {
161 Type::User(UserType::Struct(struct_type)) => Ok(PokeStruct {
162 value: self,
163 ty: struct_type,
164 }),
165 _ => Err(self.err(ReflectErrorKind::WasNotA {
166 expected: "struct",
167 actual: self.shape,
168 })),
169 }
170 }
171
172 pub fn into_enum(self) -> Result<super::PokeEnum<'mem, 'facet>, ReflectError> {
174 match self.shape.ty {
175 Type::User(UserType::Enum(enum_type)) => Ok(super::PokeEnum {
176 value: self,
177 ty: enum_type,
178 }),
179 _ => Err(self.err(ReflectErrorKind::WasNotA {
180 expected: "enum",
181 actual: self.shape,
182 })),
183 }
184 }
185
186 #[inline]
188 pub fn into_list(self) -> Result<PokeList<'mem, 'facet>, ReflectError> {
189 if let Def::List(def) = self.shape.def {
190 return Ok(unsafe { PokeList::new(self, def) });
194 }
195
196 Err(self.err(ReflectErrorKind::WasNotA {
197 expected: "list",
198 actual: self.shape,
199 }))
200 }
201
202 pub fn get<T: Facet<'facet>>(&self) -> Result<&T, ReflectError> {
206 if self.shape != T::SHAPE {
207 return Err(self.err(ReflectErrorKind::WrongShape {
208 expected: self.shape,
209 actual: T::SHAPE,
210 }));
211 }
212 Ok(unsafe { self.data.as_const().get::<T>() })
213 }
214
215 pub fn get_mut<T: Facet<'facet>>(&mut self) -> Result<&mut T, ReflectError> {
219 if self.shape != T::SHAPE {
220 return Err(self.err(ReflectErrorKind::WrongShape {
221 expected: self.shape,
222 actual: T::SHAPE,
223 }));
224 }
225 Ok(unsafe { self.data.as_mut::<T>() })
226 }
227
228 pub fn set<T: Facet<'facet>>(&mut self, value: T) -> Result<(), ReflectError> {
232 if self.shape != T::SHAPE {
233 return Err(self.err(ReflectErrorKind::WrongShape {
234 expected: self.shape,
235 actual: T::SHAPE,
236 }));
237 }
238 unsafe {
239 self.shape.call_drop_in_place(self.data);
241 core::ptr::write(self.data.as_mut_byte_ptr() as *mut T, value);
242 }
243 Ok(())
244 }
245
246 #[inline]
248 pub fn as_peek(&self) -> crate::Peek<'_, 'facet> {
249 unsafe { crate::Peek::unchecked_new(self.data.as_const(), self.shape) }
250 }
251
252 #[inline]
254 pub fn into_peek(self) -> crate::Peek<'mem, 'facet> {
255 unsafe { crate::Peek::unchecked_new(self.data.as_const(), self.shape) }
256 }
257
258 pub fn at_path_mut(self, path: &Path) -> Result<Poke<'mem, 'facet>, PathAccessError> {
283 if self.shape != path.shape {
284 return Err(PathAccessError::RootShapeMismatch {
285 expected: path.shape,
286 actual: self.shape,
287 });
288 }
289
290 let mut data = self.data;
291 let mut shape: &'static Shape = self.shape;
292
293 for (step_index, step) in path.steps().iter().enumerate() {
294 let (new_data, new_shape) = apply_step_mut(data, shape, *step, step_index)?;
295 data = new_data;
296 shape = new_shape;
297 }
298
299 Ok(unsafe { Poke::from_raw_parts(data, shape) })
300 }
301}
302
303fn apply_step_mut(
308 data: PtrMut,
309 shape: &'static Shape,
310 step: PathStep,
311 step_index: usize,
312) -> Result<(PtrMut, &'static Shape), PathAccessError> {
313 match step {
314 PathStep::Field(idx) => {
315 let idx = idx as usize;
316 match shape.ty {
317 Type::User(UserType::Struct(sd)) => {
318 if idx >= sd.fields.len() {
319 return Err(PathAccessError::IndexOutOfBounds {
320 step,
321 step_index,
322 shape,
323 index: idx,
324 bound: sd.fields.len(),
325 });
326 }
327 let field = &sd.fields[idx];
328 let field_data = unsafe { data.field(field.offset) };
329 Ok((field_data, field.shape()))
330 }
331 Type::User(UserType::Enum(enum_type)) => {
332 let variant_idx = variant_index_from_raw(data.as_const(), shape, enum_type)
334 .map_err(|_| PathAccessError::WrongStepKind {
335 step,
336 step_index,
337 shape,
338 })?;
339 let variant = &enum_type.variants[variant_idx];
340 if idx >= variant.data.fields.len() {
341 return Err(PathAccessError::IndexOutOfBounds {
342 step,
343 step_index,
344 shape,
345 index: idx,
346 bound: variant.data.fields.len(),
347 });
348 }
349 let field = &variant.data.fields[idx];
350 let field_data = unsafe { data.field(field.offset) };
351 Ok((field_data, field.shape()))
352 }
353 _ => Err(PathAccessError::WrongStepKind {
354 step,
355 step_index,
356 shape,
357 }),
358 }
359 }
360
361 PathStep::Variant(expected_idx) => {
362 let expected_idx = expected_idx as usize;
363 let enum_type = match shape.ty {
364 Type::User(UserType::Enum(et)) => et,
365 _ => {
366 return Err(PathAccessError::WrongStepKind {
367 step,
368 step_index,
369 shape,
370 });
371 }
372 };
373
374 if expected_idx >= enum_type.variants.len() {
375 return Err(PathAccessError::IndexOutOfBounds {
376 step,
377 step_index,
378 shape,
379 index: expected_idx,
380 bound: enum_type.variants.len(),
381 });
382 }
383
384 let actual_idx =
385 variant_index_from_raw(data.as_const(), shape, enum_type).map_err(|_| {
386 PathAccessError::WrongStepKind {
387 step,
388 step_index,
389 shape,
390 }
391 })?;
392
393 if actual_idx != expected_idx {
394 return Err(PathAccessError::VariantMismatch {
395 step_index,
396 shape,
397 expected_variant: expected_idx,
398 actual_variant: actual_idx,
399 });
400 }
401
402 Ok((data, shape))
404 }
405
406 PathStep::Index(idx) => {
407 let idx = idx as usize;
408 match shape.def {
409 Def::List(def) => {
410 let get_mut_fn = def.vtable.get_mut.ok_or(PathAccessError::WrongStepKind {
411 step,
412 step_index,
413 shape,
414 })?;
415 let len = unsafe { (def.vtable.len)(data.as_const()) };
416 let item = unsafe { get_mut_fn(data, idx, shape) };
417 item.map(|ptr| (ptr, def.t()))
418 .ok_or(PathAccessError::IndexOutOfBounds {
419 step,
420 step_index,
421 shape,
422 index: idx,
423 bound: len,
424 })
425 }
426 Def::Array(def) => {
427 let elem_shape = def.t();
429 let layout = elem_shape.layout.sized_layout().map_err(|_| {
430 PathAccessError::WrongStepKind {
431 step,
432 step_index,
433 shape,
434 }
435 })?;
436 let len = def.n;
437 if idx >= len {
438 return Err(PathAccessError::IndexOutOfBounds {
439 step,
440 step_index,
441 shape,
442 index: idx,
443 bound: len,
444 });
445 }
446 let elem_data = unsafe { data.field(layout.size() * idx) };
447 Ok((elem_data, elem_shape))
448 }
449 _ => Err(PathAccessError::WrongStepKind {
450 step,
451 step_index,
452 shape,
453 }),
454 }
455 }
456
457 PathStep::OptionSome => {
458 if let Def::Option(option_def) = shape.def {
459 let is_some = unsafe { (option_def.vtable.is_some)(data.as_const()) };
461 if !is_some {
462 return Err(PathAccessError::OptionIsNone { step_index, shape });
463 }
464 let inner_raw_ptr = unsafe { (option_def.vtable.get_value)(data.as_const()) };
468 assert!(
469 !inner_raw_ptr.is_null(),
470 "is_some was true but get_value returned null"
471 );
472 let inner_ptr_const = facet_core::PtrConst::new_sized(inner_raw_ptr);
473 let offset = unsafe {
475 inner_ptr_const
476 .as_byte_ptr()
477 .offset_from(data.as_const().as_byte_ptr())
478 } as usize;
479 let inner_data = unsafe { data.field(offset) };
480 Ok((inner_data, option_def.t()))
481 } else {
482 Err(PathAccessError::WrongStepKind {
483 step,
484 step_index,
485 shape,
486 })
487 }
488 }
489
490 PathStep::MapKey(_) | PathStep::MapValue(_) => {
491 if matches!(shape.def, Def::Map(_)) {
492 Err(PathAccessError::MissingTarget {
493 step,
494 step_index,
495 shape,
496 })
497 } else {
498 Err(PathAccessError::WrongStepKind {
499 step,
500 step_index,
501 shape,
502 })
503 }
504 }
505
506 PathStep::Deref => {
507 if matches!(shape.def, Def::Pointer(_)) {
508 Err(PathAccessError::MissingTarget {
509 step,
510 step_index,
511 shape,
512 })
513 } else {
514 Err(PathAccessError::WrongStepKind {
515 step,
516 step_index,
517 shape,
518 })
519 }
520 }
521
522 PathStep::Inner => Err(PathAccessError::MissingTarget {
523 step,
524 step_index,
525 shape,
526 }),
527
528 PathStep::Proxy => Err(PathAccessError::MissingTarget {
529 step,
530 step_index,
531 shape,
532 }),
533 }
534}
535
536fn variant_index_from_raw(
539 data: PtrConst,
540 shape: &'static Shape,
541 enum_type: facet_core::EnumType,
542) -> Result<usize, VariantError> {
543 use facet_core::EnumRepr;
544
545 if let Def::Option(option_def) = shape.def {
547 let is_some = unsafe { (option_def.vtable.is_some)(data) };
548 return Ok(enum_type
549 .variants
550 .iter()
551 .position(|variant| {
552 let has_fields = !variant.data.fields.is_empty();
553 has_fields == is_some
554 })
555 .expect("No variant found matching Option state"));
556 }
557
558 if enum_type.enum_repr == EnumRepr::RustNPO {
559 let layout = shape
560 .layout
561 .sized_layout()
562 .map_err(|_| VariantError::Unsized)?;
563 let slice = unsafe { core::slice::from_raw_parts(data.as_byte_ptr(), layout.size()) };
564 let all_zero = slice.iter().all(|v| *v == 0);
565
566 Ok(enum_type
567 .variants
568 .iter()
569 .position(|variant| {
570 let mut max_offset = 0;
571 for field in variant.data.fields {
572 let offset = field.offset
573 + field
574 .shape()
575 .layout
576 .sized_layout()
577 .map(|v| v.size())
578 .unwrap_or(0);
579 max_offset = core::cmp::max(max_offset, offset);
580 }
581 if all_zero {
582 max_offset == 0
583 } else {
584 max_offset != 0
585 }
586 })
587 .expect("No variant found with matching discriminant"))
588 } else {
589 let discriminant = match enum_type.enum_repr {
590 EnumRepr::Rust => {
591 panic!("cannot read discriminant from Rust enum with unspecified layout")
592 }
593 EnumRepr::RustNPO => 0,
594 EnumRepr::U8 => unsafe { data.read::<u8>() as i64 },
595 EnumRepr::U16 => unsafe { data.read::<u16>() as i64 },
596 EnumRepr::U32 => unsafe { data.read::<u32>() as i64 },
597 EnumRepr::U64 => unsafe { data.read::<u64>() as i64 },
598 EnumRepr::USize => unsafe { data.read::<usize>() as i64 },
599 EnumRepr::I8 => unsafe { data.read::<i8>() as i64 },
600 EnumRepr::I16 => unsafe { data.read::<i16>() as i64 },
601 EnumRepr::I32 => unsafe { data.read::<i32>() as i64 },
602 EnumRepr::I64 => unsafe { data.read::<i64>() },
603 EnumRepr::ISize => unsafe { data.read::<isize>() as i64 },
604 };
605
606 Ok(enum_type
607 .variants
608 .iter()
609 .position(|variant| variant.discriminant == Some(discriminant))
610 .expect("No variant found with matching discriminant"))
611 }
612}
613
614impl core::fmt::Debug for Poke<'_, '_> {
615 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
616 write!(f, "Poke<{}>", self.shape)
617 }
618}
619
620#[cfg(test)]
621mod tests {
622 use super::*;
623
624 #[test]
625 fn poke_primitive_get_set() {
626 let mut x: i32 = 42;
627 let mut poke = Poke::new(&mut x);
628
629 assert_eq!(*poke.get::<i32>().unwrap(), 42);
630
631 poke.set(100i32).unwrap();
632 assert_eq!(x, 100);
633 }
634
635 #[test]
636 fn poke_primitive_get_mut() {
637 let mut x: i32 = 42;
638 let mut poke = Poke::new(&mut x);
639
640 *poke.get_mut::<i32>().unwrap() = 99;
641 assert_eq!(x, 99);
642 }
643
644 #[test]
645 fn poke_wrong_type_fails() {
646 let mut x: i32 = 42;
647 let poke = Poke::new(&mut x);
648
649 let result = poke.get::<u32>();
650 assert!(matches!(
651 result,
652 Err(ReflectError {
653 kind: ReflectErrorKind::WrongShape { .. },
654 ..
655 })
656 ));
657 }
658
659 #[test]
660 fn poke_set_wrong_type_fails() {
661 let mut x: i32 = 42;
662 let mut poke = Poke::new(&mut x);
663
664 let result = poke.set(42u32);
665 assert!(matches!(
666 result,
667 Err(ReflectError {
668 kind: ReflectErrorKind::WrongShape { .. },
669 ..
670 })
671 ));
672 }
673
674 #[test]
675 fn poke_string_drop_and_replace() {
676 let mut s = String::from("hello");
678 let mut poke = Poke::new(&mut s);
679
680 poke.set(String::from("world")).unwrap();
681 assert_eq!(s, "world");
682 }
683}