Skip to main content

commonware_utils/sync/
mod.rs

1//! Utilities for working with synchronization primitives.
2//!
3//! # Choosing A Lock
4//!
5//! Prefer blocking locks for shared data:
6//! - [Mutex]
7//! - [RwLock]
8//!
9//! Use async locks only when you must hold a lock guard across an `.await` point:
10//! - [AsyncMutex]
11//! - [AsyncRwLock]
12//! - [UpgradableAsyncRwLock] when you need to read first and then conditionally upgrade to write
13//!   without allowing another writer to slip in between.
14//!
15//! Async locks are more expensive and should generally be reserved for coordination around
16//! asynchronous I/O resources. For plain in-memory data, blocking locks are usually the right
17//! default.
18//!
19//! Do not hold blocking lock guards across `.await`.
20//!
21//! Async lock guards may span `.await` when needed, but keep those critical sections as small as
22//! possible because long-held guards increase contention and deadlock risk.
23
24use core::ops::{Deref, DerefMut};
25pub use parking_lot::{
26    Condvar, Mutex, MutexGuard, Once, RwLock, RwLockReadGuard, RwLockWriteGuard,
27};
28pub use tokio::sync::{
29    Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, RwLock as AsyncRwLock,
30    RwLockReadGuard as AsyncRwLockReadGuard, RwLockWriteGuard as AsyncRwLockWriteGuard,
31};
32
33/// A Tokio-based async rwlock with an upgradable read mode.
34///
35/// All `write` and `upgradable_read` acquisitions take an internal async mutex ("gate") first.
36/// This ensures that upgrading from read to write does not allow another writer to slip in.
37pub struct UpgradableAsyncRwLock<T> {
38    rw: tokio::sync::RwLock<T>,
39    gate: tokio::sync::Mutex<()>,
40}
41
42impl<T> UpgradableAsyncRwLock<T> {
43    /// Create a new lock wrapping `value`.
44    pub fn new(value: T) -> Self {
45        Self {
46            rw: tokio::sync::RwLock::new(value),
47            gate: tokio::sync::Mutex::new(()),
48        }
49    }
50
51    /// Acquire a shared read guard.
52    pub async fn read(&self) -> tokio::sync::RwLockReadGuard<'_, T> {
53        self.rw.read().await
54    }
55
56    /// Acquire an exclusive write guard.
57    ///
58    /// Writers are serialized through the internal gate.
59    pub async fn write(&self) -> UpgradableAsyncRwLockWriteGuard<'_, T> {
60        let gate_guard = self.gate.lock().await;
61        let guard = self.rw.write().await;
62        UpgradableAsyncRwLockWriteGuard {
63            lock: self,
64            guard,
65            gate_guard,
66        }
67    }
68
69    /// Acquire an upgradable read guard.
70    ///
71    /// This allows shared reads, then a later [UpgradableAsyncRwLockUpgradableReadGuard::upgrade]
72    /// to exclusive write while holding the same gate token.
73    pub async fn upgradable_read(&self) -> UpgradableAsyncRwLockUpgradableReadGuard<'_, T> {
74        let gate_guard = self.gate.lock().await;
75        let guard = self.rw.read().await;
76        UpgradableAsyncRwLockUpgradableReadGuard {
77            lock: self,
78            guard,
79            gate_guard,
80        }
81    }
82
83    /// Consume the lock and return the wrapped value.
84    pub fn into_inner(self) -> T {
85        self.rw.into_inner()
86    }
87}
88
89/// Exclusive write guard for [UpgradableAsyncRwLock].
90pub struct UpgradableAsyncRwLockWriteGuard<'a, T> {
91    lock: &'a UpgradableAsyncRwLock<T>,
92    guard: tokio::sync::RwLockWriteGuard<'a, T>,
93    gate_guard: tokio::sync::MutexGuard<'a, ()>,
94}
95
96impl<'a, T> UpgradableAsyncRwLockWriteGuard<'a, T> {
97    /// Downgrade to an upgradable read guard while retaining the internal gate token.
98    pub fn downgrade_to_upgradable(self) -> UpgradableAsyncRwLockUpgradableReadGuard<'a, T> {
99        let Self {
100            lock,
101            guard,
102            gate_guard,
103        } = self;
104        let guard = tokio::sync::RwLockWriteGuard::downgrade(guard);
105        UpgradableAsyncRwLockUpgradableReadGuard {
106            lock,
107            guard,
108            gate_guard,
109        }
110    }
111}
112
113impl<T> Deref for UpgradableAsyncRwLockWriteGuard<'_, T> {
114    type Target = T;
115
116    fn deref(&self) -> &Self::Target {
117        &self.guard
118    }
119}
120
121impl<T> DerefMut for UpgradableAsyncRwLockWriteGuard<'_, T> {
122    fn deref_mut(&mut self) -> &mut Self::Target {
123        &mut self.guard
124    }
125}
126
127/// Upgradable read guard for [UpgradableAsyncRwLock].
128pub struct UpgradableAsyncRwLockUpgradableReadGuard<'a, T> {
129    lock: &'a UpgradableAsyncRwLock<T>,
130    guard: tokio::sync::RwLockReadGuard<'a, T>,
131    gate_guard: tokio::sync::MutexGuard<'a, ()>,
132}
133
134impl<'a, T> UpgradableAsyncRwLockUpgradableReadGuard<'a, T> {
135    /// Upgrade this guard to an exclusive writer.
136    pub async fn upgrade(self) -> UpgradableAsyncRwLockWriteGuard<'a, T> {
137        let Self {
138            lock,
139            guard,
140            gate_guard,
141        } = self;
142        drop(guard);
143        let guard = lock.rw.write().await;
144        UpgradableAsyncRwLockWriteGuard {
145            lock,
146            guard,
147            gate_guard,
148        }
149    }
150}
151
152impl<T> Deref for UpgradableAsyncRwLockUpgradableReadGuard<'_, T> {
153    type Target = T;
154
155    fn deref(&self) -> &Self::Target {
156        &self.guard
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::{AsyncRwLock, UpgradableAsyncRwLock};
163    use futures::{pin_mut, FutureExt};
164
165    #[test]
166    fn test_async_rwlock() {
167        futures::executor::block_on(async {
168            let lock = AsyncRwLock::new(100u64);
169
170            let r1 = lock.read().await;
171            let r2 = lock.read().await;
172            assert_eq!(*r1 + *r2, 200);
173
174            drop((r1, r2));
175            let mut writer = lock.write().await;
176            *writer += 1;
177
178            assert_eq!(*writer, 101);
179        });
180    }
181
182    #[test]
183    fn test_upgradable_read_blocks_write() {
184        futures::executor::block_on(async {
185            let lock = UpgradableAsyncRwLock::new(1u64);
186            let upgradable = lock.upgradable_read().await;
187
188            let write = lock.write();
189            pin_mut!(write);
190            assert!(write.as_mut().now_or_never().is_none());
191
192            drop(upgradable);
193
194            let mut write = write.await;
195            *write = 2;
196            drop(write);
197
198            assert_eq!(*lock.read().await, 2);
199        });
200    }
201
202    #[test]
203    fn test_read_allowed_during_upgradable_read() {
204        futures::executor::block_on(async {
205            let lock = UpgradableAsyncRwLock::new(5u64);
206            let upgradable = lock.upgradable_read().await;
207            let reader = lock.read().await;
208            assert_eq!(*upgradable, 5);
209            assert_eq!(*reader, 5);
210        });
211    }
212
213    #[test]
214    fn test_upgrade_prevents_writer_interleaving() {
215        futures::executor::block_on(async {
216            let lock = UpgradableAsyncRwLock::new(1u64);
217            let upgradable = lock.upgradable_read().await;
218
219            let writer = async {
220                let mut writer = lock.write().await;
221                let observed = *writer;
222                *writer = 7;
223                observed
224            };
225            pin_mut!(writer);
226            assert!(writer.as_mut().now_or_never().is_none());
227
228            let mut upgraded = upgradable.upgrade().await;
229            *upgraded = 5;
230            drop(upgraded);
231
232            assert_eq!(writer.await, 5);
233        });
234    }
235
236    #[test]
237    fn test_downgrade_to_upgradable() {
238        futures::executor::block_on(async {
239            let lock = UpgradableAsyncRwLock::new(10u64);
240            let mut writer = lock.write().await;
241            *writer = 11;
242
243            let upgradable = writer.downgrade_to_upgradable();
244            let writer = lock.write();
245            pin_mut!(writer);
246            assert!(writer.as_mut().now_or_never().is_none());
247            drop(upgradable);
248
249            let writer = writer.await;
250            assert_eq!(*writer, 11);
251        });
252    }
253}