fast_async_mutex/
rwlock.rs

1use crate::inner::Inner;
2use std::fmt::Debug;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::Arc;
7use std::task::{Context, Poll};
8
9/// The RW Lock mechanism accepts you get concurrent shared access to your data without waiting.
10/// And get unique access with locks like a Mutex.
11#[derive(Debug)]
12pub struct RwLock<T: ?Sized> {
13    readers: AtomicUsize,
14    inner: Inner<T>,
15}
16
17impl<T> RwLock<T> {
18    /// Create a new `RWLock`
19    #[inline]
20    pub const fn new(data: T) -> RwLock<T> {
21        RwLock {
22            readers: AtomicUsize::new(0),
23            inner: Inner::new(data),
24        }
25    }
26}
27
28impl<T: ?Sized> RwLock<T> {
29    /// Acquires the mutex for are write.
30    ///
31    /// Returns a guard that releases the mutex and wake the next locker when it will be dropped.
32    ///
33    /// # Examples
34    ///
35    /// ```
36    /// use fast_async_mutex::rwlock::RwLock;
37    ///
38    /// #[tokio::main]
39    /// async fn main() {
40    ///     let mutex = RwLock::new(10);
41    ///     let mut guard = mutex.write().await;
42    ///     *guard += 1;
43    ///     assert_eq!(*guard, 11);
44    /// }
45    /// ```
46    #[inline]
47    pub fn write(&self) -> RwLockWriteGuardFuture<T> {
48        RwLockWriteGuardFuture {
49            mutex: &self,
50            is_realized: false,
51        }
52    }
53
54    /// Acquires the mutex for are write.
55    ///
56    /// Returns a guard that releases the mutex and wake the next locker when it will be dropped.
57    /// `WriteLockOwnedGuard` have a `'static` lifetime, but requires the `Arc<RWLock<T>>` type
58    ///
59    /// # Examples
60    ///
61    /// ```
62    /// use fast_async_mutex::rwlock::RwLock;
63    /// use std::sync::Arc;
64    /// #[tokio::main]
65    /// async fn main() {
66    ///     let mutex = Arc::new(RwLock::new(10));
67    ///     let mut guard = mutex.write_owned().await;
68    ///     *guard += 1;
69    ///     assert_eq!(*guard, 11);
70    /// }
71    /// ```
72    #[inline]
73    pub fn write_owned(self: &Arc<Self>) -> RwLockWriteOwnedGuardFuture<T> {
74        RwLockWriteOwnedGuardFuture {
75            mutex: self.clone(),
76            is_realized: false,
77        }
78    }
79
80    /// Acquires the mutex for are read.
81    ///
82    /// Returns a guard that releases the mutex and wake the next locker when it will be dropped.
83    ///
84    /// # Examples
85    ///
86    /// ```
87    /// use fast_async_mutex::rwlock::RwLock;
88    ///
89    /// #[tokio::main]
90    /// async fn main() {
91    ///     let mutex = RwLock::new(10);
92    ///     let guard = mutex.read().await;
93    ///     let guard2 = mutex.read().await;
94    ///     assert_eq!(*guard, *guard2);
95    /// }
96    /// ```
97    #[inline]
98    pub fn read(&self) -> RwLockReadGuardFuture<T> {
99        RwLockReadGuardFuture {
100            mutex: &self,
101            is_realized: false,
102        }
103    }
104
105    /// Acquires the mutex for are write.
106    ///
107    /// Returns a guard that releases the mutex and wake the next locker when it will be dropped.
108    /// `WriteLockOwnedGuard` have a `'static` lifetime, but requires the `Arc<RWLock<T>>` type
109    ///
110    /// # Examples
111    ///
112    /// ```
113    /// use fast_async_mutex::rwlock::RwLock;
114    /// use std::sync::Arc;
115    /// #[tokio::main]
116    /// async fn main() {
117    ///     let mutex = Arc::new(RwLock::new(10));
118    ///     let guard = mutex.read().await;
119    ///     let guard2 = mutex.read().await;
120    ///     assert_eq!(*guard, *guard2);
121    /// }
122    /// ```
123    #[inline]
124    pub fn read_owned(self: &Arc<Self>) -> RwLockReadOwnedGuardFuture<T> {
125        RwLockReadOwnedGuardFuture {
126            mutex: self.clone(),
127            is_realized: false,
128        }
129    }
130
131    #[inline]
132    fn unlock_reader(&self) {
133        if self.readers.fetch_sub(1, Ordering::Release) == 1 {
134            self.inner.unlock()
135        }
136    }
137
138    #[inline]
139    fn add_reader(&self) {
140        self.readers.fetch_add(1, Ordering::Release);
141    }
142
143    #[inline]
144    fn try_acquire_reader(&self) -> bool {
145        self.readers.load(Ordering::Acquire) > 0 || self.inner.try_acquire()
146    }
147}
148
149/// The Simple Write Lock Guard
150/// As long as you have this guard, you have exclusive access to the underlying `T`. The guard internally borrows the RWLock, so the mutex will not be dropped while a guard exists.
151/// The lock is automatically released and waked the next locker whenever the guard is dropped, at which point lock will succeed yet again.
152#[derive(Debug)]
153pub struct RwLockWriteGuard<'a, T: ?Sized> {
154    mutex: &'a RwLock<T>,
155}
156
157#[derive(Debug)]
158pub struct RwLockWriteGuardFuture<'a, T: ?Sized> {
159    mutex: &'a RwLock<T>,
160    is_realized: bool,
161}
162
163/// An owned handle to a held RWLock.
164/// This guard is only available from a RWLock that is wrapped in an `Arc`. It is identical to `WriteLockGuard`, except that rather than borrowing the `RWLock`, it clones the `Arc`, incrementing the reference count. This means that unlike `WriteLockGuard`, it will have the `'static` lifetime.
165/// As long as you have this guard, you have exclusive access to the underlying `T`. The guard internally keeps a reference-couned pointer to the original `RWLock`, so even if the lock goes away, the guard remains valid.
166/// The lock is automatically released and waked the next locker whenever the guard is dropped, at which point lock will succeed yet again.
167#[derive(Debug)]
168pub struct RwLockWriteOwnedGuard<T: ?Sized> {
169    mutex: Arc<RwLock<T>>,
170}
171
172#[derive(Debug)]
173pub struct RwLockWriteOwnedGuardFuture<T: ?Sized> {
174    mutex: Arc<RwLock<T>>,
175    is_realized: bool,
176}
177
178/// The Simple Write Lock Guard
179/// As long as you have this guard, you have shared access to the underlying `T`. The guard internally borrows the `RWLock`, so the mutex will not be dropped while a guard exists.
180/// The lock is automatically released and waked the next locker whenever the guard is dropped, at which point lock will succeed yet again.
181#[derive(Debug)]
182pub struct RwLockReadGuard<'a, T: ?Sized> {
183    mutex: &'a RwLock<T>,
184}
185
186#[derive(Debug)]
187pub struct RwLockReadGuardFuture<'a, T: ?Sized> {
188    mutex: &'a RwLock<T>,
189    is_realized: bool,
190}
191
192/// An owned handle to a held RWLock.
193/// This guard is only available from a RWLock that is wrapped in an `Arc`. It is identical to `WriteLockGuard`, except that rather than borrowing the `RWLock`, it clones the `Arc`, incrementing the reference count. This means that unlike `WriteLockGuard`, it will have the `'static` lifetime.
194/// As long as you have this guard, you have shared access to the underlying `T`. The guard internally keeps a reference-couned pointer to the original `RWLock`, so even if the lock goes away, the guard remains valid.
195/// The lock is automatically released and waked the next locker whenever the guard is dropped, at which point lock will succeed yet again.
196#[derive(Debug)]
197pub struct RwLockReadOwnedGuard<T: ?Sized> {
198    mutex: Arc<RwLock<T>>,
199}
200
201#[derive(Debug)]
202pub struct RwLockReadOwnedGuardFuture<T: ?Sized> {
203    mutex: Arc<RwLock<T>>,
204    is_realized: bool,
205}
206
207impl<'a, T: ?Sized> Future for RwLockWriteGuardFuture<'a, T> {
208    type Output = RwLockWriteGuard<'a, T>;
209
210    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
211        if self.mutex.inner.try_acquire() {
212            self.is_realized = true;
213            Poll::Ready(RwLockWriteGuard { mutex: self.mutex })
214        } else {
215            self.mutex.inner.store_waker(cx.waker());
216            Poll::Pending
217        }
218    }
219}
220
221impl<T: ?Sized> Future for RwLockWriteOwnedGuardFuture<T> {
222    type Output = RwLockWriteOwnedGuard<T>;
223
224    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
225        if self.mutex.inner.try_acquire() {
226            self.is_realized = true;
227            Poll::Ready(RwLockWriteOwnedGuard {
228                mutex: self.mutex.clone(),
229            })
230        } else {
231            self.mutex.inner.store_waker(cx.waker());
232            Poll::Pending
233        }
234    }
235}
236
237impl<'a, T: ?Sized> Future for RwLockReadGuardFuture<'a, T> {
238    type Output = RwLockReadGuard<'a, T>;
239
240    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
241        if self.mutex.try_acquire_reader() {
242            self.is_realized = true;
243            self.mutex.add_reader();
244            Poll::Ready(RwLockReadGuard { mutex: self.mutex })
245        } else {
246            self.mutex.inner.store_waker(cx.waker());
247            Poll::Pending
248        }
249    }
250}
251
252impl<T: ?Sized> Future for RwLockReadOwnedGuardFuture<T> {
253    type Output = RwLockReadOwnedGuard<T>;
254
255    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
256        if self.mutex.try_acquire_reader() {
257            self.is_realized = true;
258            self.mutex.add_reader();
259            Poll::Ready(RwLockReadOwnedGuard {
260                mutex: self.mutex.clone(),
261            })
262        } else {
263            self.mutex.inner.store_waker(cx.waker());
264            Poll::Pending
265        }
266    }
267}
268
269crate::impl_send_sync_rwlock!(
270    RwLock,
271    RwLockReadGuard,
272    RwLockReadOwnedGuard,
273    RwLockWriteGuard,
274    RwLockWriteOwnedGuard
275);
276
277crate::impl_deref_mut!(RwLockWriteGuard, 'a);
278crate::impl_deref_mut!(RwLockWriteOwnedGuard);
279crate::impl_deref!(RwLockReadGuard, 'a);
280crate::impl_deref!(RwLockReadOwnedGuard);
281
282crate::impl_drop_guard!(RwLockWriteGuard, 'a, unlock);
283crate::impl_drop_guard!(RwLockWriteOwnedGuard, unlock);
284crate::impl_drop_guard_self!(RwLockReadGuard, 'a, unlock_reader);
285crate::impl_drop_guard_self!(RwLockReadOwnedGuard, unlock_reader);
286
287crate::impl_drop_guard_future!(RwLockWriteGuardFuture, 'a, unlock);
288crate::impl_drop_guard_future!(RwLockWriteOwnedGuardFuture, unlock);
289crate::impl_drop_guard_future!(RwLockReadGuardFuture, 'a, unlock);
290crate::impl_drop_guard_future!(RwLockReadOwnedGuardFuture, unlock);
291
292#[cfg(test)]
293mod tests {
294    use crate::rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard, RwLockWriteOwnedGuard};
295    use futures::executor::block_on;
296    use futures::{FutureExt, StreamExt, TryStreamExt};
297    use std::ops::AddAssign;
298    use std::sync::Arc;
299    use tokio::time::{sleep, Duration};
300
301    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
302    async fn test_mutex() {
303        let c = RwLock::new(0);
304
305        futures::stream::iter(0..10000)
306            .for_each_concurrent(None, |_| async {
307                let mut co: RwLockWriteGuard<i32> = c.write().await;
308                *co += 1;
309            })
310            .await;
311
312        let co = c.write().await;
313        assert_eq!(*co, 10000)
314    }
315
316    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
317    async fn test_mutex_delay() {
318        let expected_result = 100;
319        let c = RwLock::new(0);
320
321        futures::stream::iter(0..expected_result)
322            .then(|i| c.write().map(move |co| (i, co)))
323            .for_each_concurrent(None, |(i, mut co)| async move {
324                sleep(Duration::from_millis(expected_result - i)).await;
325                *co += 1;
326            })
327            .await;
328
329        let co = c.write().await;
330        assert_eq!(*co, expected_result)
331    }
332
333    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
334    async fn test_owned_mutex() {
335        let c = Arc::new(RwLock::new(0));
336
337        futures::stream::iter(0..10000)
338            .for_each_concurrent(None, |_| async {
339                let mut co: RwLockWriteOwnedGuard<i32> = c.write_owned().await;
340                *co += 1;
341            })
342            .await;
343
344        let co = c.write_owned().await;
345        assert_eq!(*co, 10000)
346    }
347
348    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
349    async fn test_container() {
350        let c = RwLock::new(String::from("lol"));
351
352        let mut co: RwLockWriteGuard<String> = c.write().await;
353        co.add_assign("lol");
354
355        assert_eq!(*co, "lollol");
356    }
357
358    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
359    async fn test_timeout() {
360        let c = RwLock::new(String::from("lol"));
361
362        let co: RwLockWriteGuard<String> = c.write().await;
363
364        futures::stream::iter(0..10000i32)
365            .then(|_| tokio::time::timeout(Duration::from_nanos(1), c.write()))
366            .try_for_each_concurrent(None, |_c| futures::future::ok(()))
367            .await
368            .expect_err("timout must be");
369
370        drop(co);
371
372        let mut co: RwLockWriteGuard<String> = c.write().await;
373        co.add_assign("lol");
374
375        assert_eq!(*co, "lollol");
376    }
377
378    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
379    async fn test_concurrent_reading() {
380        let c = RwLock::new(String::from("lol"));
381
382        let co: RwLockReadGuard<String> = c.read().await;
383
384        futures::stream::iter(0..10000i32)
385            .then(|_| c.read())
386            .inspect(|c| assert_eq!(*co, **c))
387            .for_each_concurrent(None, |_c| futures::future::ready(()))
388            .await;
389
390        assert!(matches!(
391            tokio::time::timeout(Duration::from_millis(1), c.write()).await,
392            Err(_)
393        ));
394
395        let co2: RwLockReadGuard<String> = c.read().await;
396        assert_eq!(*co, *co2);
397    }
398
399    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
400    async fn test_concurrent_reading_writing() {
401        let c = RwLock::new(String::from("lol"));
402
403        let co: RwLockReadGuard<String> = c.read().await;
404        let co2: RwLockReadGuard<String> = c.read().await;
405        assert_eq!(*co, *co2);
406
407        drop(co);
408        drop(co2);
409
410        let mut co: RwLockWriteGuard<String> = c.write().await;
411
412        assert!(matches!(
413            tokio::time::timeout(Duration::from_millis(1), c.read()).await,
414            Err(_)
415        ));
416
417        *co += "lol";
418
419        drop(co);
420
421        let co: RwLockReadGuard<String> = c.read().await;
422        let co2: RwLockReadGuard<String> = c.read().await;
423        assert_eq!(*co, "lollol");
424        assert_eq!(*co, *co2);
425    }
426
427    #[test]
428    fn multithreading_test() {
429        let num = 100;
430        let mutex = Arc::new(RwLock::new(0));
431        let ths: Vec<_> = (0..num)
432            .map(|i| {
433                let mutex = mutex.clone();
434                std::thread::spawn(move || {
435                    block_on(async {
436                        if i % 2 == 0 {
437                            let mut lock = mutex.write().await;
438                            *lock += 1;
439                            drop(lock)
440                        } else {
441                            let lock1 = mutex.read().await;
442                            let lock2 = mutex.read().await;
443                            assert_eq!(*lock1, *lock2);
444                            drop(lock1);
445                            drop(lock2);
446                        }
447                    })
448                })
449            })
450            .collect();
451
452        for thread in ths {
453            thread.join().unwrap();
454        }
455
456        block_on(async {
457            let lock = mutex.read().await;
458            assert_eq!(num / 2, *lock)
459        })
460    }
461}