use std::{
cell::Cell,
mem::MaybeUninit,
ptr,
sync::Once,
sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
};
#[derive(Debug)]
pub(crate) struct Guard(Option<Repr>);
#[derive(Debug)]
enum Repr {
Read(RwLockReadGuard<'static, ()>),
Write(RwLockWriteGuard<'static, ()>),
}
pub(crate) fn write() -> Guard {
match CACHE.with(Cell::get) {
Cache::Write => {
Guard(None)
}
Cache::Read(readers) => {
assert_eq!(
readers, 0,
"calling write() with an active read guard on the same thread would deadlock"
);
let w_guard = static_rw_lock().write().unwrap_or_else(|err| err.into_inner());
CACHE.with(|it| it.set(Cache::Write));
Guard(Some(Repr::Write(w_guard)))
}
}
}
pub(crate) fn read() -> Guard {
match CACHE.with(Cell::get) {
Cache::Write => {
Guard(None)
}
Cache::Read(readers) => {
if readers == 0 {
let r_guard = static_rw_lock().read().unwrap_or_else(|err| err.into_inner());
CACHE.with(|it| it.set(Cache::Read(1)));
Guard(Some(Repr::Read(r_guard)))
} else {
CACHE.with(|it| it.set(Cache::Read(readers + 1)));
Guard(None)
}
}
}
}
fn static_rw_lock() -> &'static RwLock<()> {
static mut LOCK: MaybeUninit<RwLock<()>> = MaybeUninit::uninit();
static LOCK_INIT: Once = Once::new();
unsafe {
LOCK_INIT.call_once(|| ptr::write(LOCK.as_mut_ptr(), RwLock::new(())));
&*LOCK.as_ptr()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Cache {
Read(usize),
Write,
}
thread_local! {
static CACHE: Cell<Cache> = Cell::new(Cache::Read(0));
}
impl Drop for Guard {
fn drop(&mut self) {
match self.0 {
Some(Repr::Read(_)) => CACHE.with(|it| {
let n = match it.get() {
Cache::Read(n) => n,
Cache::Write => unreachable!("had both a reader and a writer"),
};
it.set(Cache::Read(n - 1));
}),
Some(Repr::Write(_)) => CACHE.with(|it| {
assert_eq!(it.get(), Cache::Write);
it.set(Cache::Read(0));
}),
None => {}
}
}
}
#[test]
fn read_write_read() {
eprintln!("get r1");
let r1 = read();
eprintln!("got r1");
let h = std::thread::spawn(|| {
eprintln!("get w1");
let w1 = write();
eprintln!("got w1");
drop(w1);
eprintln!("gave w1");
});
std::thread::sleep(std::time::Duration::from_millis(300));
eprintln!("get r2");
let r2 = read();
eprintln!("got r2");
drop(r1);
eprintln!("gave r1");
drop(r2);
eprintln!("gave r2");
h.join().unwrap();
}
#[test]
fn write_read() {
let _w = write();
let _r = read();
}
#[test]
#[should_panic(
expected = "calling write() with an active read guard on the same thread would deadlock"
)]
fn read_write() {
let _r = read();
let _w = write();
}