use std::future::Future;
use std::mem::forget;
use std::pin::Pin;
use std::process;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use event_listener::{Event, EventListener};
use crate::futures::Lock;
use crate::Mutex;
const WRITER_BIT: usize = 1;
const ONE_READER: usize = 2;
pub(super) struct RawRwLock {
mutex: Mutex<()>,
no_readers: Event,
no_writer: Event,
state: AtomicUsize,
}
impl RawRwLock {
#[inline]
pub(super) const fn new() -> Self {
RawRwLock {
mutex: Mutex::new(()),
no_readers: Event::new(),
no_writer: Event::new(),
state: AtomicUsize::new(0),
}
}
pub(super) fn try_read(&self) -> bool {
let mut state = self.state.load(Ordering::Acquire);
loop {
if state & WRITER_BIT != 0 {
return false;
}
if state > std::isize::MAX as usize {
process::abort();
}
match self.state.compare_exchange(
state,
state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return true,
Err(s) => state = s,
}
}
}
#[inline]
pub(super) fn read(&self) -> RawRead<'_> {
RawRead {
lock: self,
state: self.state.load(Ordering::Acquire),
listener: None,
}
}
pub(super) fn try_upgradable_read(&self) -> bool {
let lock = if let Some(lock) = self.mutex.try_lock() {
lock
} else {
return false;
};
forget(lock);
let mut state = self.state.load(Ordering::Acquire);
if state > std::isize::MAX as usize {
process::abort();
}
loop {
match self.state.compare_exchange(
state,
state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return true,
Err(s) => state = s,
}
}
}
#[inline]
pub(super) fn upgradable_read(&self) -> RawUpgradableRead<'_> {
RawUpgradableRead {
lock: self,
acquire: self.mutex.lock(),
}
}
pub(super) fn try_write(&self) -> bool {
let lock = if let Some(lock) = self.mutex.try_lock() {
lock
} else {
return false;
};
if self
.state
.compare_exchange(0, WRITER_BIT, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
forget(lock);
true
} else {
drop(lock);
false
}
}
#[inline]
pub(super) fn write(&self) -> RawWrite<'_> {
RawWrite {
lock: self,
state: WriteState::Acquiring(self.mutex.lock()),
}
}
pub(super) unsafe fn try_upgrade(&self) -> bool {
self.state
.compare_exchange(ONE_READER, WRITER_BIT, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
}
pub(super) unsafe fn upgrade(&self) -> RawUpgrade<'_> {
self.state
.fetch_sub(ONE_READER - WRITER_BIT, Ordering::SeqCst);
RawUpgrade {
lock: Some(self),
listener: None,
}
}
#[inline]
pub(super) unsafe fn downgrade_upgradable_read(&self) {
self.mutex.unlock_unchecked();
}
pub(super) unsafe fn downgrade_write(&self) {
self.state
.fetch_add(ONE_READER - WRITER_BIT, Ordering::SeqCst);
self.mutex.unlock_unchecked();
self.no_writer.notify(1);
}
pub(super) unsafe fn downgrade_to_upgradable(&self) {
self.state
.fetch_add(ONE_READER - WRITER_BIT, Ordering::SeqCst);
}
pub(super) unsafe fn read_unlock(&self) {
if self.state.fetch_sub(ONE_READER, Ordering::SeqCst) & !WRITER_BIT == ONE_READER {
self.no_readers.notify(1);
}
}
pub(super) unsafe fn upgradable_read_unlock(&self) {
if self.state.fetch_sub(ONE_READER, Ordering::SeqCst) & !WRITER_BIT == ONE_READER {
self.no_readers.notify(1);
}
self.mutex.unlock_unchecked();
}
pub(super) unsafe fn write_unlock(&self) {
self.state.fetch_and(!WRITER_BIT, Ordering::SeqCst);
self.no_writer.notify(1);
self.mutex.unlock_unchecked();
}
}
pub(super) struct RawRead<'a> {
pub(super) lock: &'a RawRwLock,
state: usize,
listener: Option<EventListener>,
}
impl<'a> Future for RawRead<'a> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let this = self.get_mut();
loop {
if this.state & WRITER_BIT == 0 {
if this.state > std::isize::MAX as usize {
process::abort();
}
match this.lock.state.compare_exchange(
this.state,
this.state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return Poll::Ready(()),
Err(s) => this.state = s,
}
} else {
let load_ordering = match &mut this.listener {
None => {
this.listener = Some(this.lock.no_writer.listen());
Ordering::SeqCst
}
Some(ref mut listener) => {
ready!(Pin::new(listener).poll(cx));
this.listener = None;
this.lock.no_writer.notify(1);
Ordering::Acquire
}
};
this.state = this.lock.state.load(load_ordering);
}
}
}
}
pub(super) struct RawUpgradableRead<'a> {
pub(super) lock: &'a RawRwLock,
acquire: Lock<'a, ()>,
}
impl<'a> Future for RawUpgradableRead<'a> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let this = self.get_mut();
let mutex_guard = ready!(Pin::new(&mut this.acquire).poll(cx));
forget(mutex_guard);
let mut state = this.lock.state.load(Ordering::Acquire);
if state > std::isize::MAX as usize {
process::abort();
}
loop {
match this.lock.state.compare_exchange(
state,
state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
return Poll::Ready(());
}
Err(s) => state = s,
}
}
}
}
pub(super) struct RawWrite<'a> {
pub(super) lock: &'a RawRwLock,
state: WriteState<'a>,
}
enum WriteState<'a> {
Acquiring(Lock<'a, ()>),
WaitingReaders {
listener: Option<EventListener>,
},
Acquired,
}
impl<'a> Future for RawWrite<'a> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let this = self.get_mut();
loop {
match &mut this.state {
WriteState::Acquiring(lock) => {
let mutex_guard = ready!(Pin::new(lock).poll(cx));
forget(mutex_guard);
let new_state = this.lock.state.fetch_or(WRITER_BIT, Ordering::SeqCst);
if new_state == WRITER_BIT {
this.state = WriteState::Acquired;
return Poll::Ready(());
}
this.state = WriteState::WaitingReaders {
listener: Some(this.lock.no_readers.listen()),
};
}
WriteState::WaitingReaders { ref mut listener } => {
let load_ordering = if listener.is_some() {
Ordering::Acquire
} else {
Ordering::SeqCst
};
if this.lock.state.load(load_ordering) == WRITER_BIT {
this.state = WriteState::Acquired;
return Poll::Ready(());
}
match listener {
None => {
*listener = Some(this.lock.no_readers.listen());
}
Some(ref mut evl) => {
ready!(Pin::new(evl).poll(cx));
*listener = None;
}
};
}
WriteState::Acquired => panic!("Write lock already acquired"),
}
}
}
}
impl<'a> Drop for RawWrite<'a> {
fn drop(&mut self) {
if matches!(self.state, WriteState::WaitingReaders { .. }) {
unsafe {
self.lock.write_unlock();
}
}
}
}
pub(super) struct RawUpgrade<'a> {
lock: Option<&'a RawRwLock>,
listener: Option<EventListener>,
}
impl<'a> Future for RawUpgrade<'a> {
type Output = &'a RawRwLock;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<&'a RawRwLock> {
let this = self.get_mut();
let lock = this.lock.expect("cannot poll future after completion");
loop {
let load_ordering = if this.listener.is_some() {
Ordering::Acquire
} else {
Ordering::SeqCst
};
let state = lock.state.load(load_ordering);
if state == WRITER_BIT {
break;
}
match &mut this.listener {
None => {
this.listener = Some(lock.no_readers.listen());
}
Some(ref mut listener) => {
ready!(Pin::new(listener).poll(cx));
this.listener = None;
}
}
}
Poll::Ready(this.lock.take().unwrap())
}
}
impl<'a> Drop for RawUpgrade<'a> {
#[inline]
fn drop(&mut self) {
if let Some(lock) = self.lock {
unsafe {
lock.write_unlock();
}
}
}
}
impl<'a> RawUpgrade<'a> {
#[inline]
pub(super) fn is_ready(&self) -> bool {
self.lock.is_none()
}
}