pinned/rwlock/
mod.rs

1use std::cell::{Cell, RefCell};
2use std::task::Poll;
3
4use futures::future::poll_fn;
5
6use crate::utils::yield_now;
7
8mod error;
9mod read_guard;
10mod wakers;
11mod write_guard;
12pub use error::*;
13pub use read_guard::*;
14use wakers::Wakers;
15pub use write_guard::*;
16
17/// An asynchronous reader-writer lock.
18///
19/// This type of lock allows a number of readers or at most one writer at any point in time. The
20/// write portion of this lock typically allows modification of the underlying data (exclusive
21/// access) and the read portion of this lock typically allows for read-only access (shared access).
22///
23/// The acquisition order of this lock is not guaranteed and depending on the runtime's
24/// implementation and preference of any used polling combinators.
25///
26/// # Examples
27///
28/// ```
29/// # #[tokio::main]
30/// # async fn main() {
31/// # use pinned::RwLock;
32/// let lock = RwLock::new(5);
33///
34/// // many reader locks can be held at once
35/// {
36///     let r1 = lock.read().await;
37///     let r2 = lock.read().await;
38///     assert_eq!(*r1, 5);
39///     assert_eq!(*r2, 5);
40/// } // read locks are dropped at this point
41///
42/// // only one write lock may be held, however
43/// {
44///     let mut w = lock.write().await;
45///     *w += 1;
46///     assert_eq!(*w, 6);
47/// } // write lock is dropped here
48/// # }
49/// ```
50#[derive(Debug)]
51pub struct RwLock<T: ?Sized> {
52    wakers: Wakers,
53    val: RefCell<T>,
54}
55
56impl<T> RwLock<T> {
57    /// Creates a new `RwLock` containing value `T`
58    pub fn new(val: T) -> Self {
59        Self {
60            wakers: Wakers::new(),
61            val: RefCell::new(val),
62        }
63    }
64
65    /// Consumes the lock, returning the underlying data.
66    pub fn into_inner(self) -> T {
67        self.val.into_inner()
68    }
69}
70
71impl<T> RwLock<T>
72where
73    T: ?Sized,
74{
75    /// Attempts to acquire this `RwLock` with shared read access.
76    ///
77    /// If the access couldn’t be acquired immediately, returns [`TryLockError`]. Otherwise, an RAII
78    /// guard is returned which will release read access when dropped.
79    ///
80    /// This function does not block.
81    ///
82    /// This function does not provide any guarantees with respect to the ordering of whether
83    /// contentious readers or writers will acquire the lock first.
84    pub fn try_read(&self) -> TryLockResult<RwLockReadGuard<'_, T>> {
85        let read_inner = self.val.try_borrow().map_err(|_| TryLockError::new())?;
86        let wake_guard = self.wakers.wake_guard();
87
88        Ok(RwLockReadGuard {
89            val: read_inner,
90            wake_guard,
91        })
92    }
93
94    /// Attempts to lock this `RwLock` with exclusive write access.
95    ///
96    /// If the lock could not be acquired immediately, returns [`TryLockError`]. Otherwise, an RAII
97    /// guard is returned which will release the lock when it is dropped.
98    ///
99    /// This function does not block.
100    ///
101    /// This function does not provide any guarantees with respect to the ordering of whether
102    /// contentious readers or writers will acquire the lock first.
103    pub fn try_write(&self) -> TryLockResult<RwLockWriteGuard<'_, T>> {
104        let write_inner = self.val.try_borrow_mut().map_err(|_| TryLockError::new())?;
105        let wake_guard = self.wakers.wake_guard();
106
107        Ok(RwLockWriteGuard {
108            val: write_inner,
109            wake_guard,
110        })
111    }
112
113    /// Wait for the next lock release.
114    async fn wait(&self) {
115        let awaited = Cell::new(false);
116
117        poll_fn(move |cx| {
118            if awaited.get() {
119                return Poll::Ready(());
120            }
121
122            awaited.set(true);
123
124            self.wakers.push(cx.waker().clone());
125            Poll::Pending
126        })
127        .await;
128    }
129
130    /// Locks the current `RwLock` with shared read access, causing the current task to yield
131    /// until the lock has been acquired.
132    ///
133    /// This method does not provide any guarantees with respect to the ordering of whether
134    /// contentious readers or writers will acquire the lock first.
135    ///
136    /// Returns an RAII guard which will release this task's shared access once it is dropped.
137    pub async fn read(&self) -> RwLockReadGuard<'_, T> {
138        // We yield to provide some fairness over the current runtime / polling combinator so that
139        // one task is not starving the lock.
140        yield_now().await;
141
142        loop {
143            if let Ok(m) = self.try_read() {
144                return m;
145            }
146            self.wait().await;
147        }
148    }
149
150    /// Locks the current `RwLock` with exclusive write access, causing the current task to yield
151    /// until the lock has been acquired.
152    ///
153    /// This method does not provide any guarantees with respect to the ordering of whether
154    /// contentious readers or writers will acquire the lock first.
155    ///
156    /// Returns an RAII guard which will drop the write access once it is dropped.
157    pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
158        // We yield to provide some fairness over the current runtime / polling combinator so that
159        // one task is not starving the lock.
160        yield_now().await;
161
162        loop {
163            if let Ok(m) = self.try_write() {
164                return m;
165            }
166            self.wait().await;
167        }
168    }
169
170    /// Returns a mutable reference to the underlying data.
171    ///
172    /// This call borrows `RwLock` mutably (at compile-time) so there is no need for dynamic checks.
173    pub fn get_mut(&mut self) -> &mut T {
174        self.val.get_mut()
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    //! tests, mostly borrowed from tokio, adapted to this implementation.
181
182    use std::rc::Rc;
183    use std::time::Duration;
184
185    use futures::future::FutureExt;
186    use futures::{pin_mut, poll};
187    use tokio::test;
188    use tokio::time::timeout;
189
190    use super::*;
191
192    static SEC_5: Duration = Duration::from_secs(5);
193
194    #[test]
195    async fn into_inner() {
196        let rwlock = RwLock::new(42);
197        assert_eq!(rwlock.into_inner(), 42);
198    }
199
200    #[test]
201    async fn read_shared() {
202        timeout(SEC_5, async {
203            let rwlock = RwLock::new(100);
204
205            let _r1 = rwlock.read().await;
206            let _r2 = rwlock.read().await;
207        })
208        .await
209        .expect("timed out")
210    }
211
212    #[test]
213    async fn write_shared_pending() {
214        timeout(SEC_5, async {
215            let rwlock = RwLock::new(100);
216
217            let _r1 = rwlock.read().await;
218            timeout(Duration::from_millis(500), rwlock.write())
219                .await
220                .expect_err("not timed out?");
221        })
222        .await
223        .expect("timed out");
224    }
225
226    #[test]
227    async fn read_exclusive_pending() {
228        timeout(SEC_5, async {
229            let rwlock = RwLock::new(100);
230
231            let _w1 = rwlock.write().await;
232            timeout(Duration::from_millis(500), rwlock.read())
233                .await
234                .expect_err("not timed out?");
235        })
236        .await
237        .expect("timed out");
238    }
239
240    #[test]
241    async fn write_exclusive_pending() {
242        timeout(SEC_5, async {
243            let rwlock = RwLock::new(100);
244
245            let _w1 = rwlock.write().await;
246            timeout(Duration::from_millis(500), rwlock.write())
247                .await
248                .expect_err("not timed out?");
249        })
250        .await
251        .expect("timed out");
252    }
253
254    #[test]
255    async fn write_shared_drop() {
256        timeout(SEC_5, async {
257            let rwlock = Rc::new(RwLock::new(100));
258
259            let rwlock = rwlock.clone();
260            let w1 = rwlock.write().await;
261
262            let try_write_2 = rwlock.write();
263            pin_mut!(try_write_2);
264
265            matches!(poll!(&mut try_write_2), Poll::Pending);
266            matches!(poll!(&mut try_write_2), Poll::Pending);
267            matches!(poll!(&mut try_write_2), Poll::Pending);
268
269            drop(w1);
270
271            try_write_2.await;
272        })
273        .await
274        .expect("timed out");
275    }
276
277    #[test]
278    async fn write_pending_read_shared_ready() {
279        timeout(SEC_5, async {
280            let rwlock = RwLock::new(100);
281
282            let _r1 = rwlock.read().await;
283            let _r2 = rwlock.read().await;
284
285            let try_write_1 = rwlock.write();
286            pin_mut!(try_write_1);
287
288            matches!(poll!(&mut try_write_1), Poll::Pending);
289            matches!(poll!(&mut try_write_1), Poll::Pending);
290            matches!(poll!(&mut try_write_1), Poll::Pending);
291            let _r3 = rwlock.read().await;
292
293            timeout(Duration::from_millis(500), try_write_1)
294                .await
295                .expect_err("not timed out?");
296        })
297        .await
298        .expect("timed out");
299    }
300
301    #[test]
302    async fn read_uncontested() {
303        let rwlock = RwLock::new(100);
304        let result = *rwlock.read().await;
305
306        assert_eq!(result, 100);
307    }
308
309    #[test]
310    async fn write_uncontested() {
311        let rwlock = RwLock::new(100);
312        let mut result = rwlock.write().await;
313        *result += 50;
314        assert_eq!(*result, 150);
315    }
316
317    #[test]
318    async fn write_order() {
319        let rwlock = RwLock::<Vec<u32>>::new(vec![]);
320        let fut2 = rwlock.write().map(|mut guard| guard.push(2));
321        let fut1 = rwlock.write().map(|mut guard| guard.push(1));
322        fut1.await;
323        fut2.await;
324
325        let g = rwlock.read().await;
326        assert_eq!(*g, vec![1, 2]);
327    }
328
329    #[test]
330    async fn try_write() {
331        let lock = RwLock::new(0);
332        let read_guard = lock.read().await;
333        assert!(lock.try_write().is_err());
334        drop(read_guard);
335        assert!(lock.try_write().is_ok());
336    }
337
338    #[test]
339    async fn try_read_try_write() {
340        let lock: RwLock<usize> = RwLock::new(15);
341
342        {
343            let rg1 = lock.try_read().unwrap();
344            assert_eq!(*rg1, 15);
345
346            assert!(lock.try_write().is_err());
347
348            let rg2 = lock.try_read().unwrap();
349            assert_eq!(*rg2, 15)
350        }
351
352        {
353            let mut wg = lock.try_write().unwrap();
354            *wg = 1515;
355
356            assert!(lock.try_read().is_err())
357        }
358
359        assert_eq!(*lock.try_read().unwrap(), 1515);
360    }
361}