commonware_utils/sync/
mod.rs1use core::ops::{Deref, DerefMut};
25pub use parking_lot::{
26 Condvar, Mutex, MutexGuard, Once, RwLock, RwLockReadGuard, RwLockWriteGuard,
27};
28pub use tokio::sync::{
29 Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, RwLock as AsyncRwLock,
30 RwLockReadGuard as AsyncRwLockReadGuard, RwLockWriteGuard as AsyncRwLockWriteGuard,
31};
32
33pub struct UpgradableAsyncRwLock<T> {
38 rw: tokio::sync::RwLock<T>,
39 gate: tokio::sync::Mutex<()>,
40}
41
42impl<T> UpgradableAsyncRwLock<T> {
43 pub fn new(value: T) -> Self {
45 Self {
46 rw: tokio::sync::RwLock::new(value),
47 gate: tokio::sync::Mutex::new(()),
48 }
49 }
50
51 pub async fn read(&self) -> tokio::sync::RwLockReadGuard<'_, T> {
53 self.rw.read().await
54 }
55
56 pub async fn write(&self) -> UpgradableAsyncRwLockWriteGuard<'_, T> {
60 let gate_guard = self.gate.lock().await;
61 let guard = self.rw.write().await;
62 UpgradableAsyncRwLockWriteGuard {
63 lock: self,
64 guard,
65 gate_guard,
66 }
67 }
68
69 pub async fn upgradable_read(&self) -> UpgradableAsyncRwLockUpgradableReadGuard<'_, T> {
74 let gate_guard = self.gate.lock().await;
75 let guard = self.rw.read().await;
76 UpgradableAsyncRwLockUpgradableReadGuard {
77 lock: self,
78 guard,
79 gate_guard,
80 }
81 }
82
83 pub fn into_inner(self) -> T {
85 self.rw.into_inner()
86 }
87}
88
89pub struct UpgradableAsyncRwLockWriteGuard<'a, T> {
91 lock: &'a UpgradableAsyncRwLock<T>,
92 guard: tokio::sync::RwLockWriteGuard<'a, T>,
93 gate_guard: tokio::sync::MutexGuard<'a, ()>,
94}
95
96impl<'a, T> UpgradableAsyncRwLockWriteGuard<'a, T> {
97 pub fn downgrade_to_upgradable(self) -> UpgradableAsyncRwLockUpgradableReadGuard<'a, T> {
99 let Self {
100 lock,
101 guard,
102 gate_guard,
103 } = self;
104 let guard = tokio::sync::RwLockWriteGuard::downgrade(guard);
105 UpgradableAsyncRwLockUpgradableReadGuard {
106 lock,
107 guard,
108 gate_guard,
109 }
110 }
111}
112
113impl<T> Deref for UpgradableAsyncRwLockWriteGuard<'_, T> {
114 type Target = T;
115
116 fn deref(&self) -> &Self::Target {
117 &self.guard
118 }
119}
120
121impl<T> DerefMut for UpgradableAsyncRwLockWriteGuard<'_, T> {
122 fn deref_mut(&mut self) -> &mut Self::Target {
123 &mut self.guard
124 }
125}
126
127pub struct UpgradableAsyncRwLockUpgradableReadGuard<'a, T> {
129 lock: &'a UpgradableAsyncRwLock<T>,
130 guard: tokio::sync::RwLockReadGuard<'a, T>,
131 gate_guard: tokio::sync::MutexGuard<'a, ()>,
132}
133
134impl<'a, T> UpgradableAsyncRwLockUpgradableReadGuard<'a, T> {
135 pub async fn upgrade(self) -> UpgradableAsyncRwLockWriteGuard<'a, T> {
137 let Self {
138 lock,
139 guard,
140 gate_guard,
141 } = self;
142 drop(guard);
143 let guard = lock.rw.write().await;
144 UpgradableAsyncRwLockWriteGuard {
145 lock,
146 guard,
147 gate_guard,
148 }
149 }
150}
151
152impl<T> Deref for UpgradableAsyncRwLockUpgradableReadGuard<'_, T> {
153 type Target = T;
154
155 fn deref(&self) -> &Self::Target {
156 &self.guard
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::{AsyncRwLock, UpgradableAsyncRwLock};
163 use futures::{pin_mut, FutureExt};
164
165 #[test]
166 fn test_async_rwlock() {
167 futures::executor::block_on(async {
168 let lock = AsyncRwLock::new(100u64);
169
170 let r1 = lock.read().await;
171 let r2 = lock.read().await;
172 assert_eq!(*r1 + *r2, 200);
173
174 drop((r1, r2));
175 let mut writer = lock.write().await;
176 *writer += 1;
177
178 assert_eq!(*writer, 101);
179 });
180 }
181
182 #[test]
183 fn test_upgradable_read_blocks_write() {
184 futures::executor::block_on(async {
185 let lock = UpgradableAsyncRwLock::new(1u64);
186 let upgradable = lock.upgradable_read().await;
187
188 let write = lock.write();
189 pin_mut!(write);
190 assert!(write.as_mut().now_or_never().is_none());
191
192 drop(upgradable);
193
194 let mut write = write.await;
195 *write = 2;
196 drop(write);
197
198 assert_eq!(*lock.read().await, 2);
199 });
200 }
201
202 #[test]
203 fn test_read_allowed_during_upgradable_read() {
204 futures::executor::block_on(async {
205 let lock = UpgradableAsyncRwLock::new(5u64);
206 let upgradable = lock.upgradable_read().await;
207 let reader = lock.read().await;
208 assert_eq!(*upgradable, 5);
209 assert_eq!(*reader, 5);
210 });
211 }
212
213 #[test]
214 fn test_upgrade_prevents_writer_interleaving() {
215 futures::executor::block_on(async {
216 let lock = UpgradableAsyncRwLock::new(1u64);
217 let upgradable = lock.upgradable_read().await;
218
219 let writer = async {
220 let mut writer = lock.write().await;
221 let observed = *writer;
222 *writer = 7;
223 observed
224 };
225 pin_mut!(writer);
226 assert!(writer.as_mut().now_or_never().is_none());
227
228 let mut upgraded = upgradable.upgrade().await;
229 *upgraded = 5;
230 drop(upgraded);
231
232 assert_eq!(writer.await, 5);
233 });
234 }
235
236 #[test]
237 fn test_downgrade_to_upgradable() {
238 futures::executor::block_on(async {
239 let lock = UpgradableAsyncRwLock::new(10u64);
240 let mut writer = lock.write().await;
241 *writer = 11;
242
243 let upgradable = writer.downgrade_to_upgradable();
244 let writer = lock.write();
245 pin_mut!(writer);
246 assert!(writer.as_mut().now_or_never().is_none());
247 drop(upgradable);
248
249 let writer = writer.await;
250 assert_eq!(*writer, 11);
251 });
252 }
253}