use core::marker::PhantomData;
use facet_core::{Def, Facet, PtrConst, PtrMut, Shape, Type, UserType};
use crate::ReflectError;
use super::PokeStruct;
pub struct Poke<'mem, 'facet> {
pub(crate) data: PtrMut,
pub(crate) shape: &'static Shape,
#[allow(clippy::type_complexity)]
_marker: PhantomData<(&'mem mut (), fn(&'facet ()) -> &'facet ())>,
}
impl<'mem, 'facet> Poke<'mem, 'facet> {
pub fn new<T: Facet<'facet>>(t: &'mem mut T) -> Self {
Self {
data: PtrMut::new(t as *mut T as *mut u8),
shape: T::SHAPE,
_marker: PhantomData,
}
}
pub unsafe fn from_raw_parts(data: PtrMut, shape: &'static Shape) -> Self {
Self {
data,
shape,
_marker: PhantomData,
}
}
#[inline(always)]
pub const fn shape(&self) -> &'static Shape {
self.shape
}
#[inline(always)]
pub const fn data(&self) -> PtrConst {
self.data.as_const()
}
#[inline(always)]
pub const fn data_mut(&mut self) -> PtrMut {
self.data
}
#[inline]
pub const fn is_struct(&self) -> bool {
matches!(self.shape.ty, Type::User(UserType::Struct(_)))
}
#[inline]
pub const fn is_enum(&self) -> bool {
matches!(self.shape.ty, Type::User(UserType::Enum(_)))
}
#[inline]
pub const fn is_scalar(&self) -> bool {
matches!(self.shape.def, Def::Scalar)
}
pub const fn into_struct(self) -> Result<PokeStruct<'mem, 'facet>, ReflectError> {
match self.shape.ty {
Type::User(UserType::Struct(struct_type)) => Ok(PokeStruct {
value: self,
ty: struct_type,
}),
_ => Err(ReflectError::WasNotA {
expected: "struct",
actual: self.shape,
}),
}
}
pub const fn into_enum(self) -> Result<super::PokeEnum<'mem, 'facet>, ReflectError> {
match self.shape.ty {
Type::User(UserType::Enum(enum_type)) => Ok(super::PokeEnum {
value: self,
ty: enum_type,
}),
_ => Err(ReflectError::WasNotA {
expected: "enum",
actual: self.shape,
}),
}
}
pub fn get<T: Facet<'facet>>(&self) -> Result<&T, ReflectError> {
if self.shape != T::SHAPE {
return Err(ReflectError::WrongShape {
expected: self.shape,
actual: T::SHAPE,
});
}
Ok(unsafe { self.data.as_const().get::<T>() })
}
pub fn get_mut<T: Facet<'facet>>(&mut self) -> Result<&mut T, ReflectError> {
if self.shape != T::SHAPE {
return Err(ReflectError::WrongShape {
expected: self.shape,
actual: T::SHAPE,
});
}
Ok(unsafe { self.data.as_mut::<T>() })
}
pub fn set<T: Facet<'facet>>(&mut self, value: T) -> Result<(), ReflectError> {
if self.shape != T::SHAPE {
return Err(ReflectError::WrongShape {
expected: self.shape,
actual: T::SHAPE,
});
}
unsafe {
self.shape.call_drop_in_place(self.data);
core::ptr::write(self.data.as_mut_byte_ptr() as *mut T, value);
}
Ok(())
}
#[inline]
pub fn as_peek(&self) -> crate::Peek<'_, 'facet> {
unsafe { crate::Peek::unchecked_new(self.data.as_const(), self.shape) }
}
}
impl core::fmt::Debug for Poke<'_, '_> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Poke<{}>", self.shape)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn poke_primitive_get_set() {
let mut x: i32 = 42;
let mut poke = Poke::new(&mut x);
assert_eq!(*poke.get::<i32>().unwrap(), 42);
poke.set(100i32).unwrap();
assert_eq!(x, 100);
}
#[test]
fn poke_primitive_get_mut() {
let mut x: i32 = 42;
let mut poke = Poke::new(&mut x);
*poke.get_mut::<i32>().unwrap() = 99;
assert_eq!(x, 99);
}
#[test]
fn poke_wrong_type_fails() {
let mut x: i32 = 42;
let poke = Poke::new(&mut x);
let result = poke.get::<u32>();
assert!(matches!(result, Err(ReflectError::WrongShape { .. })));
}
#[test]
fn poke_set_wrong_type_fails() {
let mut x: i32 = 42;
let mut poke = Poke::new(&mut x);
let result = poke.set(42u32);
assert!(matches!(result, Err(ReflectError::WrongShape { .. })));
}
#[test]
fn poke_string_drop_and_replace() {
let mut s = String::from("hello");
let mut poke = Poke::new(&mut s);
poke.set(String::from("world")).unwrap();
assert_eq!(s, "world");
}
}