glommio/sync/
rwlock.rs

1// Unless explicitly stated otherwise all files in this repository are licensed
2// under the MIT/Apache-2.0 License, at your convenience
3
4//! Read-write locks.
5//!
6//! Provides functionality similar to the ['std::sync::RwLock'] except that lock
7//! can not be poisoned.
8//!
9//! # Examples
10//!
11//! ```
12//! use glommio::{sync::RwLock, LocalExecutor};
13//! let lock = RwLock::new(5);
14//! let ex = LocalExecutor::default();
15//!
16//! ex.run(async move {
17//!     // many reader locks can be held at once
18//!     {
19//!         let r1 = lock.read().await.unwrap();
20//!         let r2 = lock.read().await.unwrap();
21//!         assert_eq!(*r1, 5);
22//!         assert_eq!(*r2, 5);
23//!     } // read locks are dropped at this point
24//!
25//!     // only one write lock may be held, however
26//!     {
27//!         let mut w = lock.write().await.unwrap();
28//!         *w += 1;
29//!         assert_eq!(*w, 6);
30//!     } // write lock is dropped here
31//! });
32//! ```
33use core::fmt::Debug;
34use std::{
35    cell::{Ref, RefCell, RefMut},
36    future::Future,
37    ops::{Deref, DerefMut},
38    pin::Pin,
39    task::{Context, Poll, Waker},
40};
41
42use intrusive_collections::{
43    container_of, linked_list::LinkOps, offset_of, Adapter, LinkedList, LinkedListLink, PointerOps,
44};
45
46use crate::{GlommioError, ResourceType};
47use std::{marker::PhantomPinned, ptr::NonNull};
48
49/// A type alias for the result of a lock method which can be suspended.
50pub type LockResult<T> = Result<T, GlommioError<()>>;
51
52/// A type alias for the result of a non-suspending locking method.
53pub type TryLockResult<T> = Result<T, GlommioError<()>>;
54
55#[derive(Debug)]
56struct Waiter<'a, T> {
57    node: WaiterNode,
58    rw: &'a RwLock<T>,
59}
60
61#[derive(Copy, Clone, Eq, PartialEq, Debug)]
62enum WaiterKind {
63    Reader,
64    Writer,
65}
66
67#[derive(Debug)]
68struct WaiterNode {
69    kind: WaiterKind,
70    link: LinkedListLink,
71    waker: RefCell<Option<Waker>>,
72
73    // waiter node can not be `Unpin` so its pointer could be used inside intrusive
74    // collection, it also can not outlive the container which is guaranteed by the
75    // Waiter lifetime bound to the RwLock which is container of all Waiters.
76    _p: PhantomPinned,
77}
78
79struct WaiterPointerOps;
80
81unsafe impl PointerOps for WaiterPointerOps {
82    type Value = WaiterNode;
83    type Pointer = NonNull<WaiterNode>;
84
85    unsafe fn from_raw(&self, value: *const Self::Value) -> Self::Pointer {
86        NonNull::new(value as *mut Self::Value).expect("Pointer to the value can not be null")
87    }
88
89    fn into_raw(&self, ptr: Self::Pointer) -> *const Self::Value {
90        ptr.as_ptr() as *const Self::Value
91    }
92}
93
94struct WaiterAdapter {
95    pointers_ops: WaiterPointerOps,
96    link_ops: LinkOps,
97}
98
99impl WaiterAdapter {
100    fn new() -> Self {
101        WaiterAdapter {
102            pointers_ops: WaiterPointerOps,
103            link_ops: LinkOps,
104        }
105    }
106}
107
108/// Adapter which converts pointer to link to the pointer to the object which is
109/// hold in collection and vice versa
110unsafe impl Adapter for WaiterAdapter {
111    type LinkOps = LinkOps;
112    type PointerOps = WaiterPointerOps;
113
114    unsafe fn get_value(
115        &self,
116        link: <Self::LinkOps as intrusive_collections::LinkOps>::LinkPtr,
117    ) -> *const <Self::PointerOps as PointerOps>::Value {
118        container_of!(link.as_ptr(), WaiterNode, link)
119    }
120
121    unsafe fn get_link(
122        &self,
123        value: *const <Self::PointerOps as PointerOps>::Value,
124    ) -> <Self::LinkOps as intrusive_collections::LinkOps>::LinkPtr {
125        if value.is_null() {
126            panic!("Passed in pointer to the value can not be null");
127        }
128
129        let ptr = (value as *const u8).add(offset_of!(WaiterNode, link));
130        //null check is performed above
131        core::ptr::NonNull::new_unchecked(ptr as *mut _)
132    }
133
134    fn link_ops(&self) -> &Self::LinkOps {
135        &self.link_ops
136    }
137
138    fn link_ops_mut(&mut self) -> &mut Self::LinkOps {
139        &mut self.link_ops
140    }
141
142    fn pointer_ops(&self) -> &Self::PointerOps {
143        &self.pointers_ops
144    }
145}
146
147/// A reader-writer lock
148///
149/// This type of lock allows a number of readers or at most one writer at any
150/// point in time. The write portion of this lock typically allows modification
151/// of the underlying data (exclusive access) and the read portion of this lock
152/// typically allows for read-only access (shared access).
153///
154/// An `RwLock` will allow any number of readers to acquire the
155/// lock as long as a writer is not holding the lock.
156///
157/// The priority policy of the lock is based on FIFO policy. Fibers will be
158/// granted access in the order in which access to the lock was requested.
159///
160/// Lock is not reentrant, yet. That means that two subsequent calls to request
161/// write access to the lock will lead to deadlock problem.
162///
163/// The type parameter `T` represents the data that this lock protects. The RAII
164/// guards returned from the locking methods implement [`Deref`] (and
165/// [`DerefMut`] for the `write` methods) to allow access to the content of the
166/// lock.
167///
168///
169/// # Examples
170///
171/// ```
172/// use glommio::{sync::RwLock, LocalExecutor};
173///
174/// let lock = RwLock::new(5);
175/// let ex = LocalExecutor::default();
176///
177/// ex.run(async move {
178///     // many reader locks can be held at once
179///     {
180///         let r1 = lock.read().await.unwrap();
181///         let r2 = lock.read().await.unwrap();
182///         assert_eq!(*r1, 5);
183///         assert_eq!(*r2, 5);
184///     } // read locks are dropped at this point
185///
186///     // only one write lock may be held, however
187///     {
188///         let mut w = lock.write().await.unwrap();
189///         *w += 1;
190///         assert_eq!(*w, 6);
191///     } // write lock is dropped here
192/// });
193/// ```
194#[derive(Debug)]
195pub struct RwLock<T> {
196    state: RefCell<State>,
197    // Option is needed only to implement into_inner method so that is absolutely safe
198    // to unwrap it by ref. during the execution
199    value: RefCell<Option<T>>,
200}
201
202#[derive(Debug)]
203struct State {
204    // Number of granted write access
205    // There can be only single writer, but we use u32 type to support reentrancy fot the lock
206    // in future
207    writers: u32,
208    // Number of granted read accesses
209    readers: u32,
210
211    // Number of queued requests to get write access
212    queued_writers: u32,
213
214    waiters_queue: LinkedList<WaiterAdapter>,
215    closed: bool,
216}
217
218impl<'a, T> Waiter<'a, T> {
219    fn new(kind: WaiterKind, rw: &'a RwLock<T>) -> Self {
220        Waiter {
221            rw,
222            node: WaiterNode {
223                kind,
224                link: LinkedListLink::new(),
225                waker: RefCell::new(None),
226                _p: PhantomPinned,
227            },
228        }
229    }
230
231    fn remove_from_waiting_queue(node: Pin<&mut WaiterNode>, rw: &mut State) {
232        if node.link.is_linked() {
233            let mut cursor = unsafe {
234                rw.waiters_queue
235                    .cursor_mut_from_ptr(node.get_unchecked_mut())
236            };
237
238            if cursor.remove().is_none() {
239                panic!("Waiter has to be linked into the list of waiting futures");
240            }
241        }
242    }
243
244    fn register_in_waiting_queue(
245        node: Pin<&mut WaiterNode>,
246        rw: &mut State,
247        waker: Waker,
248        kind: WaiterKind,
249    ) {
250        *node.waker.borrow_mut() = Some(waker);
251
252        if node.link.is_linked() {
253            return;
254        }
255
256        if kind == WaiterKind::Writer {
257            rw.queued_writers += 1;
258        }
259
260        // It is safe to skip null check here because we use object reference
261        rw.waiters_queue
262            .push_back(unsafe { NonNull::new_unchecked(node.get_unchecked_mut()) });
263    }
264}
265
266impl<'a, T> Future for Waiter<'a, T> {
267    type Output = LockResult<()>;
268
269    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
270        let mut rw = self.rw.state.borrow_mut();
271        let future_mut = unsafe { self.get_unchecked_mut() };
272        let pinned_node = unsafe { Pin::new_unchecked(&mut future_mut.node) };
273
274        match pinned_node.kind {
275            WaiterKind::Writer => {
276                if rw.try_write()? {
277                    Self::remove_from_waiting_queue(pinned_node, &mut rw);
278                    Poll::Ready(Ok(()))
279                } else {
280                    Self::register_in_waiting_queue(
281                        pinned_node,
282                        &mut rw,
283                        cx.waker().clone(),
284                        WaiterKind::Writer,
285                    );
286                    Poll::Pending
287                }
288            }
289
290            WaiterKind::Reader => {
291                if rw.try_read()? {
292                    Self::remove_from_waiting_queue(pinned_node, &mut rw);
293                    Poll::Ready(Ok(()))
294                } else {
295                    Self::register_in_waiting_queue(
296                        pinned_node,
297                        &mut rw,
298                        cx.waker().clone(),
299                        WaiterKind::Reader,
300                    );
301                    Poll::Pending
302                }
303            }
304        }
305    }
306}
307
308impl<'a, T> Drop for Waiter<'a, T> {
309    fn drop(&mut self) {
310        if self.node.link.is_linked() {
311            // If node is lined them future is already pinned
312            let pinned_node = unsafe { Pin::new_unchecked(&mut self.node) };
313            Self::remove_from_waiting_queue(pinned_node, &mut self.rw.state.borrow_mut())
314        }
315    }
316}
317
318impl State {
319    fn new() -> Self {
320        State {
321            writers: 0,
322            readers: 0,
323            queued_writers: 0,
324            closed: false,
325
326            waiters_queue: LinkedList::new(WaiterAdapter::new()),
327        }
328    }
329
330    fn try_read(&mut self) -> LockResult<bool> {
331        if self.closed {
332            return Err(GlommioError::Closed(ResourceType::RwLock));
333        }
334
335        debug_assert!(!(self.readers > 0 && self.writers > 0));
336
337        if self.writers == 0 {
338            self.readers += 1;
339            return Ok(true);
340        }
341
342        Ok(false)
343    }
344
345    fn try_write(&mut self) -> LockResult<bool> {
346        if self.closed {
347            return Err(GlommioError::Closed(ResourceType::RwLock));
348        }
349
350        debug_assert!(!(self.readers > 0 && self.writers > 0));
351
352        if self.readers == 0 && self.writers == 0 {
353            self.writers += 1;
354            return Ok(true);
355        }
356
357        Ok(false)
358    }
359}
360
361/// RAII structure used to release the shared read access of a lock when
362/// dropped.
363///
364/// This structure is created by the [`read`] and [`try_read`] methods on
365/// [`RwLock`].
366///
367/// [`read`]: RwLock::read
368/// [`try_read`]: RwLock::try_read
369#[derive(Debug)]
370#[must_use = "if unused the RwLock will immediately unlock"]
371pub struct RwLockReadGuard<'a, T> {
372    rw: &'a RwLock<T>,
373    value_ref: Ref<'a, Option<T>>,
374}
375
376impl<'a, T> Deref for RwLockReadGuard<'a, T> {
377    type Target = T;
378
379    fn deref(&self) -> &Self::Target {
380        self.value_ref.as_ref().unwrap()
381    }
382}
383
384impl<'a, T> Drop for RwLockReadGuard<'a, T> {
385    fn drop(&mut self) {
386        let mut state = self.rw.state.borrow_mut();
387
388        if !state.closed {
389            debug_assert!(state.readers > 0);
390            state.readers -= 1;
391
392            RwLock::<T>::wake_up_fibers(&mut state);
393        }
394    }
395}
396
397/// RAII structure used to release the exclusive write access of a lock when
398/// dropped.
399///
400/// This structure is created by the [`write`] and [`try_write`] methods
401/// on [`RwLock`].
402///
403/// [`write`]: RwLock::write
404/// [`try_write`]: RwLock::try_write
405#[must_use = "if unused the RwLock will immediately unlock"]
406#[derive(Debug)]
407pub struct RwLockWriteGuard<'a, T> {
408    rw: &'a RwLock<T>,
409    value_ref: RefMut<'a, Option<T>>,
410}
411
412impl<'a, T> Deref for RwLockWriteGuard<'a, T> {
413    type Target = T;
414
415    fn deref(&self) -> &Self::Target {
416        self.value_ref.as_ref().unwrap()
417    }
418}
419
420impl<'a, T> DerefMut for RwLockWriteGuard<'a, T> {
421    fn deref_mut(&mut self) -> &mut Self::Target {
422        let state = self.rw.state.borrow();
423
424        if state.closed {
425            panic!("Related RwLock is already closed");
426        }
427
428        self.value_ref.as_mut().unwrap()
429    }
430}
431
432impl<'a, T> Drop for RwLockWriteGuard<'a, T> {
433    fn drop(&mut self) {
434        let mut state = self.rw.state.borrow_mut();
435
436        if !state.closed {
437            debug_assert!(state.writers > 0);
438            state.writers -= 1;
439
440            RwLock::<T>::wake_up_fibers(&mut state);
441        }
442    }
443}
444
445impl<T> RwLock<T> {
446    /// Creates a new instance of an `RwLock<T>` which is unlocked.
447    ///
448    /// # Examples
449    ///
450    /// ```
451    /// use glommio::sync::RwLock;
452    ///
453    /// let lock = RwLock::new(5);
454    /// ```
455    pub fn new(value: T) -> Self {
456        RwLock {
457            state: RefCell::new(State::new()),
458            value: RefCell::new(Some(value)),
459        }
460    }
461
462    /// Returns a mutable reference to the underlying data.
463    ///
464    /// Since this call borrows the `RwLock` mutably, no actual locking needs to
465    /// take place -- the mutable borrow statically guarantees no locks exist.
466    ///
467    /// # Errors
468    ///
469    /// This function will return an error if the RwLock is closed.
470    ///
471    /// # Examples
472    ///
473    /// ```
474    /// use glommio::{sync::RwLock, LocalExecutor};
475    ///
476    /// let mut lock = RwLock::new(0);
477    /// let ex = LocalExecutor::default();
478    ///
479    /// ex.run(async move {
480    ///     *lock.get_mut().unwrap() = 10;
481    ///     assert_eq!(*lock.read().await.unwrap(), 10);
482    /// });
483    /// ```
484    pub fn get_mut(&mut self) -> LockResult<&mut T> {
485        let state = self.state.borrow();
486        if state.closed {
487            return Err(GlommioError::Closed(ResourceType::RwLock));
488        }
489
490        Ok(unsafe { &mut *(self.value.borrow_mut().as_mut().unwrap() as *mut _) })
491    }
492
493    /// Locks this RwLock with shared read access, suspending the current fiber
494    /// until lock can be acquired.
495    ///
496    /// The calling fiber will be suspended until there are no more writers
497    /// which hold the lock. There may be other readers currently inside the
498    /// lock when this method returns.
499    ///
500    /// Returns an RAII guard which will release this fiber's shared access
501    /// once guard is dropped.
502    ///
503    /// # Errors
504    ///
505    /// This function will return an error if the RwLock is closed.
506    ///
507    /// # Examples
508    ///
509    /// ```
510    /// use futures::future::join;
511    /// use glommio::{sync::RwLock, LocalExecutor};
512    /// use std::rc::Rc;
513    ///
514    /// let lock = Rc::new(RwLock::new(1));
515    /// let c_lock = lock.clone();
516    ///
517    /// let ex = LocalExecutor::default();
518    ///
519    /// ex.run(async move {
520    ///     let first_reader = glommio::spawn_local(async move {
521    ///         let n = lock.read().await.unwrap();
522    ///         assert_eq!(*n, 1);
523    ///     })
524    ///     .detach();
525    ///
526    ///     let second_reader = glommio::spawn_local(async move {
527    ///         let r = c_lock.read().await;
528    ///         assert!(r.is_ok());
529    ///     })
530    ///     .detach();
531    ///
532    ///     join(first_reader, second_reader).await;
533    /// });
534    /// ```
535    pub async fn read(&self) -> LockResult<RwLockReadGuard<'_, T>> {
536        let waiter = {
537            let mut state = self.state.borrow_mut();
538            let try_result = state.try_read()?;
539
540            if try_result {
541                return Ok(RwLockReadGuard {
542                    rw: self,
543                    value_ref: self.value.borrow(),
544                });
545            }
546
547            Waiter::new(WaiterKind::Reader, self)
548        };
549
550        waiter.await.map(|_| RwLockReadGuard {
551            rw: self,
552            value_ref: self.value.borrow(),
553        })
554    }
555
556    /// Locks this RwLock with exclusive write access, suspending the current
557    /// task until RwLock can be acquired.
558    ///
559    /// This function will not return while other writers or other readers
560    /// currently have access to the lock.
561    ///
562    /// Returns an RAII guard which will drop the write access of this RwLock
563    /// when dropped.
564    ///
565    /// # Errors
566    ///
567    /// This function will return an error if the RwLock is closed.
568    ///
569    /// # Examples
570    ///
571    /// ```
572    /// use glommio::{sync::RwLock, LocalExecutor};
573    ///
574    /// let lock = RwLock::new(1);
575    /// let ex = LocalExecutor::default();
576    ///
577    /// ex.run(async move {
578    ///     let mut n = lock.write().await.unwrap();
579    ///     *n = 2;
580    ///
581    ///     assert!(lock.try_read().is_err());
582    /// });
583    /// ```
584    pub async fn write(&self) -> LockResult<RwLockWriteGuard<'_, T>> {
585        let waiter = {
586            let mut state = self.state.borrow_mut();
587            let try_result = state.try_write()?;
588
589            if try_result {
590                return Ok(RwLockWriteGuard {
591                    rw: self,
592                    value_ref: self.value.borrow_mut(),
593                });
594            }
595
596            Waiter::new(WaiterKind::Writer, self)
597        };
598        waiter.await?;
599
600        Ok(RwLockWriteGuard {
601            rw: self,
602            value_ref: self.value.borrow_mut(),
603        })
604    }
605
606    /// Attempts to acquire this RwLock with shared read access.
607    ///
608    /// If the access could not be granted at this time, then `Err` is returned.
609    /// Otherwise, an RAII guard is returned which will release the shared
610    /// access when guard is dropped.
611    ///
612    /// This function does not suspend.
613    ///
614    /// # Errors
615    ///
616    /// This function will return an error if the RwLock is closed.
617    ///
618    /// # Examples
619    ///
620    /// ```
621    /// use glommio::sync::RwLock;
622    ///
623    /// let lock = RwLock::new(1);
624    ///
625    /// match lock.try_read() {
626    ///     Ok(n) => assert_eq!(*n, 1),
627    ///     Err(_) => unreachable!(),
628    /// };
629    /// ```
630    pub fn try_read(&self) -> TryLockResult<RwLockReadGuard<'_, T>> {
631        let mut state = self.state.borrow_mut();
632        let try_result = state.try_read()?;
633
634        if try_result {
635            return Ok(RwLockReadGuard {
636                rw: self,
637                value_ref: self.value.borrow(),
638            });
639        }
640
641        Err(GlommioError::WouldBlock(ResourceType::RwLock))
642    }
643
644    /// Attempts to lock this RwLock with exclusive write access.
645    ///
646    /// If the lock could not be acquired at this time, then `Err` is returned.
647    /// Otherwise, an RAII guard is returned which will release the lock when
648    /// guard is dropped.
649    ///
650    /// This function does not suspend.
651    ///
652    /// # Errors
653    ///
654    /// This function will return an error if the RwLock is closed.
655    ///
656    /// # Examples
657    ///
658    /// ```
659    /// use glommio::{sync::RwLock, LocalExecutor};
660    ///
661    /// let lock = RwLock::new(1);
662    /// let ex = LocalExecutor::default();
663    ///
664    /// ex.run(async move {
665    ///     let n = lock.read().await.unwrap();
666    ///     assert_eq!(*n, 1);
667    ///
668    ///     assert!(lock.try_write().is_err());
669    /// });
670    /// ```
671    pub fn try_write(&self) -> TryLockResult<RwLockWriteGuard<'_, T>> {
672        let mut state = self.state.borrow_mut();
673        let try_result = state.try_write()?;
674
675        if try_result {
676            return Ok(RwLockWriteGuard {
677                rw: self,
678                value_ref: self.value.borrow_mut(),
679            });
680        }
681
682        Err(GlommioError::WouldBlock(ResourceType::RwLock))
683    }
684
685    /// Indicates whether current RwLock is closed. Once lock is closed all
686    /// subsequent calls to the methods which requests lock access will
687    /// return `Err`.
688    ///
689    /// # Examples
690    ///
691    /// ```
692    /// use glommio::sync::RwLock;
693    ///
694    /// let lock = RwLock::new(());
695    ///
696    /// lock.close();
697    ///
698    /// assert!(lock.is_closed());
699    /// ```
700    pub fn is_closed(&self) -> bool {
701        self.state.borrow().closed
702    }
703
704    /// Closes current RwLock. Once lock is closed all being hold accesses will
705    /// be released and all subsequent calls to the methods to request lock
706    /// access will return `Err`.
707    ///
708    /// # Errors
709    ///
710    /// This function will return an error if RwLock is still hold by any
711    /// reader(s) or writer
712    ///
713    /// # Examples
714    ///
715    ///```
716    /// use glommio::{
717    ///     sync::{RwLock, Semaphore},
718    ///     LocalExecutor,
719    /// };
720    /// use std::{cell::RefCell, rc::Rc};
721    ///
722    /// let lock = Rc::new(RwLock::new(()));
723    /// let c_lock = lock.clone();
724    ///
725    /// let ex = LocalExecutor::default();
726    /// ex.run(async move {
727    ///     let lock = RwLock::new(());
728    ///     let guard = lock.read().await.unwrap();
729    ///     assert!(lock.close().is_err());
730    ///
731    ///     drop(guard);
732    ///     lock.close().unwrap();
733    ///
734    ///     assert!(lock.read().await.is_err());
735    /// });
736    /// ```
737    pub fn close(&self) -> LockResult<()> {
738        let mut state = self.state.borrow_mut();
739        if state.closed {
740            return Ok(());
741        }
742
743        if state.writers > 0 || state.readers > 0 {
744            return Err(GlommioError::CanNotBeClosed(
745                ResourceType::RwLock,
746                "Lock is still held by fiber(s)",
747            ));
748        }
749
750        state.closed = true;
751
752        Self::wake_up_fibers(&mut state);
753        Ok(())
754    }
755
756    /// Consumes this [`RwLock`], returning the underlying data.
757    ///
758    ///
759    /// # Errors
760    ///
761    /// This function will return an error if the [`RwLock`] is closed.
762    ///
763    /// # Examples
764    ///
765    /// ```
766    /// use glommio::{sync::RwLock, LocalExecutor};
767    ///
768    /// let lock = RwLock::new(String::new());
769    /// let ex = LocalExecutor::default();
770    ///
771    /// ex.run(async move {
772    ///     {
773    ///         let mut s = lock.write().await.unwrap();
774    ///         *s = "modified".to_owned();
775    ///     }
776    ///
777    ///     assert_eq!(lock.into_inner().unwrap(), "modified");
778    /// });
779    /// ```
780    pub fn into_inner(self) -> LockResult<T> {
781        let state = self.state.borrow();
782        if state.closed {
783            return Err(GlommioError::Closed(ResourceType::RwLock));
784        }
785
786        drop(state);
787
788        self.close().unwrap();
789
790        let value = self.value.borrow_mut().take().unwrap();
791        Ok(value)
792    }
793
794    fn wake_up_fibers(rw: &mut State) {
795        // Created with assumption in mind that waker will trigger delayed execution of
796        // fibers such behaviour supports users intuition about tasks.
797
798        // All tasks waked up in the fair order (in the order of acquiring of the lock)
799        // if that matters. That allows to avoid lock starvation as much as
800        // possible.
801        if rw.readers == 0 && rw.writers == 0 {
802            if rw.queued_writers == 0 {
803                // Only readers are waiting in the queue and no one holding a lock
804                // wake up all of them
805                Self::wake_up_all_fibers(rw);
806            } else {
807                // There are some writers waiting into the queue so wake up all readers and
808                // single writer no one holding the lock, so likely they will be
809                // executed in order so all will have a chance to proceed
810                Self::wake_up_readers_and_first_writer(rw);
811            }
812        } else if rw.writers == 0 {
813            if rw.queued_writers == 0 {
814                // Only readers in the waiting queue and some readers holding the lock
815                // wake up all of them
816                Self::wake_up_all_fibers(rw);
817            } else {
818                // There are both readers and writers in the queue
819                // so only readers are awakened
820                Self::wake_up_all_readers_till_first_writer(rw);
821            }
822        }
823        // The only option left that some writers still holding the lock
824        // so no reason to try to wake up anyone.
825    }
826
827    fn wake_up_all_readers_till_first_writer(rw: &mut State) {
828        let mut cursor = rw.waiters_queue.front_mut();
829        while !cursor.is_null() {
830            {
831                let node = unsafe { Pin::new_unchecked(cursor.get().unwrap()) };
832                if node.kind == WaiterKind::Writer {
833                    break;
834                }
835
836                let waker = node.waker.borrow_mut().take();
837                if let Some(waker) = waker {
838                    waker.wake();
839                } else {
840                    panic!("Future was linked in waiting list without an a waker");
841                }
842            }
843
844            cursor.remove();
845        }
846    }
847
848    fn wake_up_readers_and_first_writer(rw: &mut State) {
849        let mut cursor = rw.waiters_queue.front_mut();
850        // We need to remove writer from the list too,
851        // so we use flag instead of execution of break
852        let mut only_readers = true;
853
854        while !cursor.is_null() && only_readers {
855            {
856                let node = unsafe { Pin::new_unchecked(cursor.get().unwrap()) };
857
858                let waker = node.waker.borrow_mut().take();
859                if let Some(waker) = waker {
860                    waker.wake();
861                } else {
862                    panic!("Future was linked in waiting list without an a waker");
863                }
864
865                if node.kind == WaiterKind::Writer {
866                    rw.queued_writers -= 1;
867                    only_readers = false;
868                }
869            }
870
871            cursor.remove();
872        }
873    }
874
875    fn wake_up_all_fibers(rw: &mut State) {
876        let mut cursor = rw.waiters_queue.front_mut();
877        while !cursor.is_null() {
878            let node = cursor.remove().unwrap();
879            let waker = (unsafe { node.as_ref() }).waker.borrow_mut().take();
880            if let Some(waker) = waker {
881                waker.wake();
882            } else {
883                panic!("Future was linked in waiting list without an a waker");
884            }
885        }
886    }
887}
888
889impl<T: Default> Default for RwLock<T> {
890    fn default() -> Self {
891        Self::new(T::default())
892    }
893}
894
895impl<T> Drop for RwLock<T> {
896    fn drop(&mut self) {
897        //Lifetime annotation prohibits guards to outlive RwLock so such unwrap is
898        // safe.
899        self.close().unwrap();
900        assert!(self.state.borrow().waiters_queue.is_empty());
901    }
902}
903
904#[cfg(test)]
905mod test {
906    use super::*;
907    use crate::{sync::rwlock::RwLock, timer::Timer, LocalExecutor};
908    use std::time::Duration;
909
910    use crate::sync::Semaphore;
911    use std::{cell::RefCell, rc::Rc};
912
913    #[derive(Eq, PartialEq, Debug)]
914    struct NonCopy(i32);
915
916    #[test]
917    fn test_smoke() {
918        test_executor!(async move {
919            let lock = RwLock::new(());
920            drop(lock.read().await.unwrap());
921            drop(lock.write().await.unwrap());
922            #[allow(clippy::mixed_read_write_in_expression)]
923            drop((lock.read().await.unwrap(), lock.read().await.unwrap()));
924            drop(lock.read().await.unwrap());
925        });
926    }
927
928    #[test]
929    fn test_frob() {
930        test_executor!(async move {
931            const N: u32 = 10;
932            const M: usize = 1000;
933
934            let r = Rc::new(RwLock::new(()));
935            let mut futures = Vec::new();
936
937            for _ in 0..N {
938                let r = r.clone();
939
940                let f = crate::spawn_local(async move {
941                    for _ in 0..M {
942                        if fastrand::u32(0..N) == 0 {
943                            drop(r.write().await.unwrap());
944                        } else {
945                            drop(r.read().await.unwrap());
946                        }
947                    }
948                });
949
950                futures.push(f);
951            }
952
953            join_all(futures).await;
954        });
955    }
956
957    #[test]
958    fn test_close_w() {
959        test_executor!(async move {
960            let rc = Rc::new(RwLock::new(1));
961            let rc2 = rc.clone();
962
963            crate::spawn_local(async move {
964                let _lock = rc2.write().await.unwrap();
965                assert!(rc2.close().is_err());
966            })
967            .await;
968        });
969    }
970
971    #[test]
972    fn test_close_r() {
973        test_executor!(async move {
974            let rc = Rc::new(RwLock::new(1));
975            let rc2 = rc.clone();
976
977            crate::spawn_local(async move {
978                let _lock = rc2.read().await.unwrap();
979                assert!(rc2.close().is_err());
980            })
981            .await;
982        });
983    }
984
985    #[test]
986    fn test_global_lock() {
987        test_executor!(async move {
988            let rc = Rc::new(RwLock::new(0));
989            let rc2 = rc.clone();
990
991            let s = Rc::new(Semaphore::new(0));
992            let s2 = s.clone();
993
994            let mut fibers = Vec::new();
995            fibers.push(crate::spawn_local(async move {
996                let mut lock = rc2.write().await.unwrap();
997
998                for _ in 0..10 {
999                    crate::executor().yield_task_queue_now().await;
1000                    let tmp = *lock;
1001                    *lock -= 1;
1002                    crate::executor().yield_task_queue_now().await;
1003                    *lock = tmp + 1;
1004                }
1005
1006                s2.signal(1);
1007            }));
1008
1009            for _ in 0..5 {
1010                let rc3 = rc.clone();
1011
1012                fibers.push(crate::spawn_local(async move {
1013                    let lock = rc3.read().await.unwrap();
1014                    assert!(*lock == 0 || *lock == 10);
1015
1016                    crate::executor().yield_task_queue_now().await;
1017                }));
1018            }
1019
1020            join_all(fibers).await;
1021            s.acquire(1).await.unwrap();
1022            let lock = rc.read().await.unwrap();
1023            assert_eq!(*lock, 10);
1024        });
1025    }
1026
1027    #[test]
1028    fn test_local_lock() {
1029        test_executor!(async move {
1030            let rc = Rc::new(RwLock::new(0));
1031            let rc2 = rc.clone();
1032
1033            let s = Rc::new(Semaphore::new(0));
1034            let s2 = s.clone();
1035
1036            let mut fibers = Vec::new();
1037            fibers.push(crate::spawn_local(async move {
1038                for _ in 0..10 {
1039                    let mut lock = rc2.write().await.unwrap();
1040                    let tmp = *lock;
1041                    *lock -= 1;
1042
1043                    crate::executor().yield_task_queue_now().await;
1044
1045                    *lock = tmp + 1;
1046                }
1047
1048                s2.signal(1);
1049            }));
1050
1051            for _ in 0..5 {
1052                let rc3 = rc.clone();
1053
1054                fibers.push(crate::spawn_local(async move {
1055                    let lock = rc3.read().await.unwrap();
1056                    assert!(*lock >= 0);
1057
1058                    crate::executor().yield_task_queue_now().await;
1059                }));
1060            }
1061
1062            join_all(fibers).await;
1063            s.acquire(1).await.unwrap();
1064            let lock = rc.read().await.unwrap();
1065            assert_eq!(*lock, 10);
1066        });
1067    }
1068
1069    #[test]
1070    fn test_ping_pong() {
1071        test_executor!(async move {
1072            const ITERATIONS: i32 = 10;
1073
1074            let ball = Rc::new(RwLock::new(0));
1075            let ball2 = ball.clone();
1076            let ball3 = ball.clone();
1077
1078            let mut fibers = Vec::new();
1079
1080            let pinger = crate::spawn_local(async move {
1081                let mut prev = -1;
1082                loop {
1083                    //give a room for other fibers to participate
1084                    crate::executor().yield_task_queue_now().await;
1085
1086                    let mut lock = ball2.write().await.unwrap();
1087                    if *lock == ITERATIONS {
1088                        break;
1089                    }
1090
1091                    if *lock % 2 == 0 {
1092                        *lock += 1;
1093
1094                        if prev >= 0 {
1095                            assert_eq!(prev + 2, *lock);
1096                        }
1097
1098                        prev = *lock;
1099                    }
1100                }
1101            });
1102
1103            let ponger = crate::spawn_local(async move {
1104                let mut prev = -1;
1105                loop {
1106                    //give a room for other fibers to participate
1107                    crate::executor().yield_task_queue_now().await;
1108
1109                    let mut lock = ball3.write().await.unwrap();
1110                    if *lock == ITERATIONS {
1111                        break;
1112                    }
1113
1114                    if *lock % 2 == 1 {
1115                        *lock += 1;
1116
1117                        if prev >= 0 {
1118                            assert_eq!(prev + 2, *lock);
1119                        }
1120
1121                        prev = *lock;
1122                    }
1123                }
1124            });
1125
1126            fibers.push(pinger);
1127            fibers.push(ponger);
1128
1129            for _ in 0..12 {
1130                let ball = ball.clone();
1131                let reader = crate::spawn_local(async move {
1132                    let mut prev = -1;
1133                    loop {
1134                        //give a room for other fibers to participate
1135                        crate::executor().yield_task_queue_now().await;
1136                        let lock = ball.read().await.unwrap();
1137
1138                        if *lock == ITERATIONS {
1139                            break;
1140                        }
1141
1142                        assert!(prev <= *lock);
1143                        prev = *lock;
1144                    }
1145                });
1146                fibers.push(reader);
1147            }
1148
1149            join_all(fibers).await;
1150        });
1151    }
1152
1153    #[test]
1154    fn test_try_write() {
1155        test_executor!(async move {
1156            let lock = RwLock::new(());
1157            let read_guard = lock.read().await.unwrap();
1158
1159            let write_result = lock.try_write();
1160            match write_result {
1161                Err(GlommioError::WouldBlock(ResourceType::RwLock)) => (),
1162                Ok(_) => unreachable!("try_write should not succeed while read_guard is in scope"),
1163                Err(_) => unreachable!("unexpected error"),
1164            }
1165
1166            drop(read_guard);
1167        });
1168    }
1169
1170    #[test]
1171    fn test_try_read() {
1172        test_executor!(async move {
1173            let lock = RwLock::new(());
1174            let read_guard = lock.write().await.unwrap();
1175
1176            let write_result = lock.try_read();
1177            match write_result {
1178                Err(GlommioError::WouldBlock(ResourceType::RwLock)) => (),
1179                Ok(_) => unreachable!("try_read should not succeed while read_guard is in scope"),
1180                Err(_) => unreachable!("unexpected error"),
1181            }
1182
1183            drop(read_guard);
1184        });
1185    }
1186
1187    #[test]
1188    fn test_into_inner() {
1189        let lock = RwLock::new(NonCopy(10));
1190        assert_eq!(lock.into_inner().unwrap(), NonCopy(10));
1191    }
1192
1193    #[test]
1194    fn test_into_inner_drop() {
1195        struct Foo(Rc<RefCell<usize>>);
1196
1197        impl Drop for Foo {
1198            fn drop(&mut self) {
1199                *self.0.borrow_mut() += 1;
1200            }
1201        }
1202
1203        let num_drop = Rc::new(RefCell::new(0));
1204        let lock = RwLock::new(Foo(num_drop.clone()));
1205        assert_eq!(*num_drop.borrow(), 0);
1206
1207        {
1208            let _inner = lock.into_inner().unwrap();
1209            assert_eq!(*_inner.0.borrow(), 0);
1210        }
1211
1212        assert_eq!(*num_drop.borrow(), 1);
1213    }
1214
1215    #[test]
1216    fn test_into_inner_close() {
1217        let lock = RwLock::new(());
1218        lock.close().unwrap();
1219
1220        assert!(lock.is_closed());
1221        let into_inner_result = lock.into_inner();
1222        match into_inner_result {
1223            Err(_) => (),
1224            Ok(_) => panic!("into_inner of closed lock is Ok"),
1225        }
1226    }
1227
1228    #[test]
1229    fn test_get_mut() {
1230        let mut lock = RwLock::new(NonCopy(10));
1231        *lock.get_mut().unwrap() = NonCopy(20);
1232        assert_eq!(lock.into_inner().unwrap(), NonCopy(20));
1233    }
1234
1235    #[test]
1236    fn test_get_mut_close() {
1237        let mut lock = RwLock::new(());
1238        lock.close().unwrap();
1239
1240        assert!(lock.is_closed());
1241        let get_mut_result = lock.get_mut();
1242        match get_mut_result {
1243            Err(_) => (),
1244            Ok(_) => panic!("get_mut of closed lock is Ok"),
1245        }
1246    }
1247
1248    #[test]
1249    fn rwlock_overflow() {
1250        let ex = LocalExecutor::default();
1251
1252        let lock = Rc::new(RwLock::new(()));
1253        let c_lock = lock.clone();
1254
1255        let cond = Rc::new(RefCell::new(0));
1256        let c_cond = cond.clone();
1257
1258        let semaphore = Rc::new(Semaphore::new(0));
1259        let c_semaphore = semaphore.clone();
1260
1261        ex.run(async move {
1262            crate::spawn_local(async move {
1263                c_semaphore.acquire(1).await.unwrap();
1264
1265                let _g = c_lock.read().await.unwrap();
1266                *c_cond.borrow_mut() = 1;
1267
1268                wait_on_cond!(c_cond, 2);
1269
1270                for _ in 0..100 {
1271                    Timer::new(Duration::from_micros(100)).await;
1272                }
1273
1274                let mut waiters_count = 0;
1275                for _ in &c_lock.state.borrow().waiters_queue {
1276                    waiters_count += 1;
1277                }
1278
1279                assert_eq!(waiters_count, 1);
1280            })
1281            .detach();
1282
1283            semaphore.signal(1);
1284            wait_on_cond!(cond, 1);
1285            *cond.borrow_mut() = 2;
1286            let _ = lock.write().await.unwrap();
1287        })
1288    }
1289
1290    #[test]
1291    fn rwlock_reentrant() {
1292        let ex = LocalExecutor::default();
1293        ex.run(async {
1294            let lock = RwLock::new(());
1295            let _guard = lock.write().await.unwrap();
1296            assert!(lock.try_write().is_err());
1297        });
1298    }
1299}