use crate::spinlock::Spinlock;
use std::cell::UnsafeCell;
use std::fmt::Display;
use std::sync::atomic::AtomicU8;
#[cfg(not(target_arch = "wasm32"))]
use std::thread;
#[cfg(target_arch = "wasm32")]
use wasm_safe_thread as thread;
use super::UNLOCKED;
#[derive(Debug, Default)]
pub struct RwLock<T> {
pub(crate) inner: UnsafeCell<T>,
pub(crate) data_lock: AtomicU8,
pub(crate) waiting_sync_read_threads: Spinlock<Vec<thread::Thread>>,
pub(crate) waiting_sync_write_threads: Spinlock<Vec<thread::Thread>>,
pub(crate) waiting_async_read_threads: Spinlock<Vec<r#continue::Sender<()>>>,
pub(crate) waiting_async_write_threads: Spinlock<Vec<r#continue::Sender<()>>>,
}
impl<T: Display> Display for RwLock<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.try_lock_read() {
Ok(guard) => std::fmt::Display::fmt(&*guard, f),
Err(_) => write!(f, "Mutex {{ <locked> }}"),
}
}
}
impl<T> From<T> for RwLock<T> {
fn from(value: T) -> Self {
RwLock::new(value)
}
}
unsafe impl<T: Send> Send for RwLock<T> {}
unsafe impl<T: Send> Sync for RwLock<T> {}
impl<T> RwLock<T> {
pub const fn new(value: T) -> RwLock<T> {
RwLock {
inner: UnsafeCell::new(value),
data_lock: AtomicU8::new(UNLOCKED),
waiting_sync_read_threads: Spinlock::new(vec![]),
waiting_async_read_threads: Spinlock::new(vec![]),
waiting_sync_write_threads: Spinlock::new(vec![]),
waiting_async_write_threads: Spinlock::new(vec![]),
}
}
pub(crate) fn did_unlock_write(&self) {
let threads = self.waiting_sync_read_threads.with_mut(std::mem::take);
for thread in threads {
thread.unpark();
}
let threads = self.waiting_async_read_threads.with_mut(std::mem::take);
for thread in threads {
thread.send(())
}
let threads = self.waiting_sync_write_threads.with_mut(std::mem::take);
for thread in threads {
thread.unpark();
}
let threads = self.waiting_async_write_threads.with_mut(std::mem::take);
for thread in threads {
thread.send(())
}
}
pub(crate) fn did_unlock_read(&self) {
let threads = self.waiting_sync_write_threads.with_mut(std::mem::take);
for thread in threads {
thread.unpark();
}
let threads = self.waiting_async_write_threads.with_mut(std::mem::take);
for thread in threads {
thread.send(())
}
}
}