generic_cursors/
mutex.rs

1use std::{
2    future::Future,
3    marker::PhantomData,
4    pin::Pin,
5    sync::{Mutex, MutexGuard, PoisonError, TryLockError, TryLockResult},
6};
7
8pub struct MutexGuardStack<'root, T: ?Sized> {
9    /// Ensures this mutrefstack does not exceed the lifetime of its root.
10    lifetime: PhantomData<&'root mut T>,
11    /// The stack of pointers. Each one borrows from the one prior, except the first which is the `root` and may never be popped.
12    /// Note: the `'root` lifetime is a "lie", only used because there's no raw pointer counterpart for `MutexGuard`.
13    /// The `MutexGuard`s are not publicly accessible so this is fine.
14    data: Vec<MutexGuard<'root, T>>,
15}
16
17pub enum MoveDecision<'root, 'this, T: ?Sized> {
18    Ascend,
19    Stay,
20    Descend(&'this Mutex<T>),
21    Inject(&'root Mutex<T>),
22}
23
24pub enum MoveError {
25    AscendAtRoot,
26    Poisoned,
27    WouldBlock,
28}
29
30impl<'root, T: ?Sized> MutexGuardStack<'root, T> {
31    /// Create a new MutRefStack from a mutable reference to the root
32    /// of a recursive data structure.
33    pub fn new(root: &'root Mutex<T>) -> TryLockResult<Self> {
34        let root: *const Mutex<T> = root;
35        let guard = unsafe { (*root).try_lock() };
36        match guard {
37            Ok(guard) => Ok(Self {
38                lifetime: PhantomData,
39                data: vec![guard],
40            }),
41            Err(TryLockError::Poisoned(guard)) => {
42                Err(TryLockError::Poisoned(PoisonError::new(Self {
43                    lifetime: PhantomData,
44                    data: vec![guard.into_inner()],
45                })))
46            }
47            Err(TryLockError::WouldBlock) => Err(TryLockError::WouldBlock),
48        }
49    }
50
51    pub fn raw_top_mut(&mut self) -> *mut T {
52        let guard: *mut MutexGuard<T> = self.data.last_mut().unwrap();
53        unsafe { &mut **guard }
54    }
55
56    /// Obtain a shared reference to the top of the stack.
57    pub fn top(&self) -> &T {
58        &*self.data.last().unwrap()
59    }
60
61    /// Obtain a mutable reference to the top of the stack.
62    pub fn top_mut(&mut self) -> &mut T {
63        &mut *self.data.last_mut().unwrap()
64    }
65
66    /// Is this MutRefStack currently at its root?
67    pub fn is_at_root(&self) -> bool {
68        self.data.len() == 1
69    }
70
71    fn handle_trylock_result(
72        &mut self,
73        guard: TryLockResult<MutexGuard<'root, T>>,
74        ignore_poison: bool,
75    ) -> Result<&mut T, TryLockError<()>> {
76        match (guard, ignore_poison) {
77            (Ok(guard), _) => {
78                self.data.push(guard);
79                Ok(self.top_mut())
80            }
81            (Err(TryLockError::Poisoned(guard)), true) => {
82                self.data.push(guard.into_inner());
83                Ok(self.top_mut())
84            }
85            (Err(TryLockError::Poisoned(_guard)), false) => {
86                Err(TryLockError::Poisoned(PoisonError::new(())))
87            }
88            (Err(TryLockError::WouldBlock), _) => Err(TryLockError::WouldBlock),
89        }
90    }
91
92    fn handle_move_trylock_result(
93        &mut self,
94        guard: TryLockResult<MutexGuard<'root, T>>,
95        ignore_poison: bool,
96    ) -> Result<&mut T, MoveError> {
97        match (guard, ignore_poison) {
98            (Ok(guard), _) => {
99                self.data.push(guard);
100                Ok(self.top_mut())
101            }
102            (Err(TryLockError::Poisoned(guard)), true) => {
103                self.data.push(guard.into_inner());
104                Ok(self.top_mut())
105            }
106            (Err(TryLockError::Poisoned(_guard)), false) => Err(MoveError::Poisoned),
107            (Err(TryLockError::WouldBlock), _) => Err(MoveError::WouldBlock),
108        }
109    }
110
111    /// Inject a new reference to the top of the stack. The reference still must live
112    /// as long as the root of the stack.
113    pub fn inject_top(
114        &mut self,
115        new_top: &'root Mutex<T>,
116        ignore_poison: bool,
117    ) -> Result<&mut T, TryLockError<()>> {
118        let new_top: *const Mutex<T> = new_top;
119        let guard = unsafe { (*new_top).try_lock() };
120        self.handle_trylock_result(guard, ignore_poison)
121    }
122
123    /// Inject a new reference to the top of the stack. The reference still must live
124    /// as long as the root of the stack.
125    pub fn inject_with(
126        &mut self,
127        f: impl FnOnce(&mut T) -> Option<&'root Mutex<T>>,
128        ignore_poison: bool,
129    ) -> Option<Result<&mut T, TryLockError<()>>> {
130        let old_top: *mut T = self.raw_top_mut();
131        let new_top: &Mutex<T> = unsafe { f(&mut *old_top)? };
132        let new_top: *const Mutex<T> = new_top;
133        let guard = unsafe { (*new_top).try_lock() };
134        Some(self.handle_trylock_result(guard, ignore_poison))
135    }
136
137    /// Descend into the recursive data structure, returning a mutable reference to the new top element.
138    /// Rust's borrow checker enforces that the closure cannot inject any lifetime (other than `'static`),
139    /// because the closure must work for any lifetime `'node`.
140    pub fn descend_with(
141        &mut self,
142        f: impl for<'node> FnOnce(&'node mut T) -> Option<&'node Mutex<T>>,
143        ignore_poison: bool,
144    ) -> Option<Result<&mut T, TryLockError<()>>> {
145        let old_top: *mut T = self.raw_top_mut();
146        let new_top: &Mutex<T> = unsafe { f(&mut *old_top)? };
147        let new_top: *const Mutex<T> = new_top;
148        let guard = unsafe { (*new_top).try_lock() };
149        Some(self.handle_trylock_result(guard, ignore_poison))
150    }
151
152    /// Ascend back up from the recursive data structure, returning a mutable reference to the new top element, if it changed.
153    /// If we are not currently at the root, ascend and return a reference to the new top.
154    /// If we are already at the root, returns None (the top is the root and does not change).
155    pub fn ascend(&mut self) -> Option<&mut T> {
156        match self.data.len() {
157            0 => unreachable!("root pointer must always exist"),
158            1 => None,
159            _ => {
160                self.data.pop();
161                Some(self.top_mut())
162            }
163        }
164    }
165
166    /// Ascend back up from the recursive data structure while the given closure returns `true`, returning a mutable reference to the new top element.
167    /// If we are not currently at the root, and the predicate returns `true`, ascend and continue.
168    /// If we are already at the root, or if the predicate returned false, returns a reference to the top element.
169    pub fn ascend_while<P>(&mut self, mut predicate: P) -> &mut T
170    where
171        P: FnMut(&mut T) -> bool,
172    {
173        while !self.is_at_root() && predicate(self.top_mut()) {
174            let Some(_) = self.ascend() else {
175                unreachable!();
176            };
177        }
178        self.top_mut()
179    }
180
181    /// Ascend from, descend from, inject a new stack top, or stay at the current node,
182    /// based on the return value of the closure.
183    pub fn move_with<F>(&mut self, f: F, ignore_poison: bool) -> Result<&mut T, MoveError>
184    where
185        F: for<'a> FnOnce(&'a mut T) -> MoveDecision<'root, 'a, T>,
186    {
187        let old_top: *mut T = self.raw_top_mut();
188        let result = unsafe { f(&mut *old_top) };
189        match result {
190            MoveDecision::Ascend => self.ascend().ok_or(MoveError::AscendAtRoot),
191            MoveDecision::Stay => Ok(self.top_mut()),
192            MoveDecision::Inject(new_top) | MoveDecision::Descend(new_top) => {
193                let new_top: *const Mutex<T> = new_top;
194                let guard = unsafe { (*new_top).try_lock() };
195                self.handle_move_trylock_result(guard, ignore_poison)
196            }
197        }
198    }
199
200    pub async fn move_with_async<F>(
201        &mut self,
202        f: F,
203        ignore_poison: bool,
204    ) -> Result<&mut T, MoveError>
205    where
206        F: for<'a> FnOnce(
207            &'a mut T,
208        )
209            -> Pin<Box<dyn Future<Output = MoveDecision<'root, 'a, T>> + 'a>>,
210    {
211        let old_top: *mut T = self.raw_top_mut();
212        let result = unsafe { f(&mut *old_top) }.await;
213        match result {
214            MoveDecision::Ascend => self.ascend().ok_or(MoveError::AscendAtRoot),
215            MoveDecision::Stay => Ok(self.top_mut()),
216            MoveDecision::Inject(new_top) | MoveDecision::Descend(new_top) => {
217                let new_top: *const Mutex<T> = new_top;
218                let guard = unsafe { (*new_top).try_lock() };
219                self.handle_move_trylock_result(guard, ignore_poison)
220            }
221        }
222    }
223
224    /// Return reference to the top element of this stack, forgetting about the stack entirely.
225    /// Note that this leaks all `MutexGuard`s above the top.
226    pub fn into_top(mut self) -> MutexGuard<'root, T> {
227        let ret = self.data.pop().unwrap();
228        unsafe {
229            // We need to not drop the parent MutexGuards, if any
230            self.data.set_len(0);
231        }
232        ret
233    }
234
235    /// Pop all `MutexGuard`s off the stack and go back to the root.
236    pub fn to_root(&mut self) -> &mut T {
237        for _ in 1..self.data.len() {
238            // We need to drop the MutexGuard's in the reverse order.
239            // Vec::truncate does not specify drop order, but it's probably wrong anyway.
240            self.data.pop();
241        }
242        self.top_mut()
243    }
244}
245
246impl<'root, T: ?Sized> Drop for MutexGuardStack<'root, T> {
247    fn drop(&mut self) {
248        for _ in 0..self.data.len() {
249            // We need to drop the MutexGuard's in the reverse order.
250            // Vec::truncate does not specify drop order, but it's probably wrong anyway.
251            self.data.pop();
252        }
253    }
254}