orchestrator_lock/
lib.rs

1//! `orchestrator_lock` provides a specialized mutex implementation for scenarios
2//! where fine-grained control over mutex access is required. Unlike a standard
3//! mutex where any code with a reference can attempt to acquire the lock, this
4//! implementation separates the concerns of lock orchestration from lock usage.
5//!
6//! # Core Concepts
7//!
8//! * **OrchestratorMutex**: The central coordinator that owns the protected value and
9//!   controls access to it.
10//!
11//! * **Granter**: A capability token that allows the orchestrator to grant lock access
12//!   to a specific locker.
13//!
14//! * **MutexLocker**: The component that can acquire and use the lock, but only when
15//!   explicitly granted access by the orchestrator.
16//!
17//! * **MutexGuard**: Provides access to the protected value, similar to a standard
18//!   mutex guard.
19//!
20//! # Example
21//! ```
22//! use tokio::time::Duration;
23//!
24//! use orchestrator_lock::OrchestratorMutex;
25//!
26//! #[tokio::main(flavor = "current_thread")]
27//! async fn main() {
28//!     // Create a shared counter with initial value 0
29//!     let mut orchestrator = OrchestratorMutex::new(0);
30//!     
31//!     // Create two granter/locker pairs
32//!     let (mut granter1, mut locker1) = orchestrator.add_locker();
33//!     let (mut granter2, mut locker2) = orchestrator.add_locker();
34//!     
35//!     // Task 1: Increments by 1 each time
36//!     let task1 = tokio::spawn(async move {
37//!         let expected = [0, 2, 6];
38//!         for i in 0..3 {
39//!             if let Some(mut guard) = locker1.acquire().await {
40//!                 assert_eq!(*guard, expected[i]);
41//!                 *guard += 1;
42//!                 tokio::time::sleep(Duration::from_millis(10)).await;
43//!             }
44//!         }
45//!         locker1
46//!     });
47//!     
48//!     // Task 2: Multiplies by 2 each time
49//!     let task2 = tokio::spawn(async move {
50//!         let expected = [1, 3, 7];
51//!         for i in 0..3 {
52//!             if let Some(mut guard) = locker2.acquire().await {
53//!                 assert_eq!(*guard, expected[i]);
54//!                 *guard *= 2;
55//!                 tokio::time::sleep(Duration::from_millis(10)).await;
56//!             }
57//!         }
58//!         locker2
59//!     });
60//!     
61//!     // Orchestration: Alternate between the two tasks
62//!     let expected = [1, 2, 3, 6, 7, 14];
63//!     for i in 0..3 {
64//!         // Grant access to task 1
65//!         let grant_future = orchestrator.grant_access(&mut granter1).await.unwrap();
66//!         let guard = grant_future.await;
67//!         assert_eq!(*guard, expected[i * 2]);
68//!         drop(guard);
69//!         
70//!         // Grant access to task 2
71//!         let grant_future = orchestrator.grant_access(&mut granter2).await.unwrap();
72//!         let guard = grant_future.await;
73//!         assert_eq!(*guard, expected[i * 2 + 1]);
74//!         drop(guard);
75//!     }
76//!     // Clean up
77//!     let _ = task1.await.unwrap();
78//!     let _ = task2.await.unwrap();
79//! }
80//! ```
81
82use std::ops::{Deref, DerefMut};
83use std::sync::{Arc, Weak};
84
85use tokio::select;
86use tokio::sync::{Mutex, OwnedMutexGuard};
87use tokio_util::sync::{CancellationToken, DropGuard};
88
89pub mod error {
90    #[derive(Debug, PartialEq, Eq)]
91    pub struct GrantError;
92
93    #[derive(Debug, PartialEq, Eq)]
94    pub enum TryAcquireError {
95        /// [grant_access](super::OrchestratorMutex::grant_access) has not been
96        /// called with the corresponding [Granter](super::Granter).
97        AccessDenied,
98        /// Either the corresponding [Granter](super::Granter) or the
99        /// [OrchestratorMutex](super::OrchestratorMutex) has been dropped.
100        Inaccessible,
101    }
102}
103
104pub struct OrchestratorMutex<T> {
105    inner: Arc<Mutex<T>>,
106    dropped: CancellationToken,
107    _drop_guard: DropGuard,
108}
109
110pub struct Granter<T> {
111    inner: Weak<Mutex<T>>,
112    tx: relay_channel::Sender<OwnedMutexGuard<T>>,
113}
114
115pub struct MutexLocker<T> {
116    mutex_dropped: CancellationToken,
117    rx: relay_channel::Receiver<OwnedMutexGuard<T>>,
118}
119
120pub struct MutexGuard<'a, T> {
121    guard: OwnedMutexGuard<T>,
122    _locker: &'a mut MutexLocker<T>,
123}
124
125impl<T> OrchestratorMutex<T> {
126    pub fn new(value: T) -> Self {
127        let dropped = CancellationToken::new();
128        let drop_guard = dropped.clone().drop_guard();
129        Self {
130            inner: Arc::new(Mutex::new(value)),
131            dropped,
132            _drop_guard: drop_guard,
133        }
134    }
135
136    pub fn add_locker(&self) -> (Granter<T>, MutexLocker<T>) {
137        let (tx, rx) = relay_channel::channel();
138        let inner = Arc::downgrade(&self.inner);
139        let mutex_dropped = self.dropped.clone();
140        (Granter { inner, tx }, MutexLocker { mutex_dropped, rx })
141    }
142
143    /// Directly acquire the underlying lock.
144    pub async fn acquire(&self) -> tokio::sync::MutexGuard<'_, T> {
145        self.inner.lock().await
146    }
147
148    /// Attempt to acquire the underlying lock, failing if the lock is already
149    /// held.
150    pub fn try_acquire(&self) -> Result<tokio::sync::MutexGuard<'_, T>, tokio::sync::TryLockError> {
151        self.inner.try_lock()
152    }
153
154    /// Grants lock access to the [MutexLocker] corresponding to the provided
155    /// [Granter].
156    ///
157    /// This function returns [Ok] once the corresponding [MutexLocker] has
158    /// called
159    /// [acquire](MutexLocker::acquire) (or [Err] if the [MutexLocker] has been
160    /// dropped).
161    /// The [Ok] variant contains a future which waits for the [MutexGuard] to
162    /// be dropped and then re-acquires the mutex so the caller can see what
163    /// changes were made.
164    ///
165    /// If the future in the [Ok] variant is dropped, the next call to
166    /// [grant_access](Self::grant_access) will have to wait for the current
167    /// [MutexGuard] to be dropped before it can grant access to the next
168    /// [MutexLocker]. If this is called multiple times in parallel, the
169    /// order in which the [MutexLocker]s are granted access is unspecified.
170    ///
171    /// # Panics
172    /// Panics if `granter` was created from a different [OrchestratorMutex].
173    pub async fn grant_access(
174        &self,
175        granter: &mut Granter<T>,
176    ) -> Result<impl Future<Output = tokio::sync::MutexGuard<'_, T>>, error::GrantError> {
177        assert!(
178            Weak::ptr_eq(&granter.inner, &Arc::downgrade(&self.inner)),
179            "Granter is not associated with this OrchestratorMutex"
180        );
181        match granter.tx.send(self.inner.clone().lock_owned().await).await {
182            Ok(()) => Ok(self.inner.lock()),
183            Err(relay_channel::error::SendError(_)) => Err(error::GrantError),
184        }
185    }
186}
187
188impl<T> MutexLocker<T> {
189    /// Returns [None] if either the corresponding [Granter] or the
190    /// [OrchestratorMutex] has been dropped.
191    pub async fn acquire(&mut self) -> Option<MutexGuard<'_, T>> {
192        let result = select! {
193            result = self.rx.recv() => result,
194            () = self.mutex_dropped.cancelled() => None,
195        };
196        Some(MutexGuard {
197            guard: result?,
198            _locker: self,
199        })
200    }
201
202    pub fn try_acquire(&mut self) -> Result<MutexGuard<'_, T>, error::TryAcquireError> {
203        match self.rx.try_recv() {
204            Ok(guard) => Ok(MutexGuard {
205                guard,
206                _locker: self,
207            }),
208            Err(relay_channel::error::TryRecvError::Empty) => {
209                if self.mutex_dropped.is_cancelled() {
210                    Err(error::TryAcquireError::Inaccessible)
211                } else {
212                    Err(error::TryAcquireError::AccessDenied)
213                }
214            }
215            Err(relay_channel::error::TryRecvError::Disconnected) => {
216                Err(error::TryAcquireError::Inaccessible)
217            }
218        }
219    }
220}
221
222impl<T> MutexGuard<'_, T> {
223    /// The lifetime parameter on [MutexGuard] is only for convenience (to help
224    /// avoid having multiple parallel calls to [acquire](MutexLocker::acquire)
225    /// and [try_acquire](MutexLocker::try_acquire)). The caller can choose to
226    /// instead
227    /// use this function to unwrap the underlying [OwnedMutexGuard] if it's
228    /// more convenient not to deal with the lifetime.
229    pub fn into_owned_guard(this: Self) -> OwnedMutexGuard<T> {
230        this.guard
231    }
232}
233
234impl<T> Deref for MutexGuard<'_, T> {
235    type Target = T;
236
237    fn deref(&self) -> &Self::Target {
238        &*self.guard
239    }
240}
241
242impl<T> DerefMut for MutexGuard<'_, T> {
243    fn deref_mut(&mut self) -> &mut Self::Target {
244        &mut *self.guard
245    }
246}
247
248#[cfg(test)]
249mod test;