Skip to main content

facet_maybe_mut/
maybe_mut.rs

1use derive_more::{Deref, DerefMut, From};
2use facet::{Def, PointerFlags, PtrConst, ReadLockResult, Shape, WriteLockResult};
3use facet_reflect::{Peek, Poke};
4
5/// Some reference to a type that implements [`Facet`](facet::Facet) that may be
6/// `mut` or not.
7#[derive(From)]
8#[repr(C)]
9pub enum MaybeMut<'mem, 'facet> {
10    Not(Peek<'mem, 'facet>),
11    Mut(Poke<'mem, 'facet>),
12}
13
14impl<'mem, 'facet> MaybeMut<'mem, 'facet> {
15    /// Returns a readonly/immutable version of the inner type
16    pub fn as_peek(&'mem self) -> Peek<'mem, 'facet> {
17        match self {
18            Self::Not(peek) => *peek,
19            Self::Mut(poke) => poke.as_peek(),
20        }
21    }
22
23    pub fn into_peek(self) -> Peek<'mem, 'facet> {
24        match self {
25            MaybeMut::Not(n) => n,
26            MaybeMut::Mut(m) => {
27                let m: Poke<'mem, 'facet> = m;
28                let peek: Peek<'mem, 'facet> = unsafe { Peek::unchecked_new(m.data(), m.shape()) };
29                peek
30            }
31        }
32    }
33
34    /// Returns the [`Shape`] of the underlying type
35    ///
36    /// The [`Shape`] is the same for [`Mut`](Self::Mut) and [`Not`](Self::Not)
37    pub fn shape(&self) -> &'static Shape {
38        self.as_peek().shape()
39    }
40}
41
42#[derive(Debug, thiserror::Error)]
43#[error("{kind}")]
44pub struct MakeLockError<'mem, 'facet> {
45    pub unchanged: Peek<'mem, 'facet>,
46    pub kind: MakeLockErrorKind,
47}
48
49#[derive(Debug, thiserror::Error)]
50pub enum MakeLockErrorKind {
51    /// The underlying type is not a type that we can lock from a `&T` to `&mut T` (but it is more complicated...)
52    #[error("type cannot be locked")]
53    NotLockable,
54    /// The underlying type could be locked but the provided lock method in the
55    /// vtable returned an error.
56    #[error("locking of type failed")]
57    LockFailure,
58}
59
60/// Depending on whether this is a read or write lock, `P` will be either
61/// [`PtrConst`] or [`PtrMut`]. This enum makes `P` dynamic
62#[derive(From)]
63pub(crate) enum LockGuardType {
64    Write(WriteLockResult),
65    Read(ReadLockResult),
66}
67
68impl LockGuardType {
69    pub fn data_const(&self) -> PtrConst {
70        match self {
71            Self::Write(w) => w.data_const(),
72            Self::Read(r) => *r.data(),
73        }
74    }
75}
76
77/// Contains the guard, the data ptr, and drop vtable to free the lock
78///
79/// # Note
80///
81/// The contained [`MaybeMut`] is NOT guaranteed to be [`Mut`](MaybeMut::Mut)
82///
83/// For example, RwLock also needs a lock and guard for a read.
84///
85#[derive(Deref, DerefMut)]
86pub struct Guard<'lock_mem, 'facet> {
87    /// Dropping the guard handles freeing the lock
88    ///
89    /// If this is None, the `data` can be accessed directly and there is no
90    /// lock that must be freeed
91    ///
92    /// SAFETY: The pointer inside the [`LockGuardType`] MUST NOT be used
93    /// since the data is already (mutable) available via `data`
94    _guard: Option<LockGuardType>,
95    #[deref]
96    #[deref_mut]
97    data: MaybeMut<'lock_mem, 'facet>,
98}
99
100impl<'mem, 'facet> MaybeMut<'mem, 'facet> {
101    /// Try to turn [`MaybeMut::Not`] into [`MaybeMut::Mut`]
102    ///
103    /// The returned [`MaybeMut`] may contain a different [`Shape`].
104    /// Which exact [`Shape`] it is, depends on what the input type was.
105    ///
106    /// One edge case is if you pass a `&mut Arc<RwLock<String>` the type will
107    /// not be changed to `&mut String`. But if you pass a `&Arc<RwLock<String>`
108    /// due to locking etc, it will be a `&mut String`.
109    ///
110    /// If the underlying type is something that can be write locked,
111    /// for example an `RwLock` or `Mutex`, this method creates a lock on it.
112    ///
113    /// If we already have [`MaybeMut::Mut`] this is a no-op.
114    ///
115    /// If we have [`MaybeMut::Not`] and the [`Shape`] of
116    /// `T` does not contain a [`PointerDef`](facet::PointerDef) which
117    /// has a vtable with a `write_fn` we can call with `&T`, this method
118    /// returns [`Err(MaybeMut::Not)`](Err). In this case, besides the lookup,
119    /// it is also a no-op.
120    ///
121    /// # Note
122    ///
123    /// It is very important that you drop the [`Guard`] as soon as possible
124    /// to free the lock
125    ///
126    /// [`Shape`]: facet::Shape
127    pub fn write<'lock>(self) -> Result<Guard<'lock, 'facet>, MakeLockError<'mem, 'facet>>
128    where
129        'mem: 'lock,
130    {
131        match self {
132            // if we already have a mut this is a no op
133            MaybeMut::Mut(v) => {
134                // but only if this is a type that is not a smart pointer that can be locked
135                if let Def::Pointer(p) = v.as_peek().innermost_peek().shape().def
136                // restrict downgrading to Peek only if there is a a lock _somewhere_
137                    && p.flags.contains(PointerFlags::LOCK)
138                {
139                    Self::Not(v.into_peek()).write()
140                } else {
141                    Ok(Guard {
142                        _guard: None,
143                        data: v.into(),
144                    })
145                }
146            }
147            // this is where it gets interesting
148            MaybeMut::Not(v) => {
149                // SAFETY: v.innermost_peek() unwraps all transparent wrappers like Arc or Rc until something that needs
150                // locking is reached which is all we care about
151                // FIXME: naively using innermost_peek is a bad idea i think.
152                // for example, in the UI if there is a NonZero<u32> this will peek
153                // up tu u32. Then, we will perhaps display an editable u32 which
154                // can be set to zero. Now what? We broke it boys
155                let v = v.innermost_peek();
156                // the shape of the pointer type (if it is one) but derefence smart pointers that can so without locking
157                // e.g. Arc<T> AND also &T
158                let shape = v.shape();
159                let def = shape.def;
160
161                // short cirucit if it is not a pointer. in these
162                // cases we wont be able to reach something like
163                // RwLock or Mutex
164                let Def::Pointer(pointer) = def else {
165                    return Err(MakeLockError {
166                        unchanged: v,
167                        kind: MakeLockErrorKind::NotLockable,
168                    });
169                };
170
171                // we dont care if we lock it (Mutex) or write lock it (RwLock)
172                let lock_fn = match (pointer.vtable.write_fn, pointer.vtable.lock_fn) {
173                    (Some(write_fn), _) => write_fn,
174                    (_, Some(lock_fn)) => lock_fn,
175                    _ => {
176                        return Err(MakeLockError {
177                            unchanged: v,
178                            kind: MakeLockErrorKind::NotLockable,
179                        });
180                    }
181                };
182                // SAFETY: v.innermost_peek() unwraps all transparent wrappers like Arc or Rc until something that needs
183                // locking is reached which is also the same type we get the lock_fn from
184                let res = unsafe { lock_fn(v.data()) };
185                let Ok(lock) = res else {
186                    return Err(MakeLockError {
187                        unchanged: v,
188                        kind: MakeLockErrorKind::LockFailure,
189                    });
190                };
191
192                // SAFETY: creates access via the PtrMut returned from locking
193                // the smart pointer. 'mem outlives 'lock this means
194                // the returned SmartPointer<'mem> also outlives the mutable Poke<'lock>
195                let poke: Poke<'lock, 'facet> = unsafe {
196                    Poke::from_raw_parts(
197                        // if the input type was Arc<RwLock<String>> this willbe
198                        // a pointer to a String
199                        *lock.data(),
200                        shape
201                            .inner
202                            .expect("a smart pointer always has an inner shape"),
203                    )
204                };
205                let value: MaybeMut<'lock, 'facet> = MaybeMut::Mut(poke);
206                Ok(Guard {
207                    data: value,
208                    _guard: Some(LockGuardType::Write(lock)),
209                })
210            }
211        }
212    }
213
214    /// Returns a [`Guard`] with a lock that is sufficent for reading.
215    ///
216    /// In case of `RwLock` it is locked to read. If it is a `Mutex`, it must
217    /// be exclusively locked to write but we only consider it being read which
218    /// is safe
219    pub fn read<'lock>(self) -> Result<Guard<'lock, 'facet>, MakeLockError<'mem, 'facet>>
220    where
221        'mem: 'lock,
222    {
223        let peek = self.into_peek();
224        // unwrap smart pointers
225        let v = peek.innermost_peek();
226        // the shape of the pointer type (if it is one) but derefence smart pointers that can so without locking
227        // e.g. Arc<T>
228        let shape = v.shape();
229        let def = shape.def;
230
231        // short cirucit if it is not a pointer. in these
232        // cases we wont be able to reach something like
233        // RwLock or Mutex
234        // In this case, we just return the reference to the underlying type
235        let Def::Pointer(pointer) = def else {
236            return Ok(Guard {
237                _guard: None,
238                data: MaybeMut::Not(v),
239            });
240        };
241
242        // we dont care if we lock it (Mutex) or read lock it (RwLock)
243        let res: Result<LockGuardType, _> = if let Some(read_fn) = pointer.vtable.read_fn {
244            unsafe { read_fn(v.data()) }.map(Into::into)
245        } else if let Some(lock_fn) = pointer.vtable.lock_fn {
246            unsafe { lock_fn(v.data()) }.map(Into::into)
247        } else {
248            return Err(MakeLockError {
249                unchanged: v,
250                kind: MakeLockErrorKind::NotLockable,
251            });
252        };
253
254        let Ok(lock) = res else {
255            return Err(MakeLockError {
256                unchanged: v,
257                kind: MakeLockErrorKind::LockFailure,
258            });
259        };
260        // SAFETY: creates access via the PtrMut returned from locking
261        // the smart pointer. 'lock outlives 'mem this means
262        // the returned mutable Poke<'lock> also outlives the SmartPointer<'mem>
263        let peek: Peek<'lock, 'facet> = unsafe {
264            Peek::unchecked_new(
265                lock.data_const(),
266                shape
267                    .inner
268                    .expect("a smart pointer always has an inner shape"),
269            )
270        };
271        let value = MaybeMut::Not(peek);
272        Ok(Guard {
273            _guard: Some(lock),
274            data: value,
275        })
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use facet::{Def, Facet, KnownPointer};
282    use facet_reflect::Peek;
283
284    #[derive(Debug, Facet)]
285    struct Foo {
286        value: String,
287    }
288
289    #[facet_testhelpers::test]
290    fn shared_reference() {
291        let a = Foo {
292            value: "aaaa".to_string(),
293        };
294        println!("{:#?}", <&Foo as Facet<'_>>::SHAPE.def);
295        assert!(
296            matches!(<&Foo as Facet<'_>>::SHAPE.def, Def::Pointer(p) if p.known == Some(KnownPointer::SharedReference))
297        );
298        let ref_a: &Foo = &a;
299        let ref_ref_a: &&Foo = &ref_a;
300        let peek = Peek::new(ref_ref_a);
301        println!("{:#?}", peek.shape().def); // `Undefined`
302        assert!(
303            matches!(peek.shape().def, Def::Pointer(p) if p.known == Some(KnownPointer::SharedReference))
304        );
305    }
306}