use core::cell::UnsafeCell;
use core::ops::{Deref, DerefMut};
use core::ptr::NonNull;
use crate::Mutex;
struct State {
writing: bool,
read_count: u32,
}
pub struct RwLock<T: ?Sized> {
state: Mutex<State>,
elem: UnsafeCell<T>,
}
unsafe impl<T: ?Sized + Send> Send for RwLock<T> {}
unsafe impl<T: ?Sized + Send + Sync> Sync for RwLock<T> {}
#[must_use = "if unused, the lock will release automatically"]
pub struct RwLockReadGuard<'a, T: ?Sized> {
elem: NonNull<T>,
guard: &'a Mutex<State>,
}
unsafe impl<T: ?Sized> Send for RwLockReadGuard<'_, T> {}
unsafe impl<T: ?Sized + Sync> Sync for RwLockReadGuard<'_, T> {}
impl<T: ?Sized> Drop for RwLockReadGuard<'_, T> {
fn drop(&mut self) {
self.guard.lock().read_count -= 1;
}
}
impl<T: ?Sized> Deref for RwLockReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { self.elem.as_ref() }
}
}
#[must_use = "if unused, the lock will release automatically"]
pub struct RwLockWriteGuard<'a, T: ?Sized> {
elem: NonNull<T>,
guard: &'a Mutex<State>,
}
unsafe impl<T: ?Sized> Send for RwLockWriteGuard<'_, T> {}
unsafe impl<T: ?Sized + Sync> Sync for RwLockWriteGuard<'_, T> {}
impl<T: ?Sized> Deref for RwLockWriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { self.elem.as_ref() }
}
}
impl<T: ?Sized> DerefMut for RwLockWriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.elem.as_mut() }
}
}
impl<T: ?Sized> Drop for RwLockWriteGuard<'_, T> {
fn drop(&mut self) {
self.guard.lock().writing = false;
}
}
pub enum RwLockError {
WriteWhileWrite,
WriteWhileRead,
ReadWhileWrite,
}
pub type RwLockResult<T> = Result<T, RwLockError>;
impl<T> RwLock<T> {
pub const fn new(elem: T) -> Self {
Self {
elem: UnsafeCell::new(elem),
state: Mutex::new(State {
writing: false,
read_count: 0
})
}
}
pub fn into_inner(self) -> T {
UnsafeCell::into_inner(self.elem)
}
}
impl<T: ?Sized> RwLock<T> {
pub fn read(&self) -> RwLockResult<RwLockReadGuard<'_, T>> {
let mut state = self.state.lock();
(!state.writing).then(|| {
state.read_count += 1;
RwLockReadGuard {
guard: &self.state,
elem: unsafe { NonNull::new_unchecked(self.elem.get() ) }
}
}).ok_or(RwLockError::ReadWhileWrite)
}
pub fn write(&self) -> RwLockResult<RwLockWriteGuard<'_, T>> {
let mut state = self.state.lock();
if state.writing {
Err(RwLockError::WriteWhileWrite)
} else if state.read_count > 0 {
Err(RwLockError::WriteWhileRead)
} else {
state.writing = true;
Ok(RwLockWriteGuard {
guard: &self.state,
elem: unsafe { NonNull::new_unchecked(self.elem.get() ) }
})
}
}
}