use derive_more::{Deref, DerefMut, From};
use facet::{Def, PointerFlags, PtrConst, ReadLockResult, Shape, WriteLockResult};
use facet_reflect::{Peek, Poke};
#[derive(From)]
#[repr(C)]
pub enum MaybeMut<'mem, 'facet> {
Not(Peek<'mem, 'facet>),
Mut(Poke<'mem, 'facet>),
}
impl<'mem, 'facet> MaybeMut<'mem, 'facet> {
pub fn as_peek(&'mem self) -> Peek<'mem, 'facet> {
match self {
Self::Not(peek) => *peek,
Self::Mut(poke) => poke.as_peek(),
}
}
pub fn into_peek(self) -> Peek<'mem, 'facet> {
match self {
MaybeMut::Not(n) => n,
MaybeMut::Mut(m) => m.into_peek(),
}
}
pub fn shape(&self) -> &'static Shape {
self.as_peek().shape()
}
}
#[derive(Debug, thiserror::Error)]
#[error("{kind}")]
pub struct MakeLockError<'mem, 'facet> {
pub unchanged: Peek<'mem, 'facet>,
pub kind: MakeLockErrorKind,
}
#[derive(Debug, thiserror::Error)]
pub enum MakeLockErrorKind {
#[error("type cannot be locked")]
NotLockable,
#[error("locking of type failed")]
LockFailure,
}
#[derive(From)]
pub(crate) enum LockGuardType {
Write(WriteLockResult),
Read(ReadLockResult),
}
impl LockGuardType {
pub fn data_const(&self) -> PtrConst {
match self {
Self::Write(w) => w.data_const(),
Self::Read(r) => *r.data(),
}
}
}
#[derive(Deref, DerefMut)]
pub struct Guard<'lock_mem, 'facet> {
_guard: Option<LockGuardType>,
#[deref]
#[deref_mut]
data: MaybeMut<'lock_mem, 'facet>,
}
impl<'mem, 'facet> MaybeMut<'mem, 'facet> {
pub fn write<'lock>(self) -> Result<Guard<'lock, 'facet>, MakeLockError<'mem, 'facet>>
where
'mem: 'lock,
{
match self {
MaybeMut::Mut(v) => {
if let Def::Pointer(p) = v.as_peek().innermost_peek().shape().def
&& p.flags.contains(PointerFlags::LOCK)
{
Self::Not(v.into_peek()).write()
} else {
Ok(Guard {
_guard: None,
data: v.into(),
})
}
}
MaybeMut::Not(v) => {
let v = v.innermost_peek();
let shape = v.shape();
let def = shape.def;
let Def::Pointer(pointer) = def else {
return Err(MakeLockError {
unchanged: v,
kind: MakeLockErrorKind::NotLockable,
});
};
let lock_fn = match (pointer.vtable.write_fn, pointer.vtable.lock_fn) {
(Some(write_fn), _) => write_fn,
(_, Some(lock_fn)) => lock_fn,
_ => {
return Err(MakeLockError {
unchanged: v,
kind: MakeLockErrorKind::NotLockable,
});
}
};
let res = unsafe { lock_fn(v.data()) };
let Ok(lock) = res else {
return Err(MakeLockError {
unchanged: v,
kind: MakeLockErrorKind::LockFailure,
});
};
let poke: Poke<'lock, 'facet> = unsafe {
Poke::from_raw_parts(
*lock.data(),
shape
.inner
.expect("a smart pointer always has an inner shape"),
)
};
let value: MaybeMut<'lock, 'facet> = MaybeMut::Mut(poke);
Ok(Guard {
data: value,
_guard: Some(LockGuardType::Write(lock)),
})
}
}
}
pub fn read<'lock>(self) -> Result<Guard<'lock, 'facet>, MakeLockError<'mem, 'facet>>
where
'mem: 'lock,
{
let peek = self.into_peek();
let v = peek.innermost_peek();
let shape = v.shape();
let def = shape.def;
let Def::Pointer(pointer) = def else {
return Ok(Guard {
_guard: None,
data: MaybeMut::Not(v),
});
};
let res: Result<LockGuardType, _> = if let Some(read_fn) = pointer.vtable.read_fn {
unsafe { read_fn(v.data()) }.map(Into::into)
} else if let Some(lock_fn) = pointer.vtable.lock_fn {
unsafe { lock_fn(v.data()) }.map(Into::into)
} else {
return Err(MakeLockError {
unchanged: v,
kind: MakeLockErrorKind::NotLockable,
});
};
let Ok(lock) = res else {
return Err(MakeLockError {
unchanged: v,
kind: MakeLockErrorKind::LockFailure,
});
};
let peek: Peek<'lock, 'facet> = unsafe {
Peek::unchecked_new(
lock.data_const(),
shape
.inner
.expect("a smart pointer always has an inner shape"),
)
};
let value = MaybeMut::Not(peek);
Ok(Guard {
_guard: Some(lock),
data: value,
})
}
}
#[cfg(test)]
mod tests {
use facet::{Def, Facet, KnownPointer};
use facet_reflect::Peek;
#[derive(Debug, Facet)]
struct Foo {
value: String,
}
#[facet_testhelpers::test]
fn shared_reference() {
let a = Foo {
value: "aaaa".to_string(),
};
println!("{:#?}", <&Foo as Facet<'_>>::SHAPE.def);
assert!(
matches!(<&Foo as Facet<'_>>::SHAPE.def, Def::Pointer(p) if p.known == Some(KnownPointer::SharedReference))
);
let ref_a: &Foo = &a;
let ref_ref_a: &&Foo = &ref_a;
let peek = Peek::new(ref_ref_a);
println!("{:#?}", peek.shape().def); assert!(
matches!(peek.shape().def, Def::Pointer(p) if p.known == Some(KnownPointer::SharedReference))
);
}
}