async_compatibility_layer/async_primitives/
subscribable_mutex.rs

1use crate::art::{async_timeout, future::to, stream};
2use crate::channel::{unbounded, UnboundedReceiver, UnboundedSender};
3use async_lock::{Mutex, MutexGuard};
4use futures::{stream::FuturesOrdered, Future, FutureExt};
5use std::{fmt, time::Duration};
6use tracing::warn;
7
8#[cfg(not(async_executor_impl = "tokio"))]
9use async_std::prelude::StreamExt;
10#[cfg(async_executor_impl = "tokio")]
11use tokio_stream::StreamExt;
12
13/// A mutex that can register subscribers to be notified. This works in the same way as [`Mutex`], but has some additional functions:
14///
15/// [`Self::subscribe`] will return a [`Receiver`] which can be used to be notified of changes.
16///
17/// [`Self::notify_change_subscribers`] will notify all `Receiver` that are registered with the `subscribe` function.
18#[derive(Default)]
19pub struct SubscribableMutex<T: ?Sized> {
20    /// A list of subscribers of this mutex.
21    subscribers: Mutex<Vec<UnboundedSender<()>>>,
22    /// The inner mutex holding the value.
23    /// Note that because of the `T: ?Sized` constraint, this must be the last field in this struct.
24    mutex: Mutex<T>,
25}
26
27impl<T> SubscribableMutex<T> {
28    /// Create a new mutex with the value T
29    pub fn new(t: T) -> Self {
30        Self {
31            mutex: Mutex::new(t),
32            subscribers: Mutex::default(),
33        }
34    }
35
36    /// Acquires the mutex.
37    ///
38    /// Returns a guard that releases the mutex when dropped.
39    ///
40    /// Direct usage of this function may result in unintentional deadlocks.
41    /// Consider using one of the following functions instead:
42    /// - `modify` to edit the inner value.
43    /// - `set` to set the inner value.
44    /// - `compare_and_set` compare the inner value with a given value, and if they match, update the value to the second value.
45    /// - `copied` and `cloned` gets a copy or clone of the inner value
46    #[deprecated(note = "Consider using a different function instead")]
47    pub async fn lock(&self) -> MutexGuard<'_, T> {
48        self.mutex.lock().await
49    }
50
51    /// Notify the subscribers that a change has occured. Subscribers can be registered by calling [`Self::subscribe`].
52    ///
53    /// Subscribers cannot be removed as they have no unique identifying information. Instead this function will simply remove all senders that fail to deliver their message.
54    pub async fn notify_change_subscribers(&self) {
55        let mut lock = self.subscribers.lock().await;
56        // We currently don't have a way to remove subscribers, so we'll remove them when they fail to deliver their message.
57        let mut idx_to_remove = Vec::new();
58        for (idx, sender) in lock.iter().enumerate() {
59            if sender.send(()).await.is_err() {
60                idx_to_remove.push(idx);
61            }
62        }
63        // Make sure to reverse `idx_to_remove`, or else the first index to remove will make the other indexes invalid
64        for idx in idx_to_remove.into_iter().rev() {
65            lock.remove(idx);
66        }
67    }
68
69    /// Create a [`Receiver`] that will be notified every time a thread calls [`Self::notify_change_subscribers`]
70    pub async fn subscribe(&self) -> UnboundedReceiver<()> {
71        let (sender, receiver) = unbounded();
72        self.subscribers.lock().await.push(sender);
73        receiver
74    }
75
76    /// Modify the internal value, then notify all subscribers that the value is updated.
77    pub async fn modify<F>(&self, cb: F)
78    where
79        F: FnOnce(&mut T),
80    {
81        let mut lock = self.mutex.lock().await;
82        cb(&mut *lock);
83        drop(lock);
84        self.notify_change_subscribers().await;
85    }
86
87    /// Set the new inner value, discarding the old ones. This will also notify all subscribers.
88    pub async fn set(&self, val: T) {
89        let mut lock = self.mutex.lock().await;
90        *lock = val;
91        drop(lock);
92        self.notify_change_subscribers().await;
93    }
94
95    /// Wait until `condition` returns `true`. Will block until then.
96    pub async fn wait_until<F>(&self, mut f: F)
97    where
98        F: FnMut(&T) -> bool,
99    {
100        let receiver = {
101            let lock = self.mutex.lock().await;
102            // Check if we already match the condition. If we do we don't have to subscribe at all.
103            if f(&*lock) {
104                return;
105            }
106            // note: don't drop the lock yet, we want to make sure we subscribe first
107            let receiver = self.subscribe().await;
108            drop(lock);
109            receiver
110        };
111        loop {
112            receiver
113                .recv()
114                .await
115                .expect("`SubscribableMutex::wait_until` was still running when it was dropped");
116            let lock = self.mutex.lock().await;
117            if f(&*lock) {
118                return;
119            }
120        }
121    }
122
123    /// Wait until `f` returns `true`. Signal on `ready_chan`
124    /// once has begun to listen
125    async fn wait_until_with_trigger_inner<'a, F>(
126        &self,
127        mut f: F,
128        ready_chan: futures::channel::oneshot::Sender<()>,
129    ) where
130        F: FnMut(&T) -> bool + 'a,
131    {
132        let receiver = self.subscribe().await;
133        if ready_chan.send(()).is_err() {
134            warn!("unable to notify that channel is ready");
135        };
136        loop {
137            receiver
138                .recv()
139                .await
140                .expect("`SubscribableMutex::wait_until` was still running when it was dropped");
141            let lock = self.mutex.lock().await;
142            if f(&*lock) {
143                return;
144            }
145            drop(lock);
146        }
147    }
148
149    /// Wait until `f` returns `true`. Turns a stream with two ordered
150    /// events. The first event indicates that the stream is now listening for
151    /// the state change, and the second event indicates that `f` has become true
152    pub fn wait_until_with_trigger<'a, F>(
153        &'a self,
154        f: F,
155    ) -> FuturesOrdered<impl Future<Output = ()> + 'a>
156    where
157        F: FnMut(&T) -> bool + 'a,
158    {
159        let (s, r) = futures::channel::oneshot::channel::<()>();
160        let mut result = FuturesOrdered::new();
161        let f1 = r.map(|_| ()).left_future();
162        let f2 = self.wait_until_with_trigger_inner(f, s).right_future();
163        result.push_back(f1);
164        result.push_back(f2);
165        result
166    }
167
168    /// Same functionality as `Self::wait_until_with_trigger`, except
169    /// with timeout `timeout` on both events in stream
170    pub fn wait_timeout_until_with_trigger<'a, F>(
171        &'a self,
172        timeout: Duration,
173        f: F,
174    ) -> stream::to::Timeout<FuturesOrdered<impl Future<Output = ()> + 'a>>
175    where
176        F: FnMut(&T) -> bool + 'a,
177    {
178        self.wait_until_with_trigger(f).timeout(timeout)
179    }
180
181    /// Wait `timeout` until `f` returns `true`. Will return `Ok(())` if the function returned `true` before the time elapsed.
182    /// Notifies caller over `ready_chan` when has begun to listen for changes to the
183    /// internal state (locked within the [`Mutex`])
184    ///
185    /// # Errors
186    ///
187    /// Returns an error when this function timed out.
188    pub async fn wait_timeout_until<F>(&self, timeout: Duration, f: F) -> to::Result<()>
189    where
190        F: FnMut(&T) -> bool,
191    {
192        async_timeout(timeout, self.wait_until(f)).await
193    }
194}
195
196impl<T: PartialEq> SubscribableMutex<T> {
197    /// Compare the value of this mutex. If the value is equal to `compare`, it will be set to `set` and all subscribers will be notified
198    pub async fn compare_and_set(&self, compare: T, set: T) {
199        let mut lock = self.mutex.lock().await;
200        if *lock == compare {
201            *lock = set;
202            drop(lock);
203            self.notify_change_subscribers().await;
204        }
205    }
206}
207
208impl<T: Clone> SubscribableMutex<T> {
209    /// Return a clone of the current value of `T`
210    pub async fn cloned(&self) -> T {
211        self.mutex.lock().await.clone()
212    }
213}
214
215impl<T: Copy> SubscribableMutex<T> {
216    /// Return a copy of the current value of `T`
217    pub async fn copied(&self) -> T {
218        *self.mutex.lock().await
219    }
220}
221
222impl<T: fmt::Debug> fmt::Debug for SubscribableMutex<T> {
223    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224        /// Helper struct to be shown when the inner mutex is locked.
225        struct Locked;
226        impl fmt::Debug for Locked {
227            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228                f.write_str("<locked>")
229            }
230        }
231
232        match self.mutex.try_lock() {
233            None => f
234                .debug_struct("SubscribableMutex")
235                .field("data", &Locked)
236                .finish(),
237            Some(guard) => f
238                .debug_struct("SubscribableMutex")
239                .field("data", &&*guard)
240                .finish(),
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::SubscribableMutex;
248    use crate::art::{async_sleep, async_spawn, async_timeout};
249    use std::{sync::Arc, time::Duration};
250
251    #[cfg_attr(
252        async_executor_impl = "tokio",
253        tokio::test(flavor = "multi_thread", worker_threads = 2)
254    )]
255    #[cfg_attr(not(async_executor_impl = "tokio"), async_std::test)]
256    async fn test_wait_timeout_until() {
257        let mutex: Arc<SubscribableMutex<usize>> = Arc::default();
258        {
259            // inner loop finishes in 1.1s
260            let mutex = Arc::clone(&mutex);
261            async_spawn(async move {
262                for i in 0..=10 {
263                    async_sleep(Duration::from_millis(100)).await;
264                    mutex.set(i).await;
265                }
266            });
267        }
268        // wait for 2 seconds
269        let result = mutex
270            .wait_timeout_until(Duration::from_secs(2), |s| *s == 10)
271            .await;
272        assert_eq!(result, Ok(()));
273        assert_eq!(mutex.copied().await, 10);
274    }
275
276    #[cfg_attr(
277        async_executor_impl = "tokio",
278        tokio::test(flavor = "multi_thread", worker_threads = 2)
279    )]
280    #[cfg_attr(not(async_executor_impl = "tokio"), async_std::test)]
281    async fn test_wait_timeout_until_fail() {
282        let mutex: Arc<SubscribableMutex<usize>> = Arc::default();
283        {
284            let mutex = Arc::clone(&mutex);
285            async_spawn(async move {
286                // Never gets to 10
287                for i in 0..10 {
288                    async_sleep(Duration::from_millis(100)).await;
289                    mutex.set(i).await;
290                }
291            });
292        }
293        let result = mutex
294            .wait_timeout_until(Duration::from_secs(2), |s| *s == 10)
295            .await;
296        assert!(result.is_err());
297        assert_eq!(mutex.copied().await, 9);
298    }
299
300    #[cfg_attr(
301        async_executor_impl = "tokio",
302        tokio::test(flavor = "multi_thread", worker_threads = 2)
303    )]
304    #[cfg_attr(not(async_executor_impl = "tokio"), async_std::test)]
305    async fn test_compare_and_set() {
306        let mutex = SubscribableMutex::new(5usize);
307        let subscriber = mutex.subscribe().await;
308
309        assert_eq!(mutex.copied().await, 5);
310
311        // Update
312        mutex.compare_and_set(5, 10).await;
313        assert_eq!(mutex.copied().await, 10);
314        assert!(subscriber.try_recv().is_ok());
315
316        // No update
317        mutex.compare_and_set(5, 20).await;
318        assert_eq!(mutex.copied().await, 10);
319        assert!(subscriber.try_recv().is_err());
320    }
321
322    #[cfg_attr(
323        async_executor_impl = "tokio",
324        tokio::test(flavor = "multi_thread", worker_threads = 2)
325    )]
326    #[cfg_attr(not(async_executor_impl = "tokio"), async_std::test)]
327    async fn test_subscriber() {
328        let mutex = SubscribableMutex::new(5usize);
329        let subscriber = mutex.subscribe().await;
330
331        // No messages
332        assert!(subscriber.try_recv().is_err());
333
334        // sync message
335        mutex.set(10).await;
336        assert_eq!(subscriber.try_recv(), Ok(()));
337
338        // async message
339        mutex.set(20).await;
340        assert_eq!(
341            async_timeout(Duration::from_millis(10), subscriber.recv()).await,
342            Ok(Ok(()))
343        );
344
345        // Validate we have 1 subscriber
346        assert_eq!(mutex.subscribers.lock().await.len(), 1);
347
348        // Validate that if we drop the subscriber, and notify, it'll be removed
349        drop(subscriber);
350        mutex.notify_change_subscribers().await;
351        assert_eq!(mutex.subscribers.lock().await.len(), 0);
352    }
353}