1use std::{
2 future::Future,
3 marker::PhantomData,
4 pin::Pin,
5 sync::{Mutex, MutexGuard, PoisonError, TryLockError, TryLockResult},
6};
7
8pub struct MutexGuardStack<'root, T: ?Sized> {
9 lifetime: PhantomData<&'root mut T>,
11 data: Vec<MutexGuard<'root, T>>,
15}
16
17pub enum MoveDecision<'root, 'this, T: ?Sized> {
18 Ascend,
19 Stay,
20 Descend(&'this Mutex<T>),
21 Inject(&'root Mutex<T>),
22}
23
24pub enum MoveError {
25 AscendAtRoot,
26 Poisoned,
27 WouldBlock,
28}
29
30impl<'root, T: ?Sized> MutexGuardStack<'root, T> {
31 pub fn new(root: &'root Mutex<T>) -> TryLockResult<Self> {
34 let root: *const Mutex<T> = root;
35 let guard = unsafe { (*root).try_lock() };
36 match guard {
37 Ok(guard) => Ok(Self {
38 lifetime: PhantomData,
39 data: vec![guard],
40 }),
41 Err(TryLockError::Poisoned(guard)) => {
42 Err(TryLockError::Poisoned(PoisonError::new(Self {
43 lifetime: PhantomData,
44 data: vec![guard.into_inner()],
45 })))
46 }
47 Err(TryLockError::WouldBlock) => Err(TryLockError::WouldBlock),
48 }
49 }
50
51 pub fn raw_top_mut(&mut self) -> *mut T {
52 let guard: *mut MutexGuard<T> = self.data.last_mut().unwrap();
53 unsafe { &mut **guard }
54 }
55
56 pub fn top(&self) -> &T {
58 &*self.data.last().unwrap()
59 }
60
61 pub fn top_mut(&mut self) -> &mut T {
63 &mut *self.data.last_mut().unwrap()
64 }
65
66 pub fn is_at_root(&self) -> bool {
68 self.data.len() == 1
69 }
70
71 fn handle_trylock_result(
72 &mut self,
73 guard: TryLockResult<MutexGuard<'root, T>>,
74 ignore_poison: bool,
75 ) -> Result<&mut T, TryLockError<()>> {
76 match (guard, ignore_poison) {
77 (Ok(guard), _) => {
78 self.data.push(guard);
79 Ok(self.top_mut())
80 }
81 (Err(TryLockError::Poisoned(guard)), true) => {
82 self.data.push(guard.into_inner());
83 Ok(self.top_mut())
84 }
85 (Err(TryLockError::Poisoned(_guard)), false) => {
86 Err(TryLockError::Poisoned(PoisonError::new(())))
87 }
88 (Err(TryLockError::WouldBlock), _) => Err(TryLockError::WouldBlock),
89 }
90 }
91
92 fn handle_move_trylock_result(
93 &mut self,
94 guard: TryLockResult<MutexGuard<'root, T>>,
95 ignore_poison: bool,
96 ) -> Result<&mut T, MoveError> {
97 match (guard, ignore_poison) {
98 (Ok(guard), _) => {
99 self.data.push(guard);
100 Ok(self.top_mut())
101 }
102 (Err(TryLockError::Poisoned(guard)), true) => {
103 self.data.push(guard.into_inner());
104 Ok(self.top_mut())
105 }
106 (Err(TryLockError::Poisoned(_guard)), false) => Err(MoveError::Poisoned),
107 (Err(TryLockError::WouldBlock), _) => Err(MoveError::WouldBlock),
108 }
109 }
110
111 pub fn inject_top(
114 &mut self,
115 new_top: &'root Mutex<T>,
116 ignore_poison: bool,
117 ) -> Result<&mut T, TryLockError<()>> {
118 let new_top: *const Mutex<T> = new_top;
119 let guard = unsafe { (*new_top).try_lock() };
120 self.handle_trylock_result(guard, ignore_poison)
121 }
122
123 pub fn inject_with(
126 &mut self,
127 f: impl FnOnce(&mut T) -> Option<&'root Mutex<T>>,
128 ignore_poison: bool,
129 ) -> Option<Result<&mut T, TryLockError<()>>> {
130 let old_top: *mut T = self.raw_top_mut();
131 let new_top: &Mutex<T> = unsafe { f(&mut *old_top)? };
132 let new_top: *const Mutex<T> = new_top;
133 let guard = unsafe { (*new_top).try_lock() };
134 Some(self.handle_trylock_result(guard, ignore_poison))
135 }
136
137 pub fn descend_with(
141 &mut self,
142 f: impl for<'node> FnOnce(&'node mut T) -> Option<&'node Mutex<T>>,
143 ignore_poison: bool,
144 ) -> Option<Result<&mut T, TryLockError<()>>> {
145 let old_top: *mut T = self.raw_top_mut();
146 let new_top: &Mutex<T> = unsafe { f(&mut *old_top)? };
147 let new_top: *const Mutex<T> = new_top;
148 let guard = unsafe { (*new_top).try_lock() };
149 Some(self.handle_trylock_result(guard, ignore_poison))
150 }
151
152 pub fn ascend(&mut self) -> Option<&mut T> {
156 match self.data.len() {
157 0 => unreachable!("root pointer must always exist"),
158 1 => None,
159 _ => {
160 self.data.pop();
161 Some(self.top_mut())
162 }
163 }
164 }
165
166 pub fn ascend_while<P>(&mut self, mut predicate: P) -> &mut T
170 where
171 P: FnMut(&mut T) -> bool,
172 {
173 while !self.is_at_root() && predicate(self.top_mut()) {
174 let Some(_) = self.ascend() else {
175 unreachable!();
176 };
177 }
178 self.top_mut()
179 }
180
181 pub fn move_with<F>(&mut self, f: F, ignore_poison: bool) -> Result<&mut T, MoveError>
184 where
185 F: for<'a> FnOnce(&'a mut T) -> MoveDecision<'root, 'a, T>,
186 {
187 let old_top: *mut T = self.raw_top_mut();
188 let result = unsafe { f(&mut *old_top) };
189 match result {
190 MoveDecision::Ascend => self.ascend().ok_or(MoveError::AscendAtRoot),
191 MoveDecision::Stay => Ok(self.top_mut()),
192 MoveDecision::Inject(new_top) | MoveDecision::Descend(new_top) => {
193 let new_top: *const Mutex<T> = new_top;
194 let guard = unsafe { (*new_top).try_lock() };
195 self.handle_move_trylock_result(guard, ignore_poison)
196 }
197 }
198 }
199
200 pub async fn move_with_async<F>(
201 &mut self,
202 f: F,
203 ignore_poison: bool,
204 ) -> Result<&mut T, MoveError>
205 where
206 F: for<'a> FnOnce(
207 &'a mut T,
208 )
209 -> Pin<Box<dyn Future<Output = MoveDecision<'root, 'a, T>> + 'a>>,
210 {
211 let old_top: *mut T = self.raw_top_mut();
212 let result = unsafe { f(&mut *old_top) }.await;
213 match result {
214 MoveDecision::Ascend => self.ascend().ok_or(MoveError::AscendAtRoot),
215 MoveDecision::Stay => Ok(self.top_mut()),
216 MoveDecision::Inject(new_top) | MoveDecision::Descend(new_top) => {
217 let new_top: *const Mutex<T> = new_top;
218 let guard = unsafe { (*new_top).try_lock() };
219 self.handle_move_trylock_result(guard, ignore_poison)
220 }
221 }
222 }
223
224 pub fn into_top(mut self) -> MutexGuard<'root, T> {
227 let ret = self.data.pop().unwrap();
228 unsafe {
229 self.data.set_len(0);
231 }
232 ret
233 }
234
235 pub fn to_root(&mut self) -> &mut T {
237 for _ in 1..self.data.len() {
238 self.data.pop();
241 }
242 self.top_mut()
243 }
244}
245
246impl<'root, T: ?Sized> Drop for MutexGuardStack<'root, T> {
247 fn drop(&mut self) {
248 for _ in 0..self.data.len() {
249 self.data.pop();
252 }
253 }
254}