fast_async_mutex/
rwlock_ordered.rs

1use crate::inner::OrderedInner;
2use std::fmt::Debug;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::Arc;
7use std::task::{Context, Poll};
8
9/// The Ordered RW Lock will be locking all reads, which starting after write and unlocking them only when write will realize.
10/// It may be slow down the reads speed, but decrease time to write on systems, where it is critical.
11///
12/// **BUT RW Lock has some limitations. You should avoid acquiring the second reading before realizing first inside the one future.
13/// Because it can happen that between your readings a write from another thread will acquire the mutex, and you will get a deadlock.**
14#[derive(Debug)]
15pub struct OrderedRwLock<T: ?Sized> {
16    readers: AtomicUsize,
17    inner: OrderedInner<T>,
18}
19
20impl<T> OrderedRwLock<T> {
21    /// Create a new `OrderedRWLock`
22    #[inline]
23    pub const fn new(data: T) -> OrderedRwLock<T> {
24        OrderedRwLock {
25            readers: AtomicUsize::new(0),
26            inner: OrderedInner::new(data),
27        }
28    }
29}
30
31impl<T: ?Sized> OrderedRwLock<T> {
32    /// Acquires the mutex for are write.
33    ///
34    /// Returns a guard that releases the mutex and wake the next locker when it will be dropped.
35    ///
36    /// # Examples
37    ///
38    /// ```
39    /// use fast_async_mutex::rwlock_ordered::OrderedRwLock;
40    ///
41    /// #[tokio::main]
42    /// async fn main() {
43    ///     let mutex = OrderedRwLock::new(10);
44    ///     let mut guard = mutex.write().await;
45    ///     *guard += 1;
46    ///     assert_eq!(*guard, 11);
47    /// }
48    /// ```
49    #[inline]
50    pub fn write(&self) -> OrderedRwLockWriteGuardFuture<T> {
51        OrderedRwLockWriteGuardFuture {
52            mutex: &self,
53            id: self.inner.generate_id(),
54            is_realized: false,
55        }
56    }
57
58    /// Acquires the mutex for are write.
59    ///
60    /// Returns a guard that releases the mutex and wake the next locker when it will be dropped.
61    /// `WriteLockOwnedGuard` have a `'static` lifetime, but requires the `Arc<RWLock<T>>` type
62    ///
63    /// # Examples
64    ///
65    /// ```
66    /// use fast_async_mutex::rwlock_ordered::OrderedRwLock;
67    /// use std::sync::Arc;
68    /// #[tokio::main]
69    /// async fn main() {
70    ///     let mutex = Arc::new(OrderedRwLock::new(10));
71    ///     let mut guard = mutex.write_owned().await;
72    ///     *guard += 1;
73    ///     assert_eq!(*guard, 11);
74    /// }
75    /// ```
76    #[inline]
77    pub fn write_owned(self: &Arc<Self>) -> OrderedRwLockWriteOwnedGuardFuture<T> {
78        OrderedRwLockWriteOwnedGuardFuture {
79            mutex: self.clone(),
80            id: self.inner.generate_id(),
81            is_realized: false,
82        }
83    }
84
85    /// Acquires the mutex for are read.
86    ///
87    /// Returns a guard that releases the mutex and wake the next locker when it will be dropped.
88    ///
89    /// # Examples
90    ///
91    /// ```
92    /// use fast_async_mutex::rwlock_ordered::OrderedRwLock;
93    ///
94    /// #[tokio::main]
95    /// async fn main() {
96    ///     let mutex = OrderedRwLock::new(10);
97    ///     let guard = mutex.read().await;
98    ///     let guard2 = mutex.read().await;
99    ///     assert_eq!(*guard, *guard2);
100    /// }
101    /// ```
102    #[inline]
103    pub fn read(&self) -> OrderedRwLockReadGuardFuture<T> {
104        OrderedRwLockReadGuardFuture {
105            mutex: &self,
106            id: self.inner.generate_id(),
107            is_realized: false,
108        }
109    }
110
111    /// Acquires the mutex for are write.
112    ///
113    /// Returns a guard that releases the mutex and wake the next locker when it will be dropped.
114    /// `WriteLockOwnedGuard` have a `'static` lifetime, but requires the `Arc<RWLock<T>>` type
115    ///
116    /// # Examples
117    ///
118    /// ```
119    /// use fast_async_mutex::rwlock_ordered::OrderedRwLock;
120    /// use std::sync::Arc;
121    /// #[tokio::main]
122    /// async fn main() {
123    ///     let mutex = Arc::new(OrderedRwLock::new(10));
124    ///     let guard = mutex.read().await;
125    ///     let guard2 = mutex.read().await;
126    ///     assert_eq!(*guard, *guard2);
127    /// }
128    /// ```
129    #[inline]
130    pub fn read_owned(self: &Arc<Self>) -> OrderedRwLockReadOwnedGuardFuture<T> {
131        OrderedRwLockReadOwnedGuardFuture {
132            mutex: self.clone(),
133            id: self.inner.generate_id(),
134            is_realized: false,
135        }
136    }
137
138    #[inline]
139    fn unlock_reader(&self) {
140        self.readers.fetch_sub(1, Ordering::Release);
141        self.inner.unlock()
142    }
143
144    #[inline]
145    fn add_reader(&self) {
146        self.readers.fetch_add(1, Ordering::Release);
147    }
148
149    #[inline]
150    pub fn try_acquire_reader(&self, id: usize) -> bool {
151        id == self.inner.current.load(Ordering::Acquire) + self.readers.load(Ordering::Acquire)
152    }
153}
154
155/// The Simple Write Lock Guard
156/// As long as you have this guard, you have exclusive access to the underlying `T`. The guard internally borrows the RWLock, so the mutex will not be dropped while a guard exists.
157/// The lock is automatically released and waked the next locker whenever the guard is dropped, at which point lock will succeed yet again.
158#[derive(Debug)]
159pub struct OrderedRwLockWriteGuard<'a, T: ?Sized> {
160    mutex: &'a OrderedRwLock<T>,
161}
162
163#[derive(Debug)]
164pub struct OrderedRwLockWriteGuardFuture<'a, T: ?Sized> {
165    mutex: &'a OrderedRwLock<T>,
166    id: usize,
167    is_realized: bool,
168}
169
170/// An owned handle to a held RWLock.
171/// This guard is only available from a RWLock that is wrapped in an `Arc`. It is identical to `WriteLockGuard`, except that rather than borrowing the `RWLock`, it clones the `Arc`, incrementing the reference count. This means that unlike `WriteLockGuard`, it will have the `'static` lifetime.
172/// As long as you have this guard, you have exclusive access to the underlying `T`. The guard internally keeps a reference-couned pointer to the original `RWLock`, so even if the lock goes away, the guard remains valid.
173/// The lock is automatically released and waked the next locker whenever the guard is dropped, at which point lock will succeed yet again.
174#[derive(Debug)]
175pub struct OrderedRwLockWriteOwnedGuard<T: ?Sized> {
176    mutex: Arc<OrderedRwLock<T>>,
177}
178
179#[derive(Debug)]
180pub struct OrderedRwLockWriteOwnedGuardFuture<T: ?Sized> {
181    mutex: Arc<OrderedRwLock<T>>,
182    id: usize,
183    is_realized: bool,
184}
185
186/// The Simple Write Lock Guard
187/// As long as you have this guard, you have shared access to the underlying `T`. The guard internally borrows the `RWLock`, so the mutex will not be dropped while a guard exists.
188/// The lock is automatically released and waked the next locker whenever the guard is dropped, at which point lock will succeed yet again.
189#[derive(Debug)]
190pub struct OrderedRwLockReadGuard<'a, T: ?Sized> {
191    mutex: &'a OrderedRwLock<T>,
192}
193
194#[derive(Debug)]
195pub struct OrderedRwLockReadGuardFuture<'a, T: ?Sized> {
196    mutex: &'a OrderedRwLock<T>,
197    id: usize,
198    is_realized: bool,
199}
200
201/// An owned handle to a held RWLock.
202/// This guard is only available from a RWLock that is wrapped in an `Arc`. It is identical to `WriteLockGuard`, except that rather than borrowing the `RWLock`, it clones the `Arc`, incrementing the reference count. This means that unlike `WriteLockGuard`, it will have the `'static` lifetime.
203/// As long as you have this guard, you have shared access to the underlying `T`. The guard internally keeps a reference-couned pointer to the original `RWLock`, so even if the lock goes away, the guard remains valid.
204/// The lock is automatically released and waked the next locker whenever the guard is dropped, at which point lock will succeed yet again.
205#[derive(Debug)]
206pub struct OrderedRwLockReadOwnedGuard<T: ?Sized> {
207    mutex: Arc<OrderedRwLock<T>>,
208}
209
210#[derive(Debug)]
211pub struct OrderedRwLockReadOwnedGuardFuture<T: ?Sized> {
212    mutex: Arc<OrderedRwLock<T>>,
213    id: usize,
214    is_realized: bool,
215}
216
217impl<'a, T: ?Sized> Future for OrderedRwLockWriteGuardFuture<'a, T> {
218    type Output = OrderedRwLockWriteGuard<'a, T>;
219
220    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
221        if self.mutex.inner.try_acquire(self.id) {
222            self.is_realized = true;
223            Poll::Ready(OrderedRwLockWriteGuard { mutex: self.mutex })
224        } else {
225            self.mutex.inner.store_waker(cx.waker());
226            Poll::Pending
227        }
228    }
229}
230
231impl<T: ?Sized> Future for OrderedRwLockWriteOwnedGuardFuture<T> {
232    type Output = OrderedRwLockWriteOwnedGuard<T>;
233
234    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
235        if self.mutex.inner.try_acquire(self.id) {
236            self.is_realized = true;
237            Poll::Ready(OrderedRwLockWriteOwnedGuard {
238                mutex: self.mutex.clone(),
239            })
240        } else {
241            self.mutex.inner.store_waker(cx.waker());
242            Poll::Pending
243        }
244    }
245}
246
247impl<'a, T: ?Sized> Future for OrderedRwLockReadGuardFuture<'a, T> {
248    type Output = OrderedRwLockReadGuard<'a, T>;
249
250    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
251        if self.mutex.try_acquire_reader(self.id) {
252            self.is_realized = true;
253            self.mutex.add_reader();
254            Poll::Ready(OrderedRwLockReadGuard { mutex: &self.mutex })
255        } else {
256            self.mutex.inner.store_waker(cx.waker());
257            Poll::Pending
258        }
259    }
260}
261
262impl<T: ?Sized> Future for OrderedRwLockReadOwnedGuardFuture<T> {
263    type Output = OrderedRwLockReadOwnedGuard<T>;
264
265    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
266        if self.mutex.try_acquire_reader(self.id) {
267            self.is_realized = true;
268            self.mutex.add_reader();
269            Poll::Ready(OrderedRwLockReadOwnedGuard {
270                mutex: self.mutex.clone(),
271            })
272        } else {
273            self.mutex.inner.store_waker(cx.waker());
274            Poll::Pending
275        }
276    }
277}
278
279crate::impl_send_sync_rwlock!(
280    OrderedRwLock,
281    OrderedRwLockReadGuard,
282    OrderedRwLockReadOwnedGuard,
283    OrderedRwLockWriteGuard,
284    OrderedRwLockWriteOwnedGuard
285);
286
287crate::impl_deref_mut!(OrderedRwLockWriteGuard, 'a);
288crate::impl_deref_mut!(OrderedRwLockWriteOwnedGuard);
289crate::impl_deref!(OrderedRwLockReadGuard, 'a);
290crate::impl_deref!(OrderedRwLockReadOwnedGuard);
291
292crate::impl_drop_guard!(OrderedRwLockWriteGuard, 'a, unlock);
293crate::impl_drop_guard!(OrderedRwLockWriteOwnedGuard, unlock);
294crate::impl_drop_guard_self!(OrderedRwLockReadGuard, 'a, unlock_reader);
295crate::impl_drop_guard_self!(OrderedRwLockReadOwnedGuard, unlock_reader);
296
297crate::impl_drop_guard_future!(OrderedRwLockWriteGuardFuture, 'a, unlock);
298crate::impl_drop_guard_future!(OrderedRwLockWriteOwnedGuardFuture, unlock);
299crate::impl_drop_guard_future!(OrderedRwLockReadGuardFuture, 'a, unlock);
300crate::impl_drop_guard_future!(OrderedRwLockReadOwnedGuardFuture, unlock);
301
302#[cfg(test)]
303mod tests {
304    use crate::rwlock_ordered::{
305        OrderedRwLock, OrderedRwLockReadGuard, OrderedRwLockWriteGuard,
306        OrderedRwLockWriteOwnedGuard,
307    };
308    use futures::executor::block_on;
309    use futures::{FutureExt, StreamExt, TryStreamExt};
310    use std::ops::AddAssign;
311    use std::sync::atomic::AtomicUsize;
312    use std::sync::Arc;
313    use tokio::time::{sleep, Duration};
314
315    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
316    async fn test_mutex() {
317        let c = OrderedRwLock::new(0);
318
319        futures::stream::iter(0..10000)
320            .for_each_concurrent(None, |_| async {
321                let mut co: OrderedRwLockWriteGuard<i32> = c.write().await;
322                *co += 1;
323            })
324            .await;
325
326        let co = c.write().await;
327        assert_eq!(*co, 10000)
328    }
329
330    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
331    async fn test_mutex_delay() {
332        let expected_result = 100;
333        let c = OrderedRwLock::new(0);
334
335        futures::stream::iter(0..expected_result)
336            .then(|i| c.write().map(move |co| (i, co)))
337            .for_each_concurrent(None, |(i, mut co)| async move {
338                sleep(Duration::from_millis(expected_result - i)).await;
339                *co += 1;
340            })
341            .await;
342
343        let co = c.write().await;
344        assert_eq!(*co, expected_result)
345    }
346
347    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
348    async fn test_owned_mutex() {
349        let c = Arc::new(OrderedRwLock::new(0));
350
351        futures::stream::iter(0..10000)
352            .for_each_concurrent(None, |_| async {
353                let mut co: OrderedRwLockWriteOwnedGuard<i32> = c.write_owned().await;
354                *co += 1;
355            })
356            .await;
357
358        let co = c.write_owned().await;
359        assert_eq!(*co, 10000)
360    }
361
362    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
363    async fn test_container() {
364        let c = OrderedRwLock::new(String::from("lol"));
365
366        let mut co: OrderedRwLockWriteGuard<String> = c.write().await;
367        co.add_assign("lol");
368
369        assert_eq!(*co, "lollol");
370    }
371
372    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
373    async fn test_overflow() {
374        let mut c = OrderedRwLock::new(String::from("lol"));
375
376        c.inner.state = AtomicUsize::new(usize::max_value());
377        c.inner.current = AtomicUsize::new(usize::max_value());
378
379        let mut co: OrderedRwLockWriteGuard<String> = c.write().await;
380        co.add_assign("lol");
381
382        assert_eq!(*co, "lollol");
383    }
384
385    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
386    async fn test_timeout() {
387        let c = OrderedRwLock::new(String::from("lol"));
388
389        let co: OrderedRwLockWriteGuard<String> = c.write().await;
390
391        futures::stream::iter(0..10000i32)
392            .then(|_| tokio::time::timeout(Duration::from_nanos(1), c.write()))
393            .try_for_each_concurrent(None, |_c| futures::future::ok(()))
394            .await
395            .expect_err("timout must be");
396
397        drop(co);
398
399        let mut co: OrderedRwLockWriteGuard<String> = c.write().await;
400        co.add_assign("lol");
401
402        assert_eq!(*co, "lollol");
403    }
404
405    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
406    async fn test_concurrent_reading() {
407        let c = OrderedRwLock::new(String::from("lol"));
408
409        let co: OrderedRwLockReadGuard<String> = c.read().await;
410
411        futures::stream::iter(0..10000i32)
412            .then(|_| c.read())
413            .inspect(|c| assert_eq!(*co, **c))
414            .for_each_concurrent(None, |_c| futures::future::ready(()))
415            .await;
416
417        assert!(matches!(
418            tokio::time::timeout(Duration::from_millis(1), c.write()).await,
419            Err(_)
420        ));
421
422        let co2: OrderedRwLockReadGuard<String> = c.read().await;
423        assert_eq!(*co, *co2);
424    }
425
426    #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
427    async fn test_concurrent_reading_writing() {
428        let c = OrderedRwLock::new(String::from("lol"));
429
430        let co: OrderedRwLockReadGuard<String> = c.read().await;
431        let co2: OrderedRwLockReadGuard<String> = c.read().await;
432        assert_eq!(*co, *co2);
433
434        drop(co);
435        drop(co2);
436
437        let mut co: OrderedRwLockWriteGuard<String> = c.write().await;
438
439        assert!(matches!(
440            tokio::time::timeout(Duration::from_millis(1), c.read()).await,
441            Err(_)
442        ));
443
444        *co += "lol";
445
446        drop(co);
447
448        let co: OrderedRwLockReadGuard<String> = c.read().await;
449        let co2: OrderedRwLockReadGuard<String> = c.read().await;
450        assert_eq!(*co, "lollol");
451        assert_eq!(*co, *co2);
452    }
453
454    #[test]
455    fn multithreading_test() {
456        let num = 100;
457        let mutex = Arc::new(OrderedRwLock::new(0));
458        let ths: Vec<_> = (0..num)
459            .map(|i| {
460                let mutex = mutex.clone();
461                std::thread::spawn(move || {
462                    block_on(async {
463                        if i % 2 == 0 {
464                            let mut lock = mutex.write().await;
465                            *lock += 1;
466                            drop(lock);
467                        } else {
468                            let _lock = mutex.read().await;
469                        }
470                    })
471                })
472            })
473            .collect();
474
475        for thread in ths {
476            thread.join().unwrap();
477        }
478
479        block_on(async {
480            let lock = mutex.read().await;
481            assert_eq!(num / 2, *lock)
482        })
483    }
484}