1use std::cell::{Cell, RefCell};
2use std::task::Poll;
3
4use futures::future::poll_fn;
5
6use crate::utils::yield_now;
7
8mod error;
9mod read_guard;
10mod wakers;
11mod write_guard;
12pub use error::*;
13pub use read_guard::*;
14use wakers::Wakers;
15pub use write_guard::*;
16
17#[derive(Debug)]
51pub struct RwLock<T: ?Sized> {
52 wakers: Wakers,
53 val: RefCell<T>,
54}
55
56impl<T> RwLock<T> {
57 pub fn new(val: T) -> Self {
59 Self {
60 wakers: Wakers::new(),
61 val: RefCell::new(val),
62 }
63 }
64
65 pub fn into_inner(self) -> T {
67 self.val.into_inner()
68 }
69}
70
71impl<T> RwLock<T>
72where
73 T: ?Sized,
74{
75 pub fn try_read(&self) -> TryLockResult<RwLockReadGuard<'_, T>> {
85 let read_inner = self.val.try_borrow().map_err(|_| TryLockError::new())?;
86 let wake_guard = self.wakers.wake_guard();
87
88 Ok(RwLockReadGuard {
89 val: read_inner,
90 wake_guard,
91 })
92 }
93
94 pub fn try_write(&self) -> TryLockResult<RwLockWriteGuard<'_, T>> {
104 let write_inner = self.val.try_borrow_mut().map_err(|_| TryLockError::new())?;
105 let wake_guard = self.wakers.wake_guard();
106
107 Ok(RwLockWriteGuard {
108 val: write_inner,
109 wake_guard,
110 })
111 }
112
113 async fn wait(&self) {
115 let awaited = Cell::new(false);
116
117 poll_fn(move |cx| {
118 if awaited.get() {
119 return Poll::Ready(());
120 }
121
122 awaited.set(true);
123
124 self.wakers.push(cx.waker().clone());
125 Poll::Pending
126 })
127 .await;
128 }
129
130 pub async fn read(&self) -> RwLockReadGuard<'_, T> {
138 yield_now().await;
141
142 loop {
143 if let Ok(m) = self.try_read() {
144 return m;
145 }
146 self.wait().await;
147 }
148 }
149
150 pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
158 yield_now().await;
161
162 loop {
163 if let Ok(m) = self.try_write() {
164 return m;
165 }
166 self.wait().await;
167 }
168 }
169
170 pub fn get_mut(&mut self) -> &mut T {
174 self.val.get_mut()
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use std::rc::Rc;
183 use std::time::Duration;
184
185 use futures::future::FutureExt;
186 use futures::{pin_mut, poll};
187 use tokio::test;
188 use tokio::time::timeout;
189
190 use super::*;
191
192 static SEC_5: Duration = Duration::from_secs(5);
193
194 #[test]
195 async fn into_inner() {
196 let rwlock = RwLock::new(42);
197 assert_eq!(rwlock.into_inner(), 42);
198 }
199
200 #[test]
201 async fn read_shared() {
202 timeout(SEC_5, async {
203 let rwlock = RwLock::new(100);
204
205 let _r1 = rwlock.read().await;
206 let _r2 = rwlock.read().await;
207 })
208 .await
209 .expect("timed out")
210 }
211
212 #[test]
213 async fn write_shared_pending() {
214 timeout(SEC_5, async {
215 let rwlock = RwLock::new(100);
216
217 let _r1 = rwlock.read().await;
218 timeout(Duration::from_millis(500), rwlock.write())
219 .await
220 .expect_err("not timed out?");
221 })
222 .await
223 .expect("timed out");
224 }
225
226 #[test]
227 async fn read_exclusive_pending() {
228 timeout(SEC_5, async {
229 let rwlock = RwLock::new(100);
230
231 let _w1 = rwlock.write().await;
232 timeout(Duration::from_millis(500), rwlock.read())
233 .await
234 .expect_err("not timed out?");
235 })
236 .await
237 .expect("timed out");
238 }
239
240 #[test]
241 async fn write_exclusive_pending() {
242 timeout(SEC_5, async {
243 let rwlock = RwLock::new(100);
244
245 let _w1 = rwlock.write().await;
246 timeout(Duration::from_millis(500), rwlock.write())
247 .await
248 .expect_err("not timed out?");
249 })
250 .await
251 .expect("timed out");
252 }
253
254 #[test]
255 async fn write_shared_drop() {
256 timeout(SEC_5, async {
257 let rwlock = Rc::new(RwLock::new(100));
258
259 let rwlock = rwlock.clone();
260 let w1 = rwlock.write().await;
261
262 let try_write_2 = rwlock.write();
263 pin_mut!(try_write_2);
264
265 matches!(poll!(&mut try_write_2), Poll::Pending);
266 matches!(poll!(&mut try_write_2), Poll::Pending);
267 matches!(poll!(&mut try_write_2), Poll::Pending);
268
269 drop(w1);
270
271 try_write_2.await;
272 })
273 .await
274 .expect("timed out");
275 }
276
277 #[test]
278 async fn write_pending_read_shared_ready() {
279 timeout(SEC_5, async {
280 let rwlock = RwLock::new(100);
281
282 let _r1 = rwlock.read().await;
283 let _r2 = rwlock.read().await;
284
285 let try_write_1 = rwlock.write();
286 pin_mut!(try_write_1);
287
288 matches!(poll!(&mut try_write_1), Poll::Pending);
289 matches!(poll!(&mut try_write_1), Poll::Pending);
290 matches!(poll!(&mut try_write_1), Poll::Pending);
291 let _r3 = rwlock.read().await;
292
293 timeout(Duration::from_millis(500), try_write_1)
294 .await
295 .expect_err("not timed out?");
296 })
297 .await
298 .expect("timed out");
299 }
300
301 #[test]
302 async fn read_uncontested() {
303 let rwlock = RwLock::new(100);
304 let result = *rwlock.read().await;
305
306 assert_eq!(result, 100);
307 }
308
309 #[test]
310 async fn write_uncontested() {
311 let rwlock = RwLock::new(100);
312 let mut result = rwlock.write().await;
313 *result += 50;
314 assert_eq!(*result, 150);
315 }
316
317 #[test]
318 async fn write_order() {
319 let rwlock = RwLock::<Vec<u32>>::new(vec![]);
320 let fut2 = rwlock.write().map(|mut guard| guard.push(2));
321 let fut1 = rwlock.write().map(|mut guard| guard.push(1));
322 fut1.await;
323 fut2.await;
324
325 let g = rwlock.read().await;
326 assert_eq!(*g, vec![1, 2]);
327 }
328
329 #[test]
330 async fn try_write() {
331 let lock = RwLock::new(0);
332 let read_guard = lock.read().await;
333 assert!(lock.try_write().is_err());
334 drop(read_guard);
335 assert!(lock.try_write().is_ok());
336 }
337
338 #[test]
339 async fn try_read_try_write() {
340 let lock: RwLock<usize> = RwLock::new(15);
341
342 {
343 let rg1 = lock.try_read().unwrap();
344 assert_eq!(*rg1, 15);
345
346 assert!(lock.try_write().is_err());
347
348 let rg2 = lock.try_read().unwrap();
349 assert_eq!(*rg2, 15)
350 }
351
352 {
353 let mut wg = lock.try_write().unwrap();
354 *wg = 1515;
355
356 assert!(lock.try_read().is_err())
357 }
358
359 assert_eq!(*lock.try_read().unwrap(), 1515);
360 }
361}