orchestrator_lock/lib.rs
1//! `orchestrator_lock` provides a specialized mutex implementation for scenarios
2//! where fine-grained control over mutex access is required. Unlike a standard
3//! mutex where any code with a reference can attempt to acquire the lock, this
4//! implementation separates the concerns of lock orchestration from lock usage.
5//!
6//! # Core Concepts
7//!
8//! * **OrchestratorMutex**: The central coordinator that owns the protected value and
9//! controls access to it.
10//!
11//! * **Granter**: A capability token that allows the orchestrator to grant lock access
12//! to a specific locker.
13//!
14//! * **MutexLocker**: The component that can acquire and use the lock, but only when
15//! explicitly granted access by the orchestrator.
16//!
17//! * **MutexGuard**: Provides access to the protected value, similar to a standard
18//! mutex guard.
19//!
20//! # Example
21//! ```
22//! use tokio::time::Duration;
23//!
24//! use orchestrator_lock::OrchestratorMutex;
25//!
26//! #[tokio::main(flavor = "current_thread")]
27//! async fn main() {
28//! // Create a shared counter with initial value 0
29//! let mut orchestrator = OrchestratorMutex::new(0);
30//!
31//! // Create two granter/locker pairs
32//! let (mut granter1, mut locker1) = orchestrator.add_locker();
33//! let (mut granter2, mut locker2) = orchestrator.add_locker();
34//!
35//! // Task 1: Increments by 1 each time
36//! let task1 = tokio::spawn(async move {
37//! let expected = [0, 2, 6];
38//! for i in 0..3 {
39//! if let Some(mut guard) = locker1.acquire().await {
40//! assert_eq!(*guard, expected[i]);
41//! *guard += 1;
42//! tokio::time::sleep(Duration::from_millis(10)).await;
43//! }
44//! }
45//! locker1
46//! });
47//!
48//! // Task 2: Multiplies by 2 each time
49//! let task2 = tokio::spawn(async move {
50//! let expected = [1, 3, 7];
51//! for i in 0..3 {
52//! if let Some(mut guard) = locker2.acquire().await {
53//! assert_eq!(*guard, expected[i]);
54//! *guard *= 2;
55//! tokio::time::sleep(Duration::from_millis(10)).await;
56//! }
57//! }
58//! locker2
59//! });
60//!
61//! // Orchestration: Alternate between the two tasks
62//! let expected = [1, 2, 3, 6, 7, 14];
63//! for i in 0..3 {
64//! // Grant access to task 1
65//! let grant_future = orchestrator.grant_access(&mut granter1).await.unwrap();
66//! let guard = grant_future.await;
67//! assert_eq!(*guard, expected[i * 2]);
68//! drop(guard);
69//!
70//! // Grant access to task 2
71//! let grant_future = orchestrator.grant_access(&mut granter2).await.unwrap();
72//! let guard = grant_future.await;
73//! assert_eq!(*guard, expected[i * 2 + 1]);
74//! drop(guard);
75//! }
76//! // Clean up
77//! let _ = task1.await.unwrap();
78//! let _ = task2.await.unwrap();
79//! }
80//! ```
81
82use std::ops::{Deref, DerefMut};
83use std::sync::{Arc, Weak};
84
85use tokio::select;
86use tokio::sync::{Mutex, OwnedMutexGuard};
87use tokio_util::sync::{CancellationToken, DropGuard};
88
89pub mod error {
90 #[derive(Debug, PartialEq, Eq)]
91 pub struct GrantError;
92
93 #[derive(Debug, PartialEq, Eq)]
94 pub enum TryAcquireError {
95 /// [grant_access](super::OrchestratorMutex::grant_access) has not been
96 /// called with the corresponding [Granter](super::Granter).
97 AccessDenied,
98 /// Either the corresponding [Granter](super::Granter) or the
99 /// [OrchestratorMutex](super::OrchestratorMutex) has been dropped.
100 Inaccessible,
101 }
102}
103
104pub struct OrchestratorMutex<T> {
105 inner: Arc<Mutex<T>>,
106 dropped: CancellationToken,
107 _drop_guard: DropGuard,
108}
109
110pub struct Granter<T> {
111 inner: Weak<Mutex<T>>,
112 tx: relay_channel::Sender<OwnedMutexGuard<T>>,
113}
114
115pub struct MutexLocker<T> {
116 mutex_dropped: CancellationToken,
117 rx: relay_channel::Receiver<OwnedMutexGuard<T>>,
118}
119
120pub struct MutexGuard<'a, T> {
121 guard: OwnedMutexGuard<T>,
122 _locker: &'a mut MutexLocker<T>,
123}
124
125impl<T> OrchestratorMutex<T> {
126 pub fn new(value: T) -> Self {
127 let dropped = CancellationToken::new();
128 let drop_guard = dropped.clone().drop_guard();
129 Self {
130 inner: Arc::new(Mutex::new(value)),
131 dropped,
132 _drop_guard: drop_guard,
133 }
134 }
135
136 pub fn add_locker(&self) -> (Granter<T>, MutexLocker<T>) {
137 let (tx, rx) = relay_channel::channel();
138 let inner = Arc::downgrade(&self.inner);
139 let mutex_dropped = self.dropped.clone();
140 (Granter { inner, tx }, MutexLocker { mutex_dropped, rx })
141 }
142
143 /// Directly acquire the underlying lock.
144 pub async fn acquire(&self) -> tokio::sync::MutexGuard<'_, T> {
145 self.inner.lock().await
146 }
147
148 /// Attempt to acquire the underlying lock, failing if the lock is already
149 /// held.
150 pub fn try_acquire(&self) -> Result<tokio::sync::MutexGuard<'_, T>, tokio::sync::TryLockError> {
151 self.inner.try_lock()
152 }
153
154 /// Grants lock access to the [MutexLocker] corresponding to the provided
155 /// [Granter].
156 ///
157 /// This function returns [Ok] once the corresponding [MutexLocker] has
158 /// called
159 /// [acquire](MutexLocker::acquire) (or [Err] if the [MutexLocker] has been
160 /// dropped).
161 /// The [Ok] variant contains a future which waits for the [MutexGuard] to
162 /// be dropped and then re-acquires the mutex so the caller can see what
163 /// changes were made.
164 ///
165 /// If the future in the [Ok] variant is dropped, the next call to
166 /// [grant_access](Self::grant_access) will have to wait for the current
167 /// [MutexGuard] to be dropped before it can grant access to the next
168 /// [MutexLocker]. If this is called multiple times in parallel, the
169 /// order in which the [MutexLocker]s are granted access is unspecified.
170 ///
171 /// # Panics
172 /// Panics if `granter` was created from a different [OrchestratorMutex].
173 pub async fn grant_access(
174 &self,
175 granter: &mut Granter<T>,
176 ) -> Result<impl Future<Output = tokio::sync::MutexGuard<'_, T>>, error::GrantError> {
177 assert!(
178 Weak::ptr_eq(&granter.inner, &Arc::downgrade(&self.inner)),
179 "Granter is not associated with this OrchestratorMutex"
180 );
181 match granter.tx.send(self.inner.clone().lock_owned().await).await {
182 Ok(()) => Ok(self.inner.lock()),
183 Err(relay_channel::error::SendError(_)) => Err(error::GrantError),
184 }
185 }
186}
187
188impl<T> MutexLocker<T> {
189 /// Returns [None] if either the corresponding [Granter] or the
190 /// [OrchestratorMutex] has been dropped.
191 pub async fn acquire(&mut self) -> Option<MutexGuard<'_, T>> {
192 let result = select! {
193 result = self.rx.recv() => result,
194 () = self.mutex_dropped.cancelled() => None,
195 };
196 Some(MutexGuard {
197 guard: result?,
198 _locker: self,
199 })
200 }
201
202 pub fn try_acquire(&mut self) -> Result<MutexGuard<'_, T>, error::TryAcquireError> {
203 match self.rx.try_recv() {
204 Ok(guard) => Ok(MutexGuard {
205 guard,
206 _locker: self,
207 }),
208 Err(relay_channel::error::TryRecvError::Empty) => {
209 if self.mutex_dropped.is_cancelled() {
210 Err(error::TryAcquireError::Inaccessible)
211 } else {
212 Err(error::TryAcquireError::AccessDenied)
213 }
214 }
215 Err(relay_channel::error::TryRecvError::Disconnected) => {
216 Err(error::TryAcquireError::Inaccessible)
217 }
218 }
219 }
220}
221
222impl<T> MutexGuard<'_, T> {
223 /// The lifetime parameter on [MutexGuard] is only for convenience (to help
224 /// avoid having multiple parallel calls to [acquire](MutexLocker::acquire)
225 /// and [try_acquire](MutexLocker::try_acquire)). The caller can choose to
226 /// instead
227 /// use this function to unwrap the underlying [OwnedMutexGuard] if it's
228 /// more convenient not to deal with the lifetime.
229 pub fn into_owned_guard(this: Self) -> OwnedMutexGuard<T> {
230 this.guard
231 }
232}
233
234impl<T> Deref for MutexGuard<'_, T> {
235 type Target = T;
236
237 fn deref(&self) -> &Self::Target {
238 &*self.guard
239 }
240}
241
242impl<T> DerefMut for MutexGuard<'_, T> {
243 fn deref_mut(&mut self) -> &mut Self::Target {
244 &mut *self.guard
245 }
246}
247
248#[cfg(test)]
249mod test;