use futures::{
future::{
Future,
Shared,
},
FutureExt,
lock,
};
use std::{
fmt,
ops::{
Deref,
DerefMut,
},
pin::Pin,
sync::{
Arc,
atomic::{
AtomicUsize,
Ordering,
},
RwLock,
RwLockReadGuard,
RwLockWriteGuard,
Weak,
},
task::{
Context,
Poll,
Waker,
},
};
pub struct FutRwLock<T: ?Sized> {
inner: Arc<RwLock<T>>,
reader_locks: Arc<AtomicUsize>,
waker: Arc<RwLock<Option<Waker>>>,
writer_awaiting_reader_locks_future: Arc<
lock::Mutex<Option<Shared<WriterAwaitingReaderLocksFuture>>>
>,
writer_lock: Arc<lock::Mutex<()>>,
}
impl<T> FutRwLock<T> {
pub fn new(t: T) -> FutRwLock<T> {
FutRwLock{
inner: Arc::new(RwLock::new(t)),
reader_locks: Arc::new(AtomicUsize::new(0usize)),
waker: Arc::new(RwLock::new(None)),
writer_awaiting_reader_locks_future: Arc::new(lock::Mutex::new(None)),
writer_lock: Arc::new(lock::Mutex::new(())),
}
}
}
impl <T> From<T> for FutRwLock<T> {
fn from(t: T) -> FutRwLock<T> {
FutRwLock::new(t)
}
}
impl <T: Default> Default for FutRwLock<T> {
fn default() -> FutRwLock<T> {
FutRwLock::new(Default::default())
}
}
impl <T: ?Sized + fmt::Debug> fmt::Debug for FutRwLock<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f
.debug_struct("FutRwLock")
.field("inner", &self.inner)
.finish()
}
}
impl <T: ?Sized> FutRwLock<T> {
pub fn is_poisoned(&self) -> bool {
self.inner.is_poisoned()
}
pub async fn read(&self) -> FutRwLockReadGuard<'_, T> {
let _writer_lock = self.writer_lock.lock().await;
self.read_lock_increment();
FutRwLockReadGuard::new(
self,
self.inner
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner()),
)
}
fn read_lock_decrement(&self) {
let _post_op_reader_lock_count = self
.reader_locks
.fetch_sub(1usize, Ordering::SeqCst);
let _ = self
.waker
.read()
.map(|waker_unlock_result|
waker_unlock_result
.as_ref()
.map(|waker|{
waker.wake_by_ref()
})
);
}
fn read_lock_increment(&self) {
let _ = self
.reader_locks
.fetch_add(1usize, Ordering::SeqCst);
}
pub fn try_read_now(&self) -> Option<FutRwLockReadGuard<'_, T>> {
self
.writer_lock
.try_lock()
.map(|_writer_lock_guard| {
self.read_lock_increment();
FutRwLockReadGuard::new(
self,
self.inner
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner()),
)
})
}
pub fn try_write_now(&self) -> Option<FutRwLockWriteGuard<'_, T>> {
if let Some (writer_lock) = self
.writer_lock
.try_lock()
{
if self.reader_locks.load(Ordering::SeqCst) == 0 {
Some(FutRwLockWriteGuard::new(
self,
self.inner
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner()),
writer_lock ))
} else {
None
}
} else {
None
}
}
pub async fn write(&self) -> FutRwLockWriteGuard<'_, T> {
let writer_lock = self.writer_lock.lock().await;
if self.reader_locks.load(Ordering::SeqCst) > 0 {
let new_writer_awaiting_reader_locks_future = WriterAwaitingReaderLocksFuture{
reader_locks: Arc::downgrade(&self.reader_locks),
waker: Arc::downgrade(&self.waker),
};
let shared_future = new_writer_awaiting_reader_locks_future.shared();
self.writer_awaiting_reader_locks_future
.lock()
.await
.replace(
shared_future.clone()
);
let _ = shared_future.await;
*self.writer_awaiting_reader_locks_future.lock().await = None;
}
FutRwLockWriteGuard::new(
self,
self.inner
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner()),
writer_lock )
}
}
pub struct FutRwLockReadGuard<'a, T: ?Sized + 'a> {
async_rwlock: &'a FutRwLock<T>,
inner_read_guard: RwLockReadGuard<'a, T>,
}
impl <T: ?Sized> Deref for FutRwLockReadGuard <'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner_read_guard.deref()
}
}
impl <'a, T: ?Sized + 'a> Drop for FutRwLockReadGuard <'a, T>{
fn drop(&mut self) {
self
.async_rwlock
.read_lock_decrement();
}
}
impl <'a, T: 'a + ?Sized > FutRwLockReadGuard <'a, T> {
fn new(
async_rwlock: &'a FutRwLock<T>,
inner_read_guard: RwLockReadGuard<'a, T>,
) -> FutRwLockReadGuard<'a, T> {
FutRwLockReadGuard {
async_rwlock,
inner_read_guard,
}
}
}
#[allow(dead_code)]
pub struct FutRwLockWriteGuard<'a, T: ?Sized + 'a> {
async_rwlock: &'a FutRwLock<T>,
inner_write_guard: RwLockWriteGuard<'a, T>,
writer_lock: lock::MutexGuard<'a, ()>,
}
impl <'a, T: 'a + ?Sized > FutRwLockWriteGuard <'a, T> {
fn new(
async_rwlock: &'a FutRwLock<T>,
inner_write_guard: RwLockWriteGuard<'a, T>,
writer_lock: lock::MutexGuard<'a, ()>,
) -> FutRwLockWriteGuard<'a, T> {
FutRwLockWriteGuard {
async_rwlock,
inner_write_guard,
writer_lock,
}
}
}
impl <'a, T:'a + ?Sized> Deref for FutRwLockWriteGuard <'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner_write_guard.deref()
}
}
impl <'a, T:'a + ?Sized> DerefMut for FutRwLockWriteGuard <'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner_write_guard.deref_mut()
}
}
impl <T: fmt::Debug> fmt::Debug for FutRwLockWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f
.debug_struct("FutRwLockWriteGuard")
.field("async_rwlock", &self.async_rwlock)
.field("inner_write_guard", &self.inner_write_guard)
.finish()
}
}
impl <T: ?Sized + fmt::Display> fmt::Display for FutRwLockWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(*self.inner_write_guard).fmt(f)
}
}
struct WriterAwaitingReaderLocksFuture {
reader_locks: Weak<AtomicUsize>,
waker: Weak<RwLock<Option<Waker>>>,
}
impl Future for WriterAwaitingReaderLocksFuture {
type Output = Result<(), ()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(reader_locks_atomicusize) = self.reader_locks.upgrade() {
if let Some(waker_rwlock) = self.waker.upgrade() {
if reader_locks_atomicusize.load(Ordering::SeqCst) > 0 {
waker_rwlock
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.replace(cx.waker().clone());
Poll::Pending
} else {
let _ = waker_rwlock
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.take();
Poll::Ready(Ok(()))
}
} else {
Poll::Ready(Err(()))
}
} else {
Poll::Ready(Err(()))
}
}
}
#[cfg(test)]
mod tests {
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
use super::*;
use futures::{
join,
future::{
join_all,
},
};
use rand::{RngCore,SeedableRng};
use rand_chacha::ChaChaRng;
use wasm_bindgen_test::*;
use instant::{Instant, Duration};
async fn get_some_reads_then_wait(
rwlock: &FutRwLock<()>,
wait_ns: u64
) -> Instant {
let _read_1 = rwlock.read().await;
let _read_2 = rwlock.read().await;
let reads_acquired_instant = Instant::now();
sleep(wait_ns).await;
reads_acquired_instant
}
async fn get_write_then_wait (
rwlock: &FutRwLock<()>,
wait_ns: u64,
) -> Instant {
let _write_0 = rwlock.write().await;
let write_acquired_instant = Instant::now();
sleep(wait_ns).await;
write_acquired_instant
}
async fn get_write (
rwlock: &FutRwLock<()>,
) -> Instant {
let _write_0 = rwlock.write().await;
let write_acquired_instant = Instant::now();
write_acquired_instant
}
async fn get_some_reads(
rwlock: &FutRwLock<()>,
) -> Instant {
let _read_1 = rwlock.read().await;
let _read_2 = rwlock.read().await;
let reads_acquired_instant = Instant::now();
reads_acquired_instant
}
async fn sleep(t: u64) {
let msg: String = format!("{}", t as i32);
if cfg!(target_arch = "wasm32") {
let fuff = js_function_promisify::Callback::new(move || {
web_sys::console::warn_1(js_sys::JsString::from(msg.clone()).as_ref());
Ok("".into())
});
let window = web_sys::window().expect("Must get window to sleep");
let _ = window.set_timeout_with_callback_and_timeout_and_arguments_0(
fuff.as_function().as_ref(), t as i32,
).unwrap();
let _ = fuff.await;
} else {
#[cfg(not(target_arch = "wasm32"))]
let _ = tokio::time::sleep(Duration::from_millis(t)).await;
std::println!("{}", msg );
}
}
#[cfg(not(target_arch = "wasm32"))]
macro_rules! run_test {
($f:ident) => {
{
tokio_test::block_on($f())
}
};
}
#[cfg(target_arch = "wasm32")]
macro_rules! run_test {
($f:ident) => {
{
wasm_bindgen_futures::spawn_local($f())
}
};
}
#[test]
#[wasm_bindgen_test]
fn write_mutates_inner_value() {
let async_test = move || async {
let p: FutRwLock<Option<u8>> = FutRwLock::new(None);
let mut a = p.write().await;
*a = Some(16);
assert_eq!(*a, Some(16));
};
run_test!(async_test)
}
#[test]
#[wasm_bindgen_test]
fn write_awaits_reads() {
let async_test = move || async {
let p: FutRwLock<()> = FutRwLock::new(());
let target_wait_ns = 10u64;
let (reads_acquired_instant, write_acquired_instant) = join!(
get_some_reads_then_wait(&p, target_wait_ns),
get_write(&p),
);
assert!(
write_acquired_instant >= reads_acquired_instant,
"Writes acquired after reads done"
);
let read_to_write_duration = write_acquired_instant
.duration_since(reads_acquired_instant);
assert!(
read_to_write_duration >= Duration::from_nanos(target_wait_ns),
"Writes acquired after reads done by duration magnitude"
);
};
run_test!(async_test)
}
#[test]
#[wasm_bindgen_test]
fn write_order () {
let async_test = move || async {
type TimeTy = u64;
type ValTy = i8;
type TestTy = Vec<ValTy>;
type FutrwTy = FutRwLock<TestTy>;
const TEST_DEPTH:usize = 32;
const TEST_NS_STEP:TimeTy = 100;
let p: FutrwTy = FutRwLock::new(Vec::with_capacity(TEST_DEPTH));
{
let mut rng = ChaChaRng::from_entropy();
type Spec = (
TimeTy, usize, ValTy );
async fn sleep_then_write_push (
rwlock: &FutrwTy,
ns: TimeTy,
val: ValTy,
) -> Instant {
sleep(ns).await;
let write_acquire_attempted_instant = Instant::now();
let mut w = rwlock.write().await;
(*w).push(val.clone());
write_acquire_attempted_instant
}
let mut specs: Vec<Spec> = Vec::with_capacity(TEST_DEPTH);
for i in 0..TEST_DEPTH {
specs.push((
(TEST_DEPTH - i) as TimeTy * TEST_NS_STEP, i, rng.next_u32() as ValTy, ));
}
let (futs, target) = specs.iter().fold(
(vec![], vec![]),
|(mut fs, mut ts), (ns, ind, val)| {
fs.push(sleep_then_write_push(&p, *ns, *val));
ts.push((ns, ind, val));
(fs, ts)
}
);
let target_vals : Vec<ValTy> = target.iter().map(|(_,_,v)| **v).collect();
let write_acquired_attempt_instants: Vec<Instant> = join_all(futs).await;
let mut targets_by_acquisition_attempt_instant: Vec<(Instant, ValTy)> = (0..TEST_DEPTH)
.map(|i| (write_acquired_attempt_instants[i], target_vals[i]) ).collect();
targets_by_acquisition_attempt_instant.sort_by(|(i,_), (j,_)| i.partial_cmp(j).unwrap());
let target_vals : Vec<ValTy> = targets_by_acquisition_attempt_instant.iter().map(|(_,v)| *v).collect();
let result = p.read().await;
assert_eq!(
*result,
target_vals,
"Write results in a buffer must match acquire attempt order: \n{:#?}",
specs
)
}
};
run_test!(async_test)
}
#[test]
#[wasm_bindgen_test]
fn write_then_drop_then_read() {
let async_test = move || async {
let p: FutRwLock<Option<u8>> = FutRwLock::new(None);
{
let mut a = p.write().await;
*a = Some(16);
drop(a);
let b = *p.read().await;
assert_eq!(b, Some(16), "read following write 1 must see new value");
}
let mut a = p.write().await;
*a = Some(144);
drop(a);
let b = *p.read().await;
assert_eq!(b, Some(144), "read following write 2 must see new value");
};
run_test!(async_test)
}
#[test]
#[wasm_bindgen_test]
fn write_prevents_try_write_now_and_try_read_now() {
let async_test = move || async {
let p: FutRwLock<Option<u8>> = FutRwLock::new(None);
let _a = p.write().await;
assert!(p.try_write_now().is_none(), "try_write_now returns None when write lock active");
assert!(p.try_read_now().is_none(), "try_read_now returns None when write lock active");
assert!(p.try_write_now().is_none(), "try_write_now returns None when write lock active");
assert!(p.try_read_now().is_none(), "try_read_now returns None when write lock active");
drop(_a);
assert!(p.try_write_now().is_some(), "try_write_now returns Some when write lock active");
assert!(p.try_read_now().is_some(), "try_read_now returns Some when write lock active");
};
run_test!(async_test)
}
#[test]
#[wasm_bindgen_test]
fn reads_await_write() {
let async_test = move || async {
let p: FutRwLock<()> = FutRwLock::new(());
let target_wait_ns = 10u64;
let (write_acquired_instant, reads_acquired_instant) = join!(
get_write_then_wait(&p, target_wait_ns),
get_some_reads(&p),
);
assert!(
reads_acquired_instant >= write_acquired_instant,
"Read acquired after write done"
);
let write_to_read_duration = reads_acquired_instant
.duration_since(write_acquired_instant);
assert!(
write_to_read_duration >= Duration::from_nanos(target_wait_ns),
"Reads acquired after write done by duration magnitude"
);
};
run_test!(async_test)
}
#[test]
#[wasm_bindgen_test]
fn read_one() {
let async_test = move || async {
let p: FutRwLock<Option<u8>> = FutRwLock::new(Some(22));
let a = p.read().await;
assert_eq!(*a, Some(22));
};
run_test!(async_test)
}
#[test]
#[wasm_bindgen_test]
fn read_multiple() {
let async_test = move || async {
let p: FutRwLock<Option<u8>> = FutRwLock::new(Some(22));
let a = p.read().await;
let a_0 = p.read().await;
assert_eq!(*a, Some(22));
assert_eq!(*a_0, Some(22));
};
run_test!(async_test)
}
#[test]
#[wasm_bindgen_test]
fn read_prevents_try_write_now_and_allows_try_read_now() {
let async_test = move || async {
let p: FutRwLock<u16> = FutRwLock::new(1622);
let read_0 = p.read().await;
let read_1_opt = p.try_read_now();
assert!(p.try_write_now().is_none(), "try_write_now returns None when read lock active");
assert!(read_1_opt.is_some(), "try_read_now returns Some when read lock active");
assert_eq!(*read_0, *read_1_opt.unwrap(), "Read and try read now point to same");
};
run_test!(async_test)
}
}