fast_async_mutex/
mutex.rs

1use crate::inner::Inner;
2use std::fmt::Debug;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8/// The simple Mutex, which will provide unique access to you data between multiple threads/futures.
9#[derive(Debug)]
10pub struct Mutex<T: ?Sized> {
11    inner: Inner<T>,
12}
13
14impl<T> Mutex<T> {
15    /// Create a new `Mutex`
16    #[inline]
17    pub const fn new(data: T) -> Mutex<T> {
18        Mutex {
19            inner: Inner::new(data),
20        }
21    }
22}
23
24impl<T: ?Sized> Mutex<T> {
25    /// Acquires the mutex.
26    ///
27    /// Returns a guard that releases the mutex and wake the next locker when dropped.
28    ///
29    /// # Examples
30    ///
31    /// ```
32    /// use fast_async_mutex::mutex::Mutex;
33    ///
34    /// #[tokio::main]
35    /// async fn main() {
36    ///     let mutex = Mutex::new(10);
37    ///     let guard = mutex.lock().await;
38    ///     assert_eq!(*guard, 10);
39    /// }
40    /// ```
41    #[inline]
42    pub const fn lock(&self) -> MutexGuardFuture<T> {
43        MutexGuardFuture {
44            mutex: &self,
45            is_realized: false,
46        }
47    }
48
49    /// Acquires the mutex.
50    ///
51    /// Returns a guard that releases the mutex and wake the next locker when dropped.
52    /// `MutexOwnedGuardFuture` have a `'static` lifetime, but requires the `Arc<Mutex<T>>` type
53    ///
54    /// # Examples
55    ///
56    /// ```
57    /// use fast_async_mutex::mutex::Mutex;
58    /// use std::sync::Arc;
59    /// #[tokio::main]
60    /// async fn main() {
61    ///     let mutex = Arc::new(Mutex::new(10));
62    ///     let guard = mutex.lock_owned().await;
63    ///     assert_eq!(*guard, 10);
64    /// }
65    /// ```
66    #[inline]
67    pub fn lock_owned(self: &Arc<Self>) -> MutexOwnedGuardFuture<T> {
68        MutexOwnedGuardFuture {
69            mutex: self.clone(),
70            is_realized: false,
71        }
72    }
73}
74
75/// The Simple Mutex Guard
76/// As long as you have this guard, you have exclusive access to the underlying `T`. The guard internally borrows the Mutex, so the mutex will not be dropped while a guard exists.
77/// The lock is automatically released and waked the next locker whenever the guard is dropped, at which point lock will succeed yet again.
78#[derive(Debug)]
79pub struct MutexGuard<'a, T: ?Sized> {
80    mutex: &'a Mutex<T>,
81}
82
83#[derive(Debug)]
84pub struct MutexGuardFuture<'a, T: ?Sized> {
85    mutex: &'a Mutex<T>,
86    is_realized: bool,
87}
88
89/// An owned handle to a held Mutex.
90/// This guard is only available from a Mutex that is wrapped in an `Arc`. It is identical to `MutexGuard`, except that rather than borrowing the `Mutex`, it clones the `Arc`, incrementing the reference count. This means that unlike `MutexGuard`, it will have the `'static` lifetime.
91/// As long as you have this guard, you have exclusive access to the underlying `T`. The guard internally keeps a reference-couned pointer to the original `Mutex`, so even if the lock goes away, the guard remains valid.
92/// The lock is automatically released and waked the next locker whenever the guard is dropped, at which point lock will succeed yet again.
93#[derive(Debug)]
94pub struct MutexOwnedGuard<T: ?Sized> {
95    mutex: Arc<Mutex<T>>,
96}
97
98#[derive(Debug)]
99pub struct MutexOwnedGuardFuture<T: ?Sized> {
100    mutex: Arc<Mutex<T>>,
101    is_realized: bool,
102}
103
104impl<'a, T: ?Sized> Future for MutexGuardFuture<'a, T> {
105    type Output = MutexGuard<'a, T>;
106
107    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
108        if self.mutex.inner.try_acquire() {
109            self.is_realized = true;
110            Poll::Ready(MutexGuard { mutex: self.mutex })
111        } else {
112            self.mutex.inner.store_waker(cx.waker());
113            Poll::Pending
114        }
115    }
116}
117
118impl<T: ?Sized> Future for MutexOwnedGuardFuture<T> {
119    type Output = MutexOwnedGuard<T>;
120
121    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
122        if self.mutex.inner.try_acquire() {
123            self.is_realized = true;
124            Poll::Ready(MutexOwnedGuard {
125                mutex: self.mutex.clone(),
126            })
127        } else {
128            self.mutex.inner.store_waker(cx.waker());
129            Poll::Pending
130        }
131    }
132}
133
134crate::impl_send_sync_mutex!(Mutex, MutexGuard, MutexOwnedGuard);
135
136crate::impl_deref_mut!(MutexGuard, 'a);
137crate::impl_deref_mut!(MutexOwnedGuard);
138
139crate::impl_drop_guard!(MutexGuard, 'a, unlock);
140crate::impl_drop_guard!(MutexOwnedGuard, unlock);
141crate::impl_drop_guard_future!(MutexGuardFuture, 'a, unlock);
142crate::impl_drop_guard_future!(MutexOwnedGuardFuture, unlock);
143
144#[cfg(test)]
145mod tests {
146    use crate::mutex::{Mutex, MutexGuard, MutexOwnedGuard};
147    use futures::executor::block_on;
148    use futures::{FutureExt, StreamExt, TryStreamExt};
149    use std::ops::AddAssign;
150    use std::sync::Arc;
151    use tokio::time::{sleep, Duration};
152
153    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
154    async fn test_mutex() {
155        let c = Mutex::new(0);
156
157        futures::stream::iter(0..10000)
158            .for_each_concurrent(None, |_| async {
159                let mut co: MutexGuard<i32> = c.lock().await;
160                *co += 1;
161            })
162            .await;
163
164        let co = c.lock().await;
165        assert_eq!(*co, 10000)
166    }
167
168    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
169    async fn test_mutex_delay() {
170        let expected_result = 100;
171        let c = Mutex::new(0);
172
173        futures::stream::iter(0..expected_result)
174            .then(|i| c.lock().map(move |co| (i, co)))
175            .for_each_concurrent(None, |(i, mut co)| async move {
176                sleep(Duration::from_millis(expected_result - i)).await;
177                *co += 1;
178            })
179            .await;
180
181        let co = c.lock().await;
182        assert_eq!(*co, expected_result)
183    }
184
185    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
186    async fn test_owned_mutex() {
187        let c = Arc::new(Mutex::new(0));
188
189        futures::stream::iter(0..10000)
190            .for_each_concurrent(None, |_| async {
191                let mut co: MutexOwnedGuard<i32> = c.lock_owned().await;
192                *co += 1;
193            })
194            .await;
195
196        let co = c.lock_owned().await;
197        assert_eq!(*co, 10000)
198    }
199
200    #[tokio::test]
201    async fn test_container() {
202        let c = Mutex::new(String::from("lol"));
203
204        let mut co: MutexGuard<String> = c.lock().await;
205        co.add_assign("lol");
206
207        assert_eq!(*co, "lollol");
208    }
209
210    #[tokio::test]
211    async fn test_timeout() {
212        let c = Mutex::new(String::from("lol"));
213
214        let co: MutexGuard<String> = c.lock().await;
215
216        futures::stream::iter(0..10000i32)
217            .then(|_| tokio::time::timeout(Duration::from_nanos(1), c.lock()))
218            .try_for_each_concurrent(None, |_c| futures::future::ok(()))
219            .await
220            .expect_err("timout must be");
221
222        drop(co);
223
224        let mut co: MutexGuard<String> = c.lock().await;
225        co.add_assign("lol");
226
227        assert_eq!(*co, "lollol");
228    }
229
230    #[test]
231    fn multithreading_test() {
232        let num = 100;
233        let mutex = Arc::new(Mutex::new(0));
234        let ths: Vec<_> = (0..num)
235            .map(|_| {
236                let mutex = mutex.clone();
237                std::thread::spawn(move || {
238                    block_on(async {
239                        let mut lock = mutex.lock().await;
240                        *lock += 1;
241                    })
242                })
243            })
244            .collect();
245
246        for thread in ths {
247            thread.join().unwrap();
248        }
249
250        block_on(async {
251            let lock = mutex.lock().await;
252            assert_eq!(num, *lock)
253        })
254    }
255}