orchestrator_lock/lib.rs
1//! `orchestrator_lock` provides a specialized mutex implementation for scenarios
2//! where fine-grained control over mutex access is required. Unlike a standard
3//! mutex where any code with a reference can attempt to acquire the lock, this
4//! implementation separates the concerns of lock orchestration from lock usage.
5//!
6//! # Core Concepts
7//!
8//! * **OrchestratorMutex**: The central coordinator that owns the protected value and
9//! controls access to it.
10//!
11//! * **Granter**: A capability token that allows the orchestrator to grant lock access
12//! to a specific locker.
13//!
14//! * **MutexLocker**: The component that can acquire and use the lock, but only when
15//! explicitly granted access by the orchestrator.
16//!
17//! * **MutexGuard**: Provides access to the protected value, similar to a standard
18//! mutex guard.
19//!
20//! # Example
21//! ```
22//! use tokio::time::Duration;
23//!
24//! use orchestrator_lock::OrchestratorMutex;
25//!
26//! #[tokio::main(flavor = "current_thread")]
27//! async fn main() {
28//! // Create a shared counter with initial value 0
29//! let mut orchestrator = OrchestratorMutex::new(0);
30//!
31//! // Create two granter/locker pairs
32//! let (mut granter1, mut locker1) = orchestrator.add_locker();
33//! let (mut granter2, mut locker2) = orchestrator.add_locker();
34//!
35//! // Task 1: Increments by 1 each time
36//! let task1 = tokio::spawn(async move {
37//! let expected = [0, 2, 6];
38//! for i in 0..3 {
39//! if let Some(mut guard) = locker1.acquire().await {
40//! assert_eq!(*guard, expected[i]);
41//! *guard += 1;
42//! tokio::time::sleep(Duration::from_millis(10)).await;
43//! }
44//! }
45//! locker1
46//! });
47//!
48//! // Task 2: Multiplies by 2 each time
49//! let task2 = tokio::spawn(async move {
50//! let expected = [1, 3, 7];
51//! for i in 0..3 {
52//! if let Some(mut guard) = locker2.acquire().await {
53//! assert_eq!(*guard, expected[i]);
54//! *guard *= 2;
55//! tokio::time::sleep(Duration::from_millis(10)).await;
56//! }
57//! }
58//! locker2
59//! });
60//!
61//! // Orchestration: Alternate between the two tasks
62//! for i in 0..3 {
63//! // Grant access to task 1
64//! let task1_holding = orchestrator.grant_access(&mut granter1).await.unwrap();
65//! task1_holding.await;
66//!
67//! // Grant access to task 2
68//! let task2_holding = orchestrator.grant_access(&mut granter2).await.unwrap();
69//! task2_holding.await;
70//! }
71//! assert_eq!(*orchestrator.acquire().await, 14);
72//! // Clean up
73//! let _ = task1.await.unwrap();
74//! let _ = task2.await.unwrap();
75//! }
76//! ```
77
78use std::ops::{Deref, DerefMut};
79use std::sync::{Arc, Weak};
80
81use awaitable_bool::AwaitableBool;
82use tokio::select;
83use tokio::sync::Mutex;
84
85pub mod error {
86 #[derive(Debug, PartialEq, Eq)]
87 pub struct GrantError;
88
89 #[derive(Debug, PartialEq, Eq)]
90 pub enum TryAcquireError {
91 /// [grant_access](super::OrchestratorMutex::grant_access) has not been
92 /// called with the corresponding [Granter](super::Granter).
93 AccessDenied,
94 /// Either the corresponding [Granter](super::Granter) or the
95 /// [OrchestratorMutex](super::OrchestratorMutex) has been dropped.
96 Inaccessible,
97 }
98}
99
100pub struct OrchestratorMutex<T> {
101 inner: Arc<Mutex<T>>,
102 dropped: Arc<AwaitableBool>,
103}
104
105impl<T> Drop for OrchestratorMutex<T> {
106 fn drop(&mut self) {
107 self.dropped.set_true();
108 }
109}
110
111pub struct OwnedMutexGuard<T> {
112 // Field ordering ensures that the inner guard is dropped before the
113 // finished notification is sent.
114 inner: tokio::sync::OwnedMutexGuard<T>,
115 _finished: Finished,
116}
117
118struct Finished(Arc<AwaitableBool>);
119
120pub struct Granter<T> {
121 inner: Weak<Mutex<T>>,
122 tx: relay_channel::Sender<OwnedMutexGuard<T>>,
123}
124
125pub struct MutexLocker<T> {
126 mutex_dropped: Arc<AwaitableBool>,
127 rx: relay_channel::Receiver<OwnedMutexGuard<T>>,
128}
129
130pub struct MutexGuard<'a, T> {
131 guard: OwnedMutexGuard<T>,
132 _locker: &'a mut MutexLocker<T>,
133}
134
135impl<T> OrchestratorMutex<T> {
136 pub fn new(value: T) -> Self {
137 let dropped = Arc::new(AwaitableBool::new(false));
138 Self {
139 inner: Arc::new(Mutex::new(value)),
140 dropped,
141 }
142 }
143
144 pub fn add_locker(&self) -> (Granter<T>, MutexLocker<T>) {
145 let (tx, rx) = relay_channel::channel();
146 let inner = Arc::downgrade(&self.inner);
147 let mutex_dropped = self.dropped.clone();
148 (Granter { inner, tx }, MutexLocker { mutex_dropped, rx })
149 }
150
151 /// Directly acquire the underlying lock.
152 pub async fn acquire(&self) -> tokio::sync::MutexGuard<'_, T> {
153 self.inner.lock().await
154 }
155
156 pub fn blocking_acquire(&self) -> tokio::sync::MutexGuard<'_, T> {
157 self.inner.blocking_lock()
158 }
159
160 /// Attempt to acquire the underlying lock, failing if the lock is already
161 /// held.
162 pub fn try_acquire(&self) -> Result<tokio::sync::MutexGuard<'_, T>, tokio::sync::TryLockError> {
163 self.inner.try_lock()
164 }
165
166 /// Grants lock access to the [MutexLocker] corresponding to the provided
167 /// [Granter].
168 ///
169 /// This function returns [Ok] once the corresponding [MutexLocker] has
170 /// called
171 /// [acquire](MutexLocker::acquire) (or [Err] if the [MutexLocker] has been
172 /// dropped).
173 /// The [Ok] variant contains a future which waits for the acquiring task to
174 /// drop its [MutexGuard].
175 ///
176 /// If the future in the [Ok] variant is dropped, the next call to
177 /// [grant_access](Self::grant_access) will have to wait for the current
178 /// [MutexGuard] to be dropped before it can grant access to the next
179 /// [MutexLocker]. If this is called multiple times in parallel, the
180 /// order in which the [MutexLocker]s are granted access is unspecified.
181 ///
182 /// # Panics
183 /// Panics if `granter` was created from a different [OrchestratorMutex].
184 pub async fn grant_access(
185 &self,
186 granter: &mut Granter<T>,
187 ) -> Result<impl Future<Output = ()>, error::GrantError> {
188 assert!(
189 Weak::ptr_eq(&granter.inner, &Arc::downgrade(&self.inner)),
190 "Granter is not associated with this OrchestratorMutex"
191 );
192 let inner_guard = self.inner.clone().lock_owned().await;
193 let finished = Arc::new(AwaitableBool::new(false));
194 let guard = OwnedMutexGuard {
195 inner: inner_guard,
196 _finished: Finished(Arc::clone(&finished)),
197 };
198 match granter.tx.send(guard).await {
199 Ok(()) => Ok(async move { finished.wait_true().await }),
200 Err(relay_channel::error::SendError(_)) => Err(error::GrantError),
201 }
202 }
203}
204
205impl<T> MutexLocker<T> {
206 /// Returns [None] if either the corresponding [Granter] or the
207 /// [OrchestratorMutex] has been dropped.
208 pub async fn acquire(&mut self) -> Option<MutexGuard<'_, T>> {
209 let result = select! {
210 result = self.rx.recv() => result,
211 () = self.mutex_dropped.wait_true() => None,
212 };
213 Some(MutexGuard {
214 guard: result?,
215 _locker: self,
216 })
217 }
218
219 pub fn try_acquire(&mut self) -> Result<MutexGuard<'_, T>, error::TryAcquireError> {
220 match self.rx.try_recv() {
221 Ok(guard) => Ok(MutexGuard {
222 guard,
223 _locker: self,
224 }),
225 Err(relay_channel::error::TryRecvError::Empty) => {
226 if self.mutex_dropped.is_true() {
227 Err(error::TryAcquireError::Inaccessible)
228 } else {
229 Err(error::TryAcquireError::AccessDenied)
230 }
231 }
232 Err(relay_channel::error::TryRecvError::Disconnected) => {
233 Err(error::TryAcquireError::Inaccessible)
234 }
235 }
236 }
237}
238
239impl<T> Deref for OwnedMutexGuard<T> {
240 type Target = T;
241
242 fn deref(&self) -> &Self::Target {
243 &self.inner
244 }
245}
246
247impl<T> DerefMut for OwnedMutexGuard<T> {
248 fn deref_mut(&mut self) -> &mut Self::Target {
249 &mut self.inner
250 }
251}
252
253impl Drop for Finished {
254 fn drop(&mut self) {
255 self.0.set_true();
256 }
257}
258
259impl<T> MutexGuard<'_, T> {
260 /// The lifetime parameter on [MutexGuard] is only for convenience (to help
261 /// avoid having multiple parallel calls to [acquire](MutexLocker::acquire)
262 /// and [try_acquire](MutexLocker::try_acquire)). The caller can choose to
263 /// instead
264 /// use this function to unwrap the underlying [OwnedMutexGuard] if it's
265 /// more convenient not to deal with the lifetime.
266 pub fn into_owned_guard(this: Self) -> OwnedMutexGuard<T> {
267 this.guard
268 }
269}
270
271impl<T> Deref for MutexGuard<'_, T> {
272 type Target = T;
273
274 fn deref(&self) -> &Self::Target {
275 &self.guard
276 }
277}
278
279impl<T> DerefMut for MutexGuard<'_, T> {
280 fn deref_mut(&mut self) -> &mut Self::Target {
281 &mut self.guard
282 }
283}
284
285#[cfg(test)]
286mod test;