async_compatibility_layer/async_primitives/
subscribable_mutex.rs1use crate::art::{async_timeout, future::to, stream};
2use crate::channel::{unbounded, UnboundedReceiver, UnboundedSender};
3use async_lock::{Mutex, MutexGuard};
4use futures::{stream::FuturesOrdered, Future, FutureExt};
5use std::{fmt, time::Duration};
6use tracing::warn;
7
8#[cfg(not(async_executor_impl = "tokio"))]
9use async_std::prelude::StreamExt;
10#[cfg(async_executor_impl = "tokio")]
11use tokio_stream::StreamExt;
12
13#[derive(Default)]
19pub struct SubscribableMutex<T: ?Sized> {
20 subscribers: Mutex<Vec<UnboundedSender<()>>>,
22 mutex: Mutex<T>,
25}
26
27impl<T> SubscribableMutex<T> {
28 pub fn new(t: T) -> Self {
30 Self {
31 mutex: Mutex::new(t),
32 subscribers: Mutex::default(),
33 }
34 }
35
36 #[deprecated(note = "Consider using a different function instead")]
47 pub async fn lock(&self) -> MutexGuard<'_, T> {
48 self.mutex.lock().await
49 }
50
51 pub async fn notify_change_subscribers(&self) {
55 let mut lock = self.subscribers.lock().await;
56 let mut idx_to_remove = Vec::new();
58 for (idx, sender) in lock.iter().enumerate() {
59 if sender.send(()).await.is_err() {
60 idx_to_remove.push(idx);
61 }
62 }
63 for idx in idx_to_remove.into_iter().rev() {
65 lock.remove(idx);
66 }
67 }
68
69 pub async fn subscribe(&self) -> UnboundedReceiver<()> {
71 let (sender, receiver) = unbounded();
72 self.subscribers.lock().await.push(sender);
73 receiver
74 }
75
76 pub async fn modify<F>(&self, cb: F)
78 where
79 F: FnOnce(&mut T),
80 {
81 let mut lock = self.mutex.lock().await;
82 cb(&mut *lock);
83 drop(lock);
84 self.notify_change_subscribers().await;
85 }
86
87 pub async fn set(&self, val: T) {
89 let mut lock = self.mutex.lock().await;
90 *lock = val;
91 drop(lock);
92 self.notify_change_subscribers().await;
93 }
94
95 pub async fn wait_until<F>(&self, mut f: F)
97 where
98 F: FnMut(&T) -> bool,
99 {
100 let receiver = {
101 let lock = self.mutex.lock().await;
102 if f(&*lock) {
104 return;
105 }
106 let receiver = self.subscribe().await;
108 drop(lock);
109 receiver
110 };
111 loop {
112 receiver
113 .recv()
114 .await
115 .expect("`SubscribableMutex::wait_until` was still running when it was dropped");
116 let lock = self.mutex.lock().await;
117 if f(&*lock) {
118 return;
119 }
120 }
121 }
122
123 async fn wait_until_with_trigger_inner<'a, F>(
126 &self,
127 mut f: F,
128 ready_chan: futures::channel::oneshot::Sender<()>,
129 ) where
130 F: FnMut(&T) -> bool + 'a,
131 {
132 let receiver = self.subscribe().await;
133 if ready_chan.send(()).is_err() {
134 warn!("unable to notify that channel is ready");
135 };
136 loop {
137 receiver
138 .recv()
139 .await
140 .expect("`SubscribableMutex::wait_until` was still running when it was dropped");
141 let lock = self.mutex.lock().await;
142 if f(&*lock) {
143 return;
144 }
145 drop(lock);
146 }
147 }
148
149 pub fn wait_until_with_trigger<'a, F>(
153 &'a self,
154 f: F,
155 ) -> FuturesOrdered<impl Future<Output = ()> + 'a>
156 where
157 F: FnMut(&T) -> bool + 'a,
158 {
159 let (s, r) = futures::channel::oneshot::channel::<()>();
160 let mut result = FuturesOrdered::new();
161 let f1 = r.map(|_| ()).left_future();
162 let f2 = self.wait_until_with_trigger_inner(f, s).right_future();
163 result.push_back(f1);
164 result.push_back(f2);
165 result
166 }
167
168 pub fn wait_timeout_until_with_trigger<'a, F>(
171 &'a self,
172 timeout: Duration,
173 f: F,
174 ) -> stream::to::Timeout<FuturesOrdered<impl Future<Output = ()> + 'a>>
175 where
176 F: FnMut(&T) -> bool + 'a,
177 {
178 self.wait_until_with_trigger(f).timeout(timeout)
179 }
180
181 pub async fn wait_timeout_until<F>(&self, timeout: Duration, f: F) -> to::Result<()>
189 where
190 F: FnMut(&T) -> bool,
191 {
192 async_timeout(timeout, self.wait_until(f)).await
193 }
194}
195
196impl<T: PartialEq> SubscribableMutex<T> {
197 pub async fn compare_and_set(&self, compare: T, set: T) {
199 let mut lock = self.mutex.lock().await;
200 if *lock == compare {
201 *lock = set;
202 drop(lock);
203 self.notify_change_subscribers().await;
204 }
205 }
206}
207
208impl<T: Clone> SubscribableMutex<T> {
209 pub async fn cloned(&self) -> T {
211 self.mutex.lock().await.clone()
212 }
213}
214
215impl<T: Copy> SubscribableMutex<T> {
216 pub async fn copied(&self) -> T {
218 *self.mutex.lock().await
219 }
220}
221
222impl<T: fmt::Debug> fmt::Debug for SubscribableMutex<T> {
223 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224 struct Locked;
226 impl fmt::Debug for Locked {
227 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228 f.write_str("<locked>")
229 }
230 }
231
232 match self.mutex.try_lock() {
233 None => f
234 .debug_struct("SubscribableMutex")
235 .field("data", &Locked)
236 .finish(),
237 Some(guard) => f
238 .debug_struct("SubscribableMutex")
239 .field("data", &&*guard)
240 .finish(),
241 }
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::SubscribableMutex;
248 use crate::art::{async_sleep, async_spawn, async_timeout};
249 use std::{sync::Arc, time::Duration};
250
251 #[cfg_attr(
252 async_executor_impl = "tokio",
253 tokio::test(flavor = "multi_thread", worker_threads = 2)
254 )]
255 #[cfg_attr(not(async_executor_impl = "tokio"), async_std::test)]
256 async fn test_wait_timeout_until() {
257 let mutex: Arc<SubscribableMutex<usize>> = Arc::default();
258 {
259 let mutex = Arc::clone(&mutex);
261 async_spawn(async move {
262 for i in 0..=10 {
263 async_sleep(Duration::from_millis(100)).await;
264 mutex.set(i).await;
265 }
266 });
267 }
268 let result = mutex
270 .wait_timeout_until(Duration::from_secs(2), |s| *s == 10)
271 .await;
272 assert_eq!(result, Ok(()));
273 assert_eq!(mutex.copied().await, 10);
274 }
275
276 #[cfg_attr(
277 async_executor_impl = "tokio",
278 tokio::test(flavor = "multi_thread", worker_threads = 2)
279 )]
280 #[cfg_attr(not(async_executor_impl = "tokio"), async_std::test)]
281 async fn test_wait_timeout_until_fail() {
282 let mutex: Arc<SubscribableMutex<usize>> = Arc::default();
283 {
284 let mutex = Arc::clone(&mutex);
285 async_spawn(async move {
286 for i in 0..10 {
288 async_sleep(Duration::from_millis(100)).await;
289 mutex.set(i).await;
290 }
291 });
292 }
293 let result = mutex
294 .wait_timeout_until(Duration::from_secs(2), |s| *s == 10)
295 .await;
296 assert!(result.is_err());
297 assert_eq!(mutex.copied().await, 9);
298 }
299
300 #[cfg_attr(
301 async_executor_impl = "tokio",
302 tokio::test(flavor = "multi_thread", worker_threads = 2)
303 )]
304 #[cfg_attr(not(async_executor_impl = "tokio"), async_std::test)]
305 async fn test_compare_and_set() {
306 let mutex = SubscribableMutex::new(5usize);
307 let subscriber = mutex.subscribe().await;
308
309 assert_eq!(mutex.copied().await, 5);
310
311 mutex.compare_and_set(5, 10).await;
313 assert_eq!(mutex.copied().await, 10);
314 assert!(subscriber.try_recv().is_ok());
315
316 mutex.compare_and_set(5, 20).await;
318 assert_eq!(mutex.copied().await, 10);
319 assert!(subscriber.try_recv().is_err());
320 }
321
322 #[cfg_attr(
323 async_executor_impl = "tokio",
324 tokio::test(flavor = "multi_thread", worker_threads = 2)
325 )]
326 #[cfg_attr(not(async_executor_impl = "tokio"), async_std::test)]
327 async fn test_subscriber() {
328 let mutex = SubscribableMutex::new(5usize);
329 let subscriber = mutex.subscribe().await;
330
331 assert!(subscriber.try_recv().is_err());
333
334 mutex.set(10).await;
336 assert_eq!(subscriber.try_recv(), Ok(()));
337
338 mutex.set(20).await;
340 assert_eq!(
341 async_timeout(Duration::from_millis(10), subscriber.recv()).await,
342 Ok(Ok(()))
343 );
344
345 assert_eq!(mutex.subscribers.lock().await.len(), 1);
347
348 drop(subscriber);
350 mutex.notify_change_subscribers().await;
351 assert_eq!(mutex.subscribers.lock().await.len(), 0);
352 }
353}