use crate::future::poll_fn;
use crate::sync::semaphore_ll::{AcquireError, Permit, Semaphore};
use std::cell::UnsafeCell;
use std::ops;
use std::task::{Context, Poll};
#[cfg(not(loom))]
const MAX_READS: usize = 32;
#[cfg(loom)]
const MAX_READS: usize = 10;
#[derive(Debug)]
pub struct RwLock<T> {
s: Semaphore,
c: UnsafeCell<T>,
}
#[derive(Debug)]
pub struct RwLockReadGuard<'a, T> {
permit: ReleasingPermit<'a, T>,
lock: &'a RwLock<T>,
}
#[derive(Debug)]
pub struct RwLockWriteGuard<'a, T> {
permit: ReleasingPermit<'a, T>,
lock: &'a RwLock<T>,
}
#[derive(Debug)]
struct ReleasingPermit<'a, T> {
num_permits: u16,
permit: Permit,
lock: &'a RwLock<T>,
}
impl<'a, T> ReleasingPermit<'a, T> {
fn poll_acquire(
&mut self,
cx: &mut Context<'_>,
s: &Semaphore,
) -> Poll<Result<(), AcquireError>> {
self.permit.poll_acquire(cx, self.num_permits, s)
}
}
impl<'a, T> Drop for ReleasingPermit<'a, T> {
fn drop(&mut self) {
self.permit.release(self.num_permits, &self.lock.s);
}
}
unsafe impl<T> Send for RwLock<T> where T: Send {}
unsafe impl<T> Sync for RwLock<T> where T: Send + Sync {}
unsafe impl<'a, T> Sync for RwLockReadGuard<'a, T> where T: Send + Sync {}
unsafe impl<'a, T> Sync for RwLockWriteGuard<'a, T> where T: Send + Sync {}
impl<T> RwLock<T> {
pub fn new(value: T) -> RwLock<T> {
RwLock {
c: UnsafeCell::new(value),
s: Semaphore::new(MAX_READS),
}
}
pub async fn read(&self) -> RwLockReadGuard<'_, T> {
let mut permit = ReleasingPermit {
num_permits: 1,
permit: Permit::new(),
lock: self,
};
poll_fn(|cx| permit.poll_acquire(cx, &self.s))
.await
.unwrap_or_else(|_| {
unreachable!()
});
RwLockReadGuard { lock: self, permit }
}
pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
let mut permit = ReleasingPermit {
num_permits: MAX_READS as u16,
permit: Permit::new(),
lock: self,
};
poll_fn(|cx| permit.poll_acquire(cx, &self.s))
.await
.unwrap_or_else(|_| {
unreachable!()
});
RwLockWriteGuard { lock: self, permit }
}
}
impl<T> ops::Deref for RwLockReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.lock.c.get() }
}
}
impl<T> ops::Deref for RwLockWriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.lock.c.get() }
}
}
impl<T> ops::DerefMut for RwLockWriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.lock.c.get() }
}
}
impl<T> From<T> for RwLock<T> {
fn from(s: T) -> Self {
Self::new(s)
}
}
impl<T> Default for RwLock<T>
where
T: Default,
{
fn default() -> Self {
Self::new(T::default())
}
}