use super::semaphore::Semaphore;
use std::sync::Arc;
pub struct RfRwLock {
ll_sem: Semaphore,
}
impl Default for RfRwLock {
fn default() -> Self {
Self::new()
}
}
impl RfRwLock {
pub fn new() -> Self {
Self { ll_sem: Semaphore::new(Semaphore::MAX_PERMITS) }
}
pub async fn read(&self) -> RfRwLockReadGuard<'_> {
self.ll_sem.acquire(1).await;
RfRwLockReadGuard(self)
}
pub fn blocking_read(&self) -> RfRwLockReadGuard<'_> {
self.ll_sem.blocking_acquire(1);
RfRwLockReadGuard(self)
}
pub async fn read_owned(self: Arc<Self>) -> RfRwLockOwnedReadGuard {
self.ll_sem.acquire(1).await;
RfRwLockOwnedReadGuard(self)
}
pub async fn write(&self) -> RfRwLockWriteGuard<'_> {
self.ll_sem.acquire(Semaphore::MAX_PERMITS).await;
RfRwLockWriteGuard(self)
}
pub fn blocking_write(&self) -> RfRwLockWriteGuard<'_> {
self.ll_sem.blocking_acquire(Semaphore::MAX_PERMITS);
RfRwLockWriteGuard(self)
}
pub async fn write_owned(self: Arc<Self>) -> RfRwLockOwnedWriteGuard {
self.ll_sem.acquire(Semaphore::MAX_PERMITS).await;
RfRwLockOwnedWriteGuard(self)
}
fn release_read(&self) {
self.ll_sem.release(1);
}
fn release_write(&self) {
self.ll_sem.release(Semaphore::MAX_PERMITS);
}
fn blocking_yield_writer(&self) {
self.ll_sem.blocking_yield(Semaphore::MAX_PERMITS);
}
}
pub struct RfRwLockReadGuard<'a>(&'a RfRwLock);
impl Drop for RfRwLockReadGuard<'_> {
fn drop(&mut self) {
self.0.release_read();
}
}
pub struct RfRwLockOwnedReadGuard(Arc<RfRwLock>);
impl Drop for RfRwLockOwnedReadGuard {
fn drop(&mut self) {
self.0.release_read();
}
}
pub struct RfRwLockWriteGuard<'a>(&'a RfRwLock);
impl Drop for RfRwLockWriteGuard<'_> {
fn drop(&mut self) {
self.0.release_write();
}
}
impl RfRwLockWriteGuard<'_> {
pub fn blocking_yield(&mut self) {
self.0.blocking_yield_writer();
}
}
pub struct RfRwLockOwnedWriteGuard(Arc<RfRwLock>);
impl Drop for RfRwLockOwnedWriteGuard {
fn drop(&mut self) {
self.0.release_write();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
sync::atomic::{AtomicBool, Ordering::SeqCst},
time::Duration,
};
use tokio::{sync::oneshot, time::sleep, time::timeout};
const ACQUIRE_TIMEOUT: Duration = Duration::from_secs(5);
#[tokio::test]
async fn test_writer_reentrance() {
for i in 0..16 {
let l = Arc::new(RfRwLock::new());
let (tx, rx) = oneshot::channel();
let l_clone = l.clone();
let h = std::thread::spawn(move || {
let mut write = l_clone.blocking_write();
tx.send(()).unwrap();
for _ in 0..10 {
std::thread::sleep(Duration::from_millis(2));
write.blocking_yield();
}
});
rx.await.unwrap();
let read = timeout(Duration::from_millis(18), l.read()).await.unwrap_or_else(|_| panic!("failed at iteration {i}"));
drop(read);
timeout(Duration::from_millis(100), tokio::task::spawn_blocking(move || h.join())).await.unwrap().unwrap().unwrap();
}
}
#[tokio::test]
async fn test_readers_preferred() {
let l = Arc::new(RfRwLock::new());
let read1 = l.read().await;
let read2 = l.read().await;
let read3 = l.read().await;
let (tx, rx) = oneshot::channel();
let (tx_back, rx_back) = oneshot::channel();
let l_clone = l.clone();
let h = tokio::spawn(async move {
let fut = l_clone.write();
tx.send(()).unwrap();
let _write = fut.await;
println!("writer acquired");
rx_back.await.unwrap();
println!("releasing writer");
});
rx.await.unwrap();
let read4 = timeout(ACQUIRE_TIMEOUT, l.read()).await.unwrap();
let read5 = timeout(ACQUIRE_TIMEOUT, l.read()).await.unwrap();
drop(read1);
drop(read2);
drop(read3);
drop(read4);
drop(read5);
println!("dropped all readers");
let f = Arc::new(AtomicBool::new(false));
let f_clone = f.clone();
let l_clone = l.clone();
tokio::spawn(async move {
let _read = l_clone.read().await;
assert!(f_clone.load(SeqCst), "reader acquired before writer release");
println!("late reader acquired");
});
sleep(Duration::from_secs(1)).await;
f.store(true, SeqCst);
tx_back.send(()).unwrap();
timeout(ACQUIRE_TIMEOUT, h).await.unwrap().unwrap();
}
}