use std::cell::UnsafeCell;
use std::fmt;
use std::future::Future;
use std::mem;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::process;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use event_listener::{Event, EventListener};
use crate::futures::Lock;
use crate::{Mutex, MutexGuard};
const WRITER_BIT: usize = 1;
const ONE_READER: usize = 2;
pub struct RwLock<T: ?Sized> {
mutex: Mutex<()>,
no_readers: Event,
no_writer: Event,
state: AtomicUsize,
value: UnsafeCell<T>,
}
unsafe impl<T: Send + ?Sized> Send for RwLock<T> {}
unsafe impl<T: Send + Sync + ?Sized> Sync for RwLock<T> {}
impl<T> RwLock<T> {
pub const fn new(t: T) -> RwLock<T> {
RwLock {
mutex: Mutex::new(()),
no_readers: Event::new(),
no_writer: Event::new(),
state: AtomicUsize::new(0),
value: UnsafeCell::new(t),
}
}
pub fn into_inner(self) -> T {
self.value.into_inner()
}
}
impl<T: ?Sized> RwLock<T> {
pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
let mut state = self.state.load(Ordering::Acquire);
loop {
if state & WRITER_BIT != 0 {
return None;
}
if state > std::isize::MAX as usize {
process::abort();
}
match self.state.compare_exchange(
state,
state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return Some(RwLockReadGuard(self)),
Err(s) => state = s,
}
}
}
pub fn read(&self) -> Read<'_, T> {
Read {
lock: self,
state: self.state.load(Ordering::Acquire),
listener: None,
}
}
pub fn try_upgradable_read(&self) -> Option<RwLockUpgradableReadGuard<'_, T>> {
let lock = self.mutex.try_lock()?;
let mut state = self.state.load(Ordering::Acquire);
if state > std::isize::MAX as usize {
process::abort();
}
loop {
match self.state.compare_exchange(
state,
state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
return Some(RwLockUpgradableReadGuard {
reader: RwLockReadGuard(self),
reserved: lock,
})
}
Err(s) => state = s,
}
}
}
pub fn upgradable_read(&self) -> UpgradableRead<'_, T> {
UpgradableRead {
lock: self,
acquire: self.mutex.lock(),
}
}
pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
let lock = self.mutex.try_lock()?;
if self
.state
.compare_exchange(0, WRITER_BIT, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
Some(RwLockWriteGuard {
writer: RwLockWriteGuardInner(self),
reserved: lock,
})
} else {
None
}
}
pub fn write(&self) -> Write<'_, T> {
Write {
lock: self,
state: WriteState::Acquiring(self.mutex.lock()),
}
}
pub fn get_mut(&mut self) -> &mut T {
unsafe { &mut *self.value.get() }
}
}
impl<T: fmt::Debug + ?Sized> fmt::Debug for RwLock<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct Locked;
impl fmt::Debug for Locked {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<locked>")
}
}
match self.try_read() {
None => f.debug_struct("RwLock").field("value", &Locked).finish(),
Some(guard) => f.debug_struct("RwLock").field("value", &&*guard).finish(),
}
}
}
impl<T> From<T> for RwLock<T> {
fn from(val: T) -> RwLock<T> {
RwLock::new(val)
}
}
impl<T: Default + ?Sized> Default for RwLock<T> {
fn default() -> RwLock<T> {
RwLock::new(Default::default())
}
}
pub struct Read<'a, T: ?Sized> {
lock: &'a RwLock<T>,
state: usize,
listener: Option<EventListener>,
}
impl<T: ?Sized> fmt::Debug for Read<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Read { .. }")
}
}
impl<T: ?Sized> Unpin for Read<'_, T> {}
impl<'a, T: ?Sized> Future for Read<'a, T> {
type Output = RwLockReadGuard<'a, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
if this.state & WRITER_BIT == 0 {
if this.state > std::isize::MAX as usize {
process::abort();
}
match this.lock.state.compare_exchange(
this.state,
this.state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return Poll::Ready(RwLockReadGuard(this.lock)),
Err(s) => this.state = s,
}
} else {
let load_ordering = match &mut this.listener {
listener @ None => {
*listener = Some(this.lock.no_writer.listen());
Ordering::SeqCst
}
Some(ref mut listener) => {
ready!(Pin::new(listener).poll(cx));
this.listener = None;
this.lock.no_writer.notify(1);
Ordering::Acquire
}
};
this.state = this.lock.state.load(load_ordering);
}
}
}
}
pub struct UpgradableRead<'a, T: ?Sized> {
lock: &'a RwLock<T>,
acquire: Lock<'a, ()>,
}
impl<T: ?Sized> fmt::Debug for UpgradableRead<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("UpgradableRead { .. }")
}
}
impl<T: ?Sized> Unpin for UpgradableRead<'_, T> {}
impl<'a, T: ?Sized> Future for UpgradableRead<'a, T> {
type Output = RwLockUpgradableReadGuard<'a, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let mutex_guard = ready!(Pin::new(&mut this.acquire).poll(cx));
let mut state = this.lock.state.load(Ordering::Acquire);
if state > std::isize::MAX as usize {
process::abort();
}
loop {
match this.lock.state.compare_exchange(
state,
state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
return Poll::Ready(RwLockUpgradableReadGuard {
reader: RwLockReadGuard(this.lock),
reserved: mutex_guard,
});
}
Err(s) => state = s,
}
}
}
}
pub struct Write<'a, T: ?Sized> {
lock: &'a RwLock<T>,
state: WriteState<'a, T>,
}
enum WriteState<'a, T: ?Sized> {
Acquiring(Lock<'a, ()>),
WaitingReaders {
guard: Option<RwLockWriteGuard<'a, T>>,
listener: Option<EventListener>,
},
}
impl<T: ?Sized> fmt::Debug for Write<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Write { .. }")
}
}
impl<T: ?Sized> Unpin for Write<'_, T> {}
impl<'a, T: ?Sized> Future for Write<'a, T> {
type Output = RwLockWriteGuard<'a, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
match &mut this.state {
WriteState::Acquiring(lock) => {
let mutex_guard = ready!(Pin::new(lock).poll(cx));
let new_state = this.lock.state.fetch_or(WRITER_BIT, Ordering::SeqCst);
let guard = RwLockWriteGuard {
writer: RwLockWriteGuardInner(this.lock),
reserved: mutex_guard,
};
if new_state == WRITER_BIT {
return Poll::Ready(guard);
}
this.state = WriteState::WaitingReaders {
guard: Some(guard),
listener: Some(this.lock.no_readers.listen()),
};
}
WriteState::WaitingReaders {
guard,
ref mut listener,
} => {
let load_ordering = if listener.is_some() {
Ordering::Acquire
} else {
Ordering::SeqCst
};
if this.lock.state.load(load_ordering) == WRITER_BIT {
return Poll::Ready(guard.take().unwrap());
}
match listener {
None => {
*listener = Some(this.lock.no_readers.listen());
}
Some(ref mut evl) => {
ready!(Pin::new(evl).poll(cx));
*listener = None;
}
};
}
}
}
}
}
#[clippy::has_significant_drop]
pub struct RwLockReadGuard<'a, T: ?Sized>(&'a RwLock<T>);
unsafe impl<T: Sync + ?Sized> Send for RwLockReadGuard<'_, T> {}
unsafe impl<T: Sync + ?Sized> Sync for RwLockReadGuard<'_, T> {}
impl<T: ?Sized> Drop for RwLockReadGuard<'_, T> {
fn drop(&mut self) {
if self.0.state.fetch_sub(ONE_READER, Ordering::SeqCst) & !WRITER_BIT == ONE_READER {
self.0.no_readers.notify(1);
}
}
}
impl<T: fmt::Debug + ?Sized> fmt::Debug for RwLockReadGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&**self, f)
}
}
impl<T: fmt::Display + ?Sized> fmt::Display for RwLockReadGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
impl<T: ?Sized> Deref for RwLockReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.0.value.get() }
}
}
#[clippy::has_significant_drop]
pub struct RwLockUpgradableReadGuard<'a, T: ?Sized> {
reader: RwLockReadGuard<'a, T>,
reserved: MutexGuard<'a, ()>,
}
unsafe impl<T: Send + Sync + ?Sized> Send for RwLockUpgradableReadGuard<'_, T> {}
unsafe impl<T: Sync + ?Sized> Sync for RwLockUpgradableReadGuard<'_, T> {}
impl<'a, T: ?Sized> RwLockUpgradableReadGuard<'a, T> {
fn into_writer(self) -> RwLockWriteGuard<'a, T> {
let writer = RwLockWriteGuard {
writer: RwLockWriteGuardInner(self.reader.0),
reserved: self.reserved,
};
mem::forget(self.reader);
writer
}
pub fn downgrade(guard: Self) -> RwLockReadGuard<'a, T> {
guard.reader
}
pub fn try_upgrade(guard: Self) -> Result<RwLockWriteGuard<'a, T>, Self> {
if guard
.reader
.0
.state
.compare_exchange(ONE_READER, WRITER_BIT, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
Ok(guard.into_writer())
} else {
Err(guard)
}
}
pub fn upgrade(guard: Self) -> Upgrade<'a, T> {
guard
.reader
.0
.state
.fetch_sub(ONE_READER - WRITER_BIT, Ordering::SeqCst);
let guard = guard.into_writer();
Upgrade {
guard: Some(guard),
listener: None,
}
}
}
impl<T: fmt::Debug + ?Sized> fmt::Debug for RwLockUpgradableReadGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&**self, f)
}
}
impl<T: fmt::Display + ?Sized> fmt::Display for RwLockUpgradableReadGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
impl<T: ?Sized> Deref for RwLockUpgradableReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.reader.0.value.get() }
}
}
pub struct Upgrade<'a, T: ?Sized> {
guard: Option<RwLockWriteGuard<'a, T>>,
listener: Option<EventListener>,
}
impl<T: ?Sized> fmt::Debug for Upgrade<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Upgrade").finish()
}
}
impl<T: ?Sized> Unpin for Upgrade<'_, T> {}
impl<'a, T: ?Sized> Future for Upgrade<'a, T> {
type Output = RwLockWriteGuard<'a, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let guard = this
.guard
.as_mut()
.expect("cannot poll future after completion");
loop {
let load_ordering = if this.listener.is_some() {
Ordering::Acquire
} else {
Ordering::SeqCst
};
let state = guard.writer.0.state.load(load_ordering);
if state == WRITER_BIT {
break;
}
match &mut this.listener {
listener @ None => {
*listener = Some(guard.writer.0.no_readers.listen());
}
Some(ref mut listener) => {
ready!(Pin::new(listener).poll(cx));
this.listener = None;
}
}
}
Poll::Ready(this.guard.take().unwrap())
}
}
struct RwLockWriteGuardInner<'a, T: ?Sized>(&'a RwLock<T>);
impl<T: ?Sized> Drop for RwLockWriteGuardInner<'_, T> {
fn drop(&mut self) {
self.0.state.fetch_and(!WRITER_BIT, Ordering::SeqCst);
self.0.no_writer.notify(1);
}
}
#[clippy::has_significant_drop]
pub struct RwLockWriteGuard<'a, T: ?Sized> {
writer: RwLockWriteGuardInner<'a, T>,
reserved: MutexGuard<'a, ()>,
}
unsafe impl<T: Send + ?Sized> Send for RwLockWriteGuard<'_, T> {}
unsafe impl<T: Sync + ?Sized> Sync for RwLockWriteGuard<'_, T> {}
impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> {
pub fn downgrade(guard: Self) -> RwLockReadGuard<'a, T> {
guard
.writer
.0
.state
.fetch_add(ONE_READER - WRITER_BIT, Ordering::SeqCst);
guard.writer.0.no_writer.notify(1);
let new_guard = RwLockReadGuard(guard.writer.0);
mem::forget(guard.writer); new_guard
}
pub fn downgrade_to_upgradable(guard: Self) -> RwLockUpgradableReadGuard<'a, T> {
guard
.writer
.0
.state
.fetch_add(ONE_READER - WRITER_BIT, Ordering::SeqCst);
let new_guard = RwLockUpgradableReadGuard {
reader: RwLockReadGuard(guard.writer.0),
reserved: guard.reserved,
};
mem::forget(guard.writer); new_guard
}
}
impl<T: fmt::Debug + ?Sized> fmt::Debug for RwLockWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&**self, f)
}
}
impl<T: fmt::Display + ?Sized> fmt::Display for RwLockWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
impl<T: ?Sized> Deref for RwLockWriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.writer.0.value.get() }
}
}
impl<T: ?Sized> DerefMut for RwLockWriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.writer.0.value.get() }
}
}