Skip to main content

facet_maybe_mut/
maybe_mut.rs

1use derive_more::{Deref, DerefMut, From};
2use facet::{Def, PointerFlags, PtrConst, ReadLockResult, Shape, WriteLockResult};
3use facet_reflect::{Peek, Poke};
4
5/// Some reference to a type that implements [`Facet`](facet::Facet) that may be
6/// `mut` or not.
7#[derive(From)]
8#[repr(C)]
9pub enum MaybeMut<'mem, 'facet> {
10    Not(Peek<'mem, 'facet>),
11    Mut(Poke<'mem, 'facet>),
12}
13
14impl<'mem, 'facet> MaybeMut<'mem, 'facet> {
15    /// Returns a readonly/immutable version of the inner type
16    pub fn as_peek(&'mem self) -> Peek<'mem, 'facet> {
17        match self {
18            Self::Not(peek) => *peek,
19            Self::Mut(poke) => poke.as_peek(),
20        }
21    }
22
23    pub fn into_peek(self) -> Peek<'mem, 'facet> {
24        match self {
25            MaybeMut::Not(n) => n,
26            MaybeMut::Mut(m) => m.into_peek(),
27        }
28    }
29
30    /// Returns the [`Shape`] of the underlying type
31    ///
32    /// The [`Shape`] is the same for [`Mut`](Self::Mut) and [`Not`](Self::Not)
33    pub fn shape(&self) -> &'static Shape {
34        self.as_peek().shape()
35    }
36}
37
38#[derive(Debug, thiserror::Error)]
39#[error("{kind}")]
40pub struct MakeLockError<'mem, 'facet> {
41    pub unchanged: Peek<'mem, 'facet>,
42    pub kind: MakeLockErrorKind,
43}
44
45#[derive(Debug, thiserror::Error)]
46pub enum MakeLockErrorKind {
47    /// The underlying type is not a type that we can lock from a `&T` to `&mut T` (but it is more complicated...)
48    #[error("type cannot be locked")]
49    NotLockable,
50    /// The underlying type could be locked but the provided lock method in the
51    /// vtable returned an error.
52    #[error("locking of type failed")]
53    LockFailure,
54}
55
56/// Depending on whether this is a read or write lock, `P` will be either
57/// [`PtrConst`] or [`PtrMut`](facet::PtrMut). This enum makes `P` dynamic
58#[derive(From)]
59pub(crate) enum LockGuardType {
60    Write(WriteLockResult),
61    Read(ReadLockResult),
62}
63
64impl LockGuardType {
65    /// Safety
66    ///
67    /// This is the raw pointer returned from the lock which is already
68    /// available via [`Guard`]. Creating a new [`Peek`] or [`Poke`] from this
69    /// [`PtrConst`] is UB.
70    pub fn data_const(&self) -> PtrConst {
71        match self {
72            Self::Write(w) => w.data_const(),
73            Self::Read(r) => *r.data(),
74        }
75    }
76}
77
78/// Contains the guard, the data ptr, and drop vtable to free the lock
79///
80/// # Note
81///
82/// The contained [`MaybeMut`] is NOT guaranteed to be [`Mut`](MaybeMut::Mut)
83///
84/// For example, RwLock also needs a lock and guard for a read.
85///
86#[derive(Deref, DerefMut)]
87pub struct Guard<'lock_mem, 'facet> {
88    /// Dropping the guard handles freeing the lock
89    ///
90    /// If this is None, the `data` can be accessed directly and there is no
91    /// lock that must be freeed
92    ///
93    /// SAFETY: The pointer inside the [`LockGuardType`] MUST NOT be used
94    /// since the data is already (mutable) available via `data`
95    _guard: Option<LockGuardType>,
96    #[deref]
97    #[deref_mut]
98    data: MaybeMut<'lock_mem, 'facet>,
99}
100
101impl<'mem, 'facet> MaybeMut<'mem, 'facet> {
102    /// Try to turn [`MaybeMut::Not`] into [`MaybeMut::Mut`]
103    ///
104    /// The returned [`MaybeMut`] may contain a different [`Shape`].
105    /// Which exact [`Shape`] it is, depends on what the input type was.
106    ///
107    /// One edge case is if you pass a `&mut Arc<RwLock<String>` the type will
108    /// not be changed to `&mut String`. But if you pass a `&Arc<RwLock<String>`
109    /// due to locking etc, it will be a `&mut String`.
110    ///
111    /// If the underlying type is something that can be write locked,
112    /// for example an `RwLock` or `Mutex`, this method creates a lock on it.
113    ///
114    /// If we already have [`MaybeMut::Mut`] this is a no-op.
115    ///
116    /// If we have [`MaybeMut::Not`] and the [`Shape`] of
117    /// `T` does not contain a [`PointerDef`](facet::PointerDef) which
118    /// has a vtable with a `write_fn` we can call with `&T`, this method
119    /// returns [`Err(MaybeMut::Not)`](Err). In this case, besides the lookup,
120    /// it is also a no-op.
121    ///
122    /// # Note
123    ///
124    /// It is very important that you drop the [`Guard`] as soon as possible
125    /// to free the lock
126    ///
127    /// [`Shape`]: facet::Shape
128    pub fn write<'lock>(self) -> Result<Guard<'lock, 'facet>, MakeLockError<'mem, 'facet>>
129    where
130        'mem: 'lock,
131    {
132        match self {
133            // if we already have a mut this is a no op
134            MaybeMut::Mut(v) => {
135                // but only if this is a type that is not a smart pointer that can be locked
136                if let Def::Pointer(p) = v.as_peek().innermost_peek().shape().def
137                // restrict downgrading to Peek only if there is a a lock _somewhere_
138                    && p.flags.contains(PointerFlags::LOCK)
139                {
140                    Self::Not(v.into_peek()).write()
141                } else {
142                    Ok(Guard {
143                        _guard: None,
144                        data: v.into(),
145                    })
146                }
147            }
148            // this is where it gets interesting
149            MaybeMut::Not(v) => {
150                // SAFETY: v.innermost_peek() unwraps all transparent wrappers like Arc or Rc until something that needs
151                // locking is reached which is all we care about
152                // FIXME: naively using innermost_peek is a bad idea i think.
153                // for example, in the UI if there is a NonZero<u32> this will peek
154                // up tu u32. Then, we will perhaps display an editable u32 which
155                // can be set to zero. Now what? We broke it boys
156                let v = v.innermost_peek();
157                // the shape of the pointer type (if it is one) but derefence smart pointers that can so without locking
158                // e.g. Arc<T> AND also &T
159                let shape = v.shape();
160                let def = shape.def;
161
162                // short cirucit if it is not a pointer. in these
163                // cases we wont be able to reach something like
164                // RwLock or Mutex
165                let Def::Pointer(pointer) = def else {
166                    return Err(MakeLockError {
167                        unchanged: v,
168                        kind: MakeLockErrorKind::NotLockable,
169                    });
170                };
171
172                // we dont care if we lock it (Mutex) or write lock it (RwLock)
173                let lock_fn = match (pointer.vtable.write_fn, pointer.vtable.lock_fn) {
174                    (Some(write_fn), _) => write_fn,
175                    (_, Some(lock_fn)) => lock_fn,
176                    _ => {
177                        return Err(MakeLockError {
178                            unchanged: v,
179                            kind: MakeLockErrorKind::NotLockable,
180                        });
181                    }
182                };
183                // SAFETY: v.innermost_peek() unwraps all transparent wrappers like Arc or Rc until something that needs
184                // locking is reached which is also the same type we get the lock_fn from
185                let res = unsafe { lock_fn(v.data()) };
186                let Ok(lock) = res else {
187                    return Err(MakeLockError {
188                        unchanged: v,
189                        kind: MakeLockErrorKind::LockFailure,
190                    });
191                };
192
193                // SAFETY: creates access via the PtrMut returned from locking
194                // the smart pointer. 'mem outlives 'lock this means
195                // the returned SmartPointer<'mem> also outlives the mutable Poke<'lock>
196                let poke: Poke<'lock, 'facet> = unsafe {
197                    Poke::from_raw_parts(
198                        // if the input type was Arc<RwLock<String>> this willbe
199                        // a pointer to a String
200                        *lock.data(),
201                        shape
202                            .inner
203                            .expect("a smart pointer always has an inner shape"),
204                    )
205                };
206                let value: MaybeMut<'lock, 'facet> = MaybeMut::Mut(poke);
207                Ok(Guard {
208                    data: value,
209                    _guard: Some(LockGuardType::Write(lock)),
210                })
211            }
212        }
213    }
214
215    /// Returns a [`Guard`] with a lock that is sufficent for reading.
216    ///
217    /// In case of `RwLock` it is locked to read. If it is a `Mutex`, it must
218    /// be exclusively locked to write but we only consider it being read which
219    /// is safe
220    pub fn read<'lock>(self) -> Result<Guard<'lock, 'facet>, MakeLockError<'mem, 'facet>>
221    where
222        'mem: 'lock,
223    {
224        let peek = self.into_peek();
225        // unwrap smart pointers
226        let v = peek.innermost_peek();
227        // the shape of the pointer type (if it is one) but derefence smart pointers that can so without locking
228        // e.g. Arc<T>
229        let shape = v.shape();
230        let def = shape.def;
231
232        // short cirucit if it is not a pointer. in these
233        // cases we wont be able to reach something like
234        // RwLock or Mutex
235        // In this case, we just return the reference to the underlying type
236        let Def::Pointer(pointer) = def else {
237            return Ok(Guard {
238                _guard: None,
239                data: MaybeMut::Not(v),
240            });
241        };
242
243        // we dont care if we lock it (Mutex) or read lock it (RwLock)
244        let res: Result<LockGuardType, _> = if let Some(read_fn) = pointer.vtable.read_fn {
245            unsafe { read_fn(v.data()) }.map(Into::into)
246        } else if let Some(lock_fn) = pointer.vtable.lock_fn {
247            unsafe { lock_fn(v.data()) }.map(Into::into)
248        } else {
249            return Err(MakeLockError {
250                unchanged: v,
251                kind: MakeLockErrorKind::NotLockable,
252            });
253        };
254
255        let Ok(lock) = res else {
256            return Err(MakeLockError {
257                unchanged: v,
258                kind: MakeLockErrorKind::LockFailure,
259            });
260        };
261        // SAFETY: creates access via the PtrMut returned from locking
262        // the smart pointer. 'lock outlives 'mem this means
263        // the returned mutable Poke<'lock> also outlives the SmartPointer<'mem>
264        let peek: Peek<'lock, 'facet> = unsafe {
265            Peek::unchecked_new(
266                lock.data_const(),
267                shape
268                    .inner
269                    .expect("a smart pointer always has an inner shape"),
270            )
271        };
272        let value = MaybeMut::Not(peek);
273        Ok(Guard {
274            _guard: Some(lock),
275            data: value,
276        })
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use facet::{Def, Facet, KnownPointer};
283    use facet_reflect::Peek;
284
285    #[derive(Debug, Facet)]
286    struct Foo {
287        value: String,
288    }
289
290    #[facet_testhelpers::test]
291    fn shared_reference() {
292        let a = Foo {
293            value: "aaaa".to_string(),
294        };
295        println!("{:#?}", <&Foo as Facet<'_>>::SHAPE.def);
296        assert!(
297            matches!(<&Foo as Facet<'_>>::SHAPE.def, Def::Pointer(p) if p.known == Some(KnownPointer::SharedReference))
298        );
299        let ref_a: &Foo = &a;
300        let ref_ref_a: &&Foo = &ref_a;
301        let peek = Peek::new(ref_ref_a);
302        println!("{:#?}", peek.shape().def); // `Undefined`
303        assert!(
304            matches!(peek.shape().def, Def::Pointer(p) if p.known == Some(KnownPointer::SharedReference))
305        );
306    }
307}