orchestrator_lock/
lib.rs

1//! `orchestrator_lock` provides a specialized mutex implementation for scenarios
2//! where fine-grained control over mutex access is required. Unlike a standard
3//! mutex where any code with a reference can attempt to acquire the lock, this
4//! implementation separates the concerns of lock orchestration from lock usage.
5//!
6//! # Core Concepts
7//!
8//! * **OrchestratorMutex**: The central coordinator that owns the protected value and
9//!   controls access to it.
10//!
11//! * **Granter**: A capability token that allows the orchestrator to grant lock access
12//!   to a specific locker.
13//!
14//! * **MutexLocker**: The component that can acquire and use the lock, but only when
15//!   explicitly granted access by the orchestrator.
16//!
17//! * **MutexGuard**: Provides access to the protected value, similar to a standard
18//!   mutex guard.
19//!
20//! # Example
21//! ```
22//! use tokio::time::Duration;
23//!
24//! use orchestrator_lock::OrchestratorMutex;
25//!
26//! #[tokio::main(flavor = "current_thread")]
27//! async fn main() {
28//!     // Create a shared counter with initial value 0
29//!     let mut orchestrator = OrchestratorMutex::new(0);
30//!     
31//!     // Create two granter/locker pairs
32//!     let (mut granter1, mut locker1) = orchestrator.add_locker();
33//!     let (mut granter2, mut locker2) = orchestrator.add_locker();
34//!     
35//!     // Task 1: Increments by 1 each time
36//!     let task1 = tokio::spawn(async move {
37//!         let expected = [0, 2, 6];
38//!         for i in 0..3 {
39//!             if let Some(mut guard) = locker1.acquire().await {
40//!                 assert_eq!(*guard, expected[i]);
41//!                 *guard += 1;
42//!                 tokio::time::sleep(Duration::from_millis(10)).await;
43//!             }
44//!         }
45//!         locker1
46//!     });
47//!     
48//!     // Task 2: Multiplies by 2 each time
49//!     let task2 = tokio::spawn(async move {
50//!         let expected = [1, 3, 7];
51//!         for i in 0..3 {
52//!             if let Some(mut guard) = locker2.acquire().await {
53//!                 assert_eq!(*guard, expected[i]);
54//!                 *guard *= 2;
55//!                 tokio::time::sleep(Duration::from_millis(10)).await;
56//!             }
57//!         }
58//!         locker2
59//!     });
60//!     
61//!     // Orchestration: Alternate between the two tasks
62//!     for i in 0..3 {
63//!         // Grant access to task 1
64//!         let task1_holding = orchestrator.grant_access(&mut granter1).await.unwrap();
65//!         task1_holding.await;
66//!         
67//!         // Grant access to task 2
68//!         let task2_holding = orchestrator.grant_access(&mut granter2).await.unwrap();
69//!         task2_holding.await;
70//!     }
71//!     assert_eq!(*orchestrator.acquire().await, 14);
72//!     // Clean up
73//!     let _ = task1.await.unwrap();
74//!     let _ = task2.await.unwrap();
75//! }
76//! ```
77
78use std::ops::{Deref, DerefMut};
79use std::sync::{Arc, Weak};
80
81use awaitable_bool::AwaitableBool;
82use tokio::select;
83use tokio::sync::Mutex;
84
85pub mod error {
86    #[derive(Debug, PartialEq, Eq)]
87    pub struct GrantError;
88
89    #[derive(Debug, PartialEq, Eq)]
90    pub enum TryAcquireError {
91        /// [grant_access](super::OrchestratorMutex::grant_access) has not been
92        /// called with the corresponding [Granter](super::Granter).
93        AccessDenied,
94        /// Either the corresponding [Granter](super::Granter) or the
95        /// [OrchestratorMutex](super::OrchestratorMutex) has been dropped.
96        Inaccessible,
97    }
98}
99
100pub struct OrchestratorMutex<T> {
101    inner: Arc<Mutex<T>>,
102    dropped: Arc<AwaitableBool>,
103}
104
105impl<T> Drop for OrchestratorMutex<T> {
106    fn drop(&mut self) {
107        self.dropped.set_true();
108    }
109}
110
111pub struct OwnedMutexGuard<T> {
112    // Field ordering ensures that the inner guard is dropped before the
113    // finished notification is sent.
114    inner: tokio::sync::OwnedMutexGuard<T>,
115    _finished: Finished,
116}
117
118struct Finished(Arc<AwaitableBool>);
119
120pub struct Granter<T> {
121    inner: Weak<Mutex<T>>,
122    tx: relay_channel::Sender<OwnedMutexGuard<T>>,
123}
124
125pub struct MutexLocker<T> {
126    mutex_dropped: Arc<AwaitableBool>,
127    rx: relay_channel::Receiver<OwnedMutexGuard<T>>,
128}
129
130pub struct MutexGuard<'a, T> {
131    guard: OwnedMutexGuard<T>,
132    _locker: &'a mut MutexLocker<T>,
133}
134
135impl<T> OrchestratorMutex<T> {
136    pub fn new(value: T) -> Self {
137        let dropped = Arc::new(AwaitableBool::new(false));
138        Self {
139            inner: Arc::new(Mutex::new(value)),
140            dropped,
141        }
142    }
143
144    pub fn add_locker(&self) -> (Granter<T>, MutexLocker<T>) {
145        let (tx, rx) = relay_channel::channel();
146        let inner = Arc::downgrade(&self.inner);
147        let mutex_dropped = self.dropped.clone();
148        (Granter { inner, tx }, MutexLocker { mutex_dropped, rx })
149    }
150
151    /// Directly acquire the underlying lock.
152    pub async fn acquire(&self) -> tokio::sync::MutexGuard<'_, T> {
153        self.inner.lock().await
154    }
155
156    pub fn blocking_acquire(&self) -> tokio::sync::MutexGuard<'_, T> {
157        self.inner.blocking_lock()
158    }
159
160    /// Attempt to acquire the underlying lock, failing if the lock is already
161    /// held.
162    pub fn try_acquire(&self) -> Result<tokio::sync::MutexGuard<'_, T>, tokio::sync::TryLockError> {
163        self.inner.try_lock()
164    }
165
166    /// Grants lock access to the [MutexLocker] corresponding to the provided
167    /// [Granter].
168    ///
169    /// This function returns [Ok] once the corresponding [MutexLocker] has
170    /// called
171    /// [acquire](MutexLocker::acquire) (or [Err] if the [MutexLocker] has been
172    /// dropped).
173    /// The [Ok] variant contains a future which waits for the acquiring task to
174    /// drop its [MutexGuard].
175    ///
176    /// If the future in the [Ok] variant is dropped, the next call to
177    /// [grant_access](Self::grant_access) will have to wait for the current
178    /// [MutexGuard] to be dropped before it can grant access to the next
179    /// [MutexLocker]. If this is called multiple times in parallel, the
180    /// order in which the [MutexLocker]s are granted access is unspecified.
181    ///
182    /// # Panics
183    /// Panics if `granter` was created from a different [OrchestratorMutex].
184    pub async fn grant_access(
185        &self,
186        granter: &mut Granter<T>,
187    ) -> Result<impl Future<Output = ()>, error::GrantError> {
188        assert!(
189            Weak::ptr_eq(&granter.inner, &Arc::downgrade(&self.inner)),
190            "Granter is not associated with this OrchestratorMutex"
191        );
192        let inner_guard = self.inner.clone().lock_owned().await;
193        let finished = Arc::new(AwaitableBool::new(false));
194        let guard = OwnedMutexGuard {
195            inner: inner_guard,
196            _finished: Finished(Arc::clone(&finished)),
197        };
198        match granter.tx.send(guard).await {
199            Ok(()) => Ok(async move { finished.wait_true().await }),
200            Err(relay_channel::error::SendError(_)) => Err(error::GrantError),
201        }
202    }
203}
204
205impl<T> MutexLocker<T> {
206    /// Returns [None] if either the corresponding [Granter] or the
207    /// [OrchestratorMutex] has been dropped.
208    pub async fn acquire(&mut self) -> Option<MutexGuard<'_, T>> {
209        let result = select! {
210            result = self.rx.recv() => result,
211            () = self.mutex_dropped.wait_true() => None,
212        };
213        Some(MutexGuard {
214            guard: result?,
215            _locker: self,
216        })
217    }
218
219    pub fn try_acquire(&mut self) -> Result<MutexGuard<'_, T>, error::TryAcquireError> {
220        match self.rx.try_recv() {
221            Ok(guard) => Ok(MutexGuard {
222                guard,
223                _locker: self,
224            }),
225            Err(relay_channel::error::TryRecvError::Empty) => {
226                if self.mutex_dropped.is_true() {
227                    Err(error::TryAcquireError::Inaccessible)
228                } else {
229                    Err(error::TryAcquireError::AccessDenied)
230                }
231            }
232            Err(relay_channel::error::TryRecvError::Disconnected) => {
233                Err(error::TryAcquireError::Inaccessible)
234            }
235        }
236    }
237}
238
239impl<T> Deref for OwnedMutexGuard<T> {
240    type Target = T;
241
242    fn deref(&self) -> &Self::Target {
243        &self.inner
244    }
245}
246
247impl<T> DerefMut for OwnedMutexGuard<T> {
248    fn deref_mut(&mut self) -> &mut Self::Target {
249        &mut self.inner
250    }
251}
252
253impl Drop for Finished {
254    fn drop(&mut self) {
255        self.0.set_true();
256    }
257}
258
259impl<T> MutexGuard<'_, T> {
260    /// The lifetime parameter on [MutexGuard] is only for convenience (to help
261    /// avoid having multiple parallel calls to [acquire](MutexLocker::acquire)
262    /// and [try_acquire](MutexLocker::try_acquire)). The caller can choose to
263    /// instead
264    /// use this function to unwrap the underlying [OwnedMutexGuard] if it's
265    /// more convenient not to deal with the lifetime.
266    pub fn into_owned_guard(this: Self) -> OwnedMutexGuard<T> {
267        this.guard
268    }
269}
270
271impl<T> Deref for MutexGuard<'_, T> {
272    type Target = T;
273
274    fn deref(&self) -> &Self::Target {
275        &self.guard
276    }
277}
278
279impl<T> DerefMut for MutexGuard<'_, T> {
280    fn deref_mut(&mut self) -> &mut Self::Target {
281        &mut self.guard
282    }
283}
284
285#[cfg(test)]
286mod test;