use std::cell::{Cell, RefCell};
use std::task::Poll;
use futures::future::poll_fn;
use crate::utils::yield_now;
mod error;
mod read_guard;
mod wakers;
mod write_guard;
pub use error::*;
pub use read_guard::*;
use wakers::Wakers;
pub use write_guard::*;
#[derive(Debug)]
pub struct RwLock<T: ?Sized> {
wakers: Wakers,
val: RefCell<T>,
}
impl<T> RwLock<T> {
pub fn new(val: T) -> Self {
Self {
wakers: Wakers::new(),
val: RefCell::new(val),
}
}
pub fn into_inner(self) -> T {
self.val.into_inner()
}
}
impl<T> RwLock<T>
where
T: ?Sized,
{
pub fn try_read(&self) -> TryLockResult<RwLockReadGuard<'_, T>> {
let read_inner = self.val.try_borrow().map_err(|_| TryLockError::new())?;
let wake_guard = self.wakers.wake_guard();
Ok(RwLockReadGuard {
val: read_inner,
wake_guard,
})
}
pub fn try_write(&self) -> TryLockResult<RwLockWriteGuard<'_, T>> {
let write_inner = self.val.try_borrow_mut().map_err(|_| TryLockError::new())?;
let wake_guard = self.wakers.wake_guard();
Ok(RwLockWriteGuard {
val: write_inner,
wake_guard,
})
}
async fn wait(&self) {
let awaited = Cell::new(false);
poll_fn(move |cx| {
if awaited.get() {
return Poll::Ready(());
}
awaited.set(true);
self.wakers.push(cx.waker().clone());
Poll::Pending
})
.await;
}
pub async fn read(&self) -> RwLockReadGuard<'_, T> {
yield_now().await;
loop {
if let Ok(m) = self.try_read() {
return m;
}
self.wait().await;
}
}
pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
yield_now().await;
loop {
if let Ok(m) = self.try_write() {
return m;
}
self.wait().await;
}
}
pub fn get_mut(&mut self) -> &mut T {
self.val.get_mut()
}
}
#[cfg(test)]
mod tests {
use std::rc::Rc;
use std::time::Duration;
use futures::future::FutureExt;
use futures::{pin_mut, poll};
use tokio::test;
use tokio::time::timeout;
use super::*;
static SEC_5: Duration = Duration::from_secs(5);
#[test]
async fn into_inner() {
let rwlock = RwLock::new(42);
assert_eq!(rwlock.into_inner(), 42);
}
#[test]
async fn read_shared() {
timeout(SEC_5, async {
let rwlock = RwLock::new(100);
let _r1 = rwlock.read().await;
let _r2 = rwlock.read().await;
})
.await
.expect("timed out")
}
#[test]
async fn write_shared_pending() {
timeout(SEC_5, async {
let rwlock = RwLock::new(100);
let _r1 = rwlock.read().await;
timeout(Duration::from_millis(500), rwlock.write())
.await
.expect_err("not timed out?");
})
.await
.expect("timed out");
}
#[test]
async fn read_exclusive_pending() {
timeout(SEC_5, async {
let rwlock = RwLock::new(100);
let _w1 = rwlock.write().await;
timeout(Duration::from_millis(500), rwlock.read())
.await
.expect_err("not timed out?");
})
.await
.expect("timed out");
}
#[test]
async fn write_exclusive_pending() {
timeout(SEC_5, async {
let rwlock = RwLock::new(100);
let _w1 = rwlock.write().await;
timeout(Duration::from_millis(500), rwlock.write())
.await
.expect_err("not timed out?");
})
.await
.expect("timed out");
}
#[test]
async fn write_shared_drop() {
timeout(SEC_5, async {
let rwlock = Rc::new(RwLock::new(100));
let rwlock = rwlock.clone();
let w1 = rwlock.write().await;
let try_write_2 = rwlock.write();
pin_mut!(try_write_2);
matches!(poll!(&mut try_write_2), Poll::Pending);
matches!(poll!(&mut try_write_2), Poll::Pending);
matches!(poll!(&mut try_write_2), Poll::Pending);
drop(w1);
try_write_2.await;
})
.await
.expect("timed out");
}
#[test]
async fn write_pending_read_shared_ready() {
timeout(SEC_5, async {
let rwlock = RwLock::new(100);
let _r1 = rwlock.read().await;
let _r2 = rwlock.read().await;
let try_write_1 = rwlock.write();
pin_mut!(try_write_1);
matches!(poll!(&mut try_write_1), Poll::Pending);
matches!(poll!(&mut try_write_1), Poll::Pending);
matches!(poll!(&mut try_write_1), Poll::Pending);
let _r3 = rwlock.read().await;
timeout(Duration::from_millis(500), try_write_1)
.await
.expect_err("not timed out?");
})
.await
.expect("timed out");
}
#[test]
async fn read_uncontested() {
let rwlock = RwLock::new(100);
let result = *rwlock.read().await;
assert_eq!(result, 100);
}
#[test]
async fn write_uncontested() {
let rwlock = RwLock::new(100);
let mut result = rwlock.write().await;
*result += 50;
assert_eq!(*result, 150);
}
#[test]
async fn write_order() {
let rwlock = RwLock::<Vec<u32>>::new(vec![]);
let fut2 = rwlock.write().map(|mut guard| guard.push(2));
let fut1 = rwlock.write().map(|mut guard| guard.push(1));
fut1.await;
fut2.await;
let g = rwlock.read().await;
assert_eq!(*g, vec![1, 2]);
}
#[test]
async fn try_write() {
let lock = RwLock::new(0);
let read_guard = lock.read().await;
assert!(lock.try_write().is_err());
drop(read_guard);
assert!(lock.try_write().is_ok());
}
#[test]
async fn try_read_try_write() {
let lock: RwLock<usize> = RwLock::new(15);
{
let rg1 = lock.try_read().unwrap();
assert_eq!(*rg1, 15);
assert!(lock.try_write().is_err());
let rg2 = lock.try_read().unwrap();
assert_eq!(*rg2, 15)
}
{
let mut wg = lock.try_write().unwrap();
*wg = 1515;
assert!(lock.try_read().is_err())
}
assert_eq!(*lock.try_read().unwrap(), 1515);
}
}