use core::marker::PhantomData;
use facet_core::{Def, Facet, PtrConst, PtrMut, Shape, Type, UserType, Variance};
use facet_path::{Path, PathAccessError, PathStep};
use crate::{ReflectError, ReflectErrorKind, peek::VariantError};
use super::{PokeList, PokeStruct};
pub struct Poke<'mem, 'facet> {
pub(crate) data: PtrMut,
pub(crate) shape: &'static Shape,
#[allow(clippy::type_complexity)]
_marker: PhantomData<(&'mem mut (), fn(&'facet ()) -> &'facet ())>,
}
impl<'mem, 'facet> Poke<'mem, 'facet> {
pub fn new<T: Facet<'facet>>(t: &'mem mut T) -> Self {
Self {
data: PtrMut::new(t as *mut T as *mut u8),
shape: T::SHAPE,
_marker: PhantomData,
}
}
pub unsafe fn from_raw_parts(data: PtrMut, shape: &'static Shape) -> Self {
Self {
data,
shape,
_marker: PhantomData,
}
}
#[inline(always)]
pub const fn shape(&self) -> &'static Shape {
self.shape
}
#[inline(always)]
pub const fn data(&self) -> PtrConst {
self.data.as_const()
}
#[inline]
fn err(&self, kind: ReflectErrorKind) -> ReflectError {
ReflectError::new(kind, Path::new(self.shape))
}
#[inline(always)]
pub const fn data_mut(&mut self) -> PtrMut {
self.data
}
#[inline]
pub fn variance(&self) -> Variance {
self.shape.computed_variance()
}
#[inline]
pub fn try_reborrow<'shorter>(&mut self) -> Option<Poke<'_, 'shorter>>
where
'facet: 'shorter,
{
if self.variance().can_shrink() {
Some(Poke {
data: self.data,
shape: self.shape,
_marker: PhantomData,
})
} else {
None
}
}
#[inline]
pub const fn is_struct(&self) -> bool {
matches!(self.shape.ty, Type::User(UserType::Struct(_)))
}
#[inline]
pub const fn is_enum(&self) -> bool {
matches!(self.shape.ty, Type::User(UserType::Enum(_)))
}
#[inline]
pub const fn is_scalar(&self) -> bool {
matches!(self.shape.def, Def::Scalar)
}
pub fn into_struct(self) -> Result<PokeStruct<'mem, 'facet>, ReflectError> {
match self.shape.ty {
Type::User(UserType::Struct(struct_type)) => Ok(PokeStruct {
value: self,
ty: struct_type,
}),
_ => Err(self.err(ReflectErrorKind::WasNotA {
expected: "struct",
actual: self.shape,
})),
}
}
pub fn into_enum(self) -> Result<super::PokeEnum<'mem, 'facet>, ReflectError> {
match self.shape.ty {
Type::User(UserType::Enum(enum_type)) => Ok(super::PokeEnum {
value: self,
ty: enum_type,
}),
_ => Err(self.err(ReflectErrorKind::WasNotA {
expected: "enum",
actual: self.shape,
})),
}
}
#[inline]
pub fn into_list(self) -> Result<PokeList<'mem, 'facet>, ReflectError> {
if let Def::List(def) = self.shape.def {
return Ok(unsafe { PokeList::new(self, def) });
}
Err(self.err(ReflectErrorKind::WasNotA {
expected: "list",
actual: self.shape,
}))
}
pub fn get<T: Facet<'facet>>(&self) -> Result<&T, ReflectError> {
if self.shape != T::SHAPE {
return Err(self.err(ReflectErrorKind::WrongShape {
expected: self.shape,
actual: T::SHAPE,
}));
}
Ok(unsafe { self.data.as_const().get::<T>() })
}
pub fn get_mut<T: Facet<'facet>>(&mut self) -> Result<&mut T, ReflectError> {
if self.shape != T::SHAPE {
return Err(self.err(ReflectErrorKind::WrongShape {
expected: self.shape,
actual: T::SHAPE,
}));
}
Ok(unsafe { self.data.as_mut::<T>() })
}
pub fn set<T: Facet<'facet>>(&mut self, value: T) -> Result<(), ReflectError> {
if self.shape != T::SHAPE {
return Err(self.err(ReflectErrorKind::WrongShape {
expected: self.shape,
actual: T::SHAPE,
}));
}
unsafe {
self.shape.call_drop_in_place(self.data);
core::ptr::write(self.data.as_mut_byte_ptr() as *mut T, value);
}
Ok(())
}
#[inline]
pub fn as_peek(&self) -> crate::Peek<'_, 'facet> {
unsafe { crate::Peek::unchecked_new(self.data.as_const(), self.shape) }
}
#[inline]
pub fn into_peek(self) -> crate::Peek<'mem, 'facet> {
unsafe { crate::Peek::unchecked_new(self.data.as_const(), self.shape) }
}
pub fn at_path_mut(self, path: &Path) -> Result<Poke<'mem, 'facet>, PathAccessError> {
if self.shape != path.shape {
return Err(PathAccessError::RootShapeMismatch {
expected: path.shape,
actual: self.shape,
});
}
let mut data = self.data;
let mut shape: &'static Shape = self.shape;
for (step_index, step) in path.steps().iter().enumerate() {
let (new_data, new_shape) = apply_step_mut(data, shape, *step, step_index)?;
data = new_data;
shape = new_shape;
}
Ok(unsafe { Poke::from_raw_parts(data, shape) })
}
}
fn apply_step_mut(
data: PtrMut,
shape: &'static Shape,
step: PathStep,
step_index: usize,
) -> Result<(PtrMut, &'static Shape), PathAccessError> {
match step {
PathStep::Field(idx) => {
let idx = idx as usize;
match shape.ty {
Type::User(UserType::Struct(sd)) => {
if idx >= sd.fields.len() {
return Err(PathAccessError::IndexOutOfBounds {
step,
step_index,
shape,
index: idx,
bound: sd.fields.len(),
});
}
let field = &sd.fields[idx];
let field_data = unsafe { data.field(field.offset) };
Ok((field_data, field.shape()))
}
Type::User(UserType::Enum(enum_type)) => {
let variant_idx = variant_index_from_raw(data.as_const(), shape, enum_type)
.map_err(|_| PathAccessError::WrongStepKind {
step,
step_index,
shape,
})?;
let variant = &enum_type.variants[variant_idx];
if idx >= variant.data.fields.len() {
return Err(PathAccessError::IndexOutOfBounds {
step,
step_index,
shape,
index: idx,
bound: variant.data.fields.len(),
});
}
let field = &variant.data.fields[idx];
let field_data = unsafe { data.field(field.offset) };
Ok((field_data, field.shape()))
}
_ => Err(PathAccessError::WrongStepKind {
step,
step_index,
shape,
}),
}
}
PathStep::Variant(expected_idx) => {
let expected_idx = expected_idx as usize;
let enum_type = match shape.ty {
Type::User(UserType::Enum(et)) => et,
_ => {
return Err(PathAccessError::WrongStepKind {
step,
step_index,
shape,
});
}
};
if expected_idx >= enum_type.variants.len() {
return Err(PathAccessError::IndexOutOfBounds {
step,
step_index,
shape,
index: expected_idx,
bound: enum_type.variants.len(),
});
}
let actual_idx =
variant_index_from_raw(data.as_const(), shape, enum_type).map_err(|_| {
PathAccessError::WrongStepKind {
step,
step_index,
shape,
}
})?;
if actual_idx != expected_idx {
return Err(PathAccessError::VariantMismatch {
step_index,
shape,
expected_variant: expected_idx,
actual_variant: actual_idx,
});
}
Ok((data, shape))
}
PathStep::Index(idx) => {
let idx = idx as usize;
match shape.def {
Def::List(def) => {
let get_mut_fn = def.vtable.get_mut.ok_or(PathAccessError::WrongStepKind {
step,
step_index,
shape,
})?;
let len = unsafe { (def.vtable.len)(data.as_const()) };
let item = unsafe { get_mut_fn(data, idx, shape) };
item.map(|ptr| (ptr, def.t()))
.ok_or(PathAccessError::IndexOutOfBounds {
step,
step_index,
shape,
index: idx,
bound: len,
})
}
Def::Array(def) => {
let elem_shape = def.t();
let layout = elem_shape.layout.sized_layout().map_err(|_| {
PathAccessError::WrongStepKind {
step,
step_index,
shape,
}
})?;
let len = def.n;
if idx >= len {
return Err(PathAccessError::IndexOutOfBounds {
step,
step_index,
shape,
index: idx,
bound: len,
});
}
let elem_data = unsafe { data.field(layout.size() * idx) };
Ok((elem_data, elem_shape))
}
_ => Err(PathAccessError::WrongStepKind {
step,
step_index,
shape,
}),
}
}
PathStep::OptionSome => {
if let Def::Option(option_def) = shape.def {
let is_some = unsafe { (option_def.vtable.is_some)(data.as_const()) };
if !is_some {
return Err(PathAccessError::OptionIsNone { step_index, shape });
}
let inner_raw_ptr = unsafe { (option_def.vtable.get_value)(data.as_const()) };
assert!(
!inner_raw_ptr.is_null(),
"is_some was true but get_value returned null"
);
let inner_ptr_const = facet_core::PtrConst::new_sized(inner_raw_ptr);
let offset = unsafe {
inner_ptr_const
.as_byte_ptr()
.offset_from(data.as_const().as_byte_ptr())
} as usize;
let inner_data = unsafe { data.field(offset) };
Ok((inner_data, option_def.t()))
} else {
Err(PathAccessError::WrongStepKind {
step,
step_index,
shape,
})
}
}
PathStep::MapKey(_) | PathStep::MapValue(_) => {
if matches!(shape.def, Def::Map(_)) {
Err(PathAccessError::MissingTarget {
step,
step_index,
shape,
})
} else {
Err(PathAccessError::WrongStepKind {
step,
step_index,
shape,
})
}
}
PathStep::Deref => {
if matches!(shape.def, Def::Pointer(_)) {
Err(PathAccessError::MissingTarget {
step,
step_index,
shape,
})
} else {
Err(PathAccessError::WrongStepKind {
step,
step_index,
shape,
})
}
}
PathStep::Inner => Err(PathAccessError::MissingTarget {
step,
step_index,
shape,
}),
PathStep::Proxy => Err(PathAccessError::MissingTarget {
step,
step_index,
shape,
}),
}
}
fn variant_index_from_raw(
data: PtrConst,
shape: &'static Shape,
enum_type: facet_core::EnumType,
) -> Result<usize, VariantError> {
use facet_core::EnumRepr;
if let Def::Option(option_def) = shape.def {
let is_some = unsafe { (option_def.vtable.is_some)(data) };
return Ok(enum_type
.variants
.iter()
.position(|variant| {
let has_fields = !variant.data.fields.is_empty();
has_fields == is_some
})
.expect("No variant found matching Option state"));
}
if enum_type.enum_repr == EnumRepr::RustNPO {
let layout = shape
.layout
.sized_layout()
.map_err(|_| VariantError::Unsized)?;
let slice = unsafe { core::slice::from_raw_parts(data.as_byte_ptr(), layout.size()) };
let all_zero = slice.iter().all(|v| *v == 0);
Ok(enum_type
.variants
.iter()
.position(|variant| {
let mut max_offset = 0;
for field in variant.data.fields {
let offset = field.offset
+ field
.shape()
.layout
.sized_layout()
.map(|v| v.size())
.unwrap_or(0);
max_offset = core::cmp::max(max_offset, offset);
}
if all_zero {
max_offset == 0
} else {
max_offset != 0
}
})
.expect("No variant found with matching discriminant"))
} else {
let discriminant = match enum_type.enum_repr {
EnumRepr::Rust => {
panic!("cannot read discriminant from Rust enum with unspecified layout")
}
EnumRepr::RustNPO => 0,
EnumRepr::U8 => unsafe { data.read::<u8>() as i64 },
EnumRepr::U16 => unsafe { data.read::<u16>() as i64 },
EnumRepr::U32 => unsafe { data.read::<u32>() as i64 },
EnumRepr::U64 => unsafe { data.read::<u64>() as i64 },
EnumRepr::USize => unsafe { data.read::<usize>() as i64 },
EnumRepr::I8 => unsafe { data.read::<i8>() as i64 },
EnumRepr::I16 => unsafe { data.read::<i16>() as i64 },
EnumRepr::I32 => unsafe { data.read::<i32>() as i64 },
EnumRepr::I64 => unsafe { data.read::<i64>() },
EnumRepr::ISize => unsafe { data.read::<isize>() as i64 },
};
Ok(enum_type
.variants
.iter()
.position(|variant| variant.discriminant == Some(discriminant))
.expect("No variant found with matching discriminant"))
}
}
impl core::fmt::Debug for Poke<'_, '_> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Poke<{}>", self.shape)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn poke_primitive_get_set() {
let mut x: i32 = 42;
let mut poke = Poke::new(&mut x);
assert_eq!(*poke.get::<i32>().unwrap(), 42);
poke.set(100i32).unwrap();
assert_eq!(x, 100);
}
#[test]
fn poke_primitive_get_mut() {
let mut x: i32 = 42;
let mut poke = Poke::new(&mut x);
*poke.get_mut::<i32>().unwrap() = 99;
assert_eq!(x, 99);
}
#[test]
fn poke_wrong_type_fails() {
let mut x: i32 = 42;
let poke = Poke::new(&mut x);
let result = poke.get::<u32>();
assert!(matches!(
result,
Err(ReflectError {
kind: ReflectErrorKind::WrongShape { .. },
..
})
));
}
#[test]
fn poke_set_wrong_type_fails() {
let mut x: i32 = 42;
let mut poke = Poke::new(&mut x);
let result = poke.set(42u32);
assert!(matches!(
result,
Err(ReflectError {
kind: ReflectErrorKind::WrongShape { .. },
..
})
));
}
#[test]
fn poke_string_drop_and_replace() {
let mut s = String::from("hello");
let mut poke = Poke::new(&mut s);
poke.set(String::from("world")).unwrap();
assert_eq!(s, "world");
}
}