#![deny(unsafe_code)]
use std::marker::PhantomData;
use std::pin::Pin;
#[cfg(not(feature = "loom"))]
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed};
use std::task::{Context, Poll};
#[cfg(feature = "loom")]
use loom::sync::atomic::AtomicUsize;
use crate::opcode::Opcode;
use crate::sync_primitive::SyncPrimitive;
use crate::wait_queue::WaitQueue;
#[derive(Debug, Default)]
pub struct Gate {
state: AtomicUsize,
}
#[derive(Debug, Default)]
pub struct Pager<'g> {
entry: Option<WaitQueue>,
_phantom: PhantomData<&'g Gate>,
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
#[repr(u8)]
pub enum State {
Controlled = 0_u8,
Sealed = 1_u8,
Open = 2_u8,
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
#[repr(u8)]
pub enum Error {
Rejected = 4_u8,
Sealed = 8_u8,
SpuriousFailure = 12_u8,
NotRegistered = 16_u8,
WrongMode = 20_u8,
NotReady = 24_u8,
}
impl Gate {
const STATE_MASK: u8 = 0b11;
#[inline]
pub fn state(&self, mo: Ordering) -> State {
State::from(self.state.load(mo) & WaitQueue::DATA_MASK)
}
#[inline]
pub fn reset(&self) -> Option<State> {
match self.state.fetch_update(Relaxed, Relaxed, |value| {
let state = State::from(value & WaitQueue::DATA_MASK);
if state == State::Controlled {
None
} else {
debug_assert_eq!(value & WaitQueue::ADDR_MASK, 0);
Some((value & WaitQueue::ADDR_MASK) | u8::from(state) as usize)
}
}) {
Ok(state) => Some(State::from(state & WaitQueue::DATA_MASK)),
Err(_) => None,
}
}
#[inline]
pub fn permit(&self) -> Result<usize, State> {
let (state, count) = self.wake_all(None, None);
if state == State::Controlled {
Ok(count)
} else {
debug_assert_eq!(count, 0);
Err(state)
}
}
#[inline]
pub fn reject(&self) -> Result<usize, State> {
let (state, count) = self.wake_all(None, Some(Error::Rejected));
if state == State::Controlled {
Ok(count)
} else {
debug_assert_eq!(count, 0);
Err(state)
}
}
#[inline]
pub fn open(&self) -> (State, usize) {
self.wake_all(Some(State::Open), None)
}
#[inline]
pub fn seal(&self) -> (State, usize) {
self.wake_all(Some(State::Sealed), Some(Error::Sealed))
}
#[inline]
pub async fn enter_async(&self) -> Result<State, Error> {
let mut pager = Pager::default();
pager
.entry
.replace(WaitQueue::new_async(Opcode::Wait, Self::noop, self.addr()));
let mut pinned_pager = Pin::new(&mut pager);
self.push_wait_queue_entry(&mut pinned_pager, || {});
pinned_pager.await
}
#[inline]
pub async fn enter_async_with<F: FnOnce()>(&self, wait_callback: F) -> Result<State, Error> {
let mut pager = Pager::default();
pager
.entry
.replace(WaitQueue::new_async(Opcode::Wait, Self::noop, self.addr()));
let mut pinned_pager = Pin::new(&mut pager);
self.push_wait_queue_entry(&mut pinned_pager, wait_callback);
pinned_pager.await
}
#[inline]
pub fn enter_sync(&self) -> Result<State, Error> {
self.enter_sync_with(|| ())
}
#[inline]
pub fn enter_sync_with<F: FnOnce()>(&self, wait_callback: F) -> Result<State, Error> {
let mut pager = Pager::default();
pager
.entry
.replace(WaitQueue::new_sync(Opcode::Wait, self.addr()));
let mut pinned_pager = Pin::new(&mut pager);
self.push_wait_queue_entry(&mut pinned_pager, wait_callback);
pinned_pager.poll_sync()
}
#[inline]
pub fn register_async<'g>(&'g self, pager: &mut Pin<&mut Pager<'g>>) -> bool {
if pager.entry.is_some() {
return false;
}
pager
.entry
.replace(WaitQueue::new_async(Opcode::Wait, Self::noop, self.addr()));
self.push_wait_queue_entry(pager, || ());
true
}
#[inline]
pub fn register_sync<'g>(&'g self, pager: &mut Pin<&mut Pager<'g>>) -> bool {
if pager.entry.is_some() {
return false;
}
pager
.entry
.replace(WaitQueue::new_sync(Opcode::Wait, self.addr()));
self.push_wait_queue_entry(pager, || ());
true
}
fn wake_all(&self, next_state: Option<State>, error: Option<Error>) -> (State, usize) {
match self.state.fetch_update(AcqRel, Acquire, |value| {
if let Some(new_value) = next_state {
Some(u8::from(new_value) as usize)
} else {
Some(value & WaitQueue::DATA_MASK)
}
}) {
Ok(value) | Err(value) => {
let mut count = 0;
let entry_addr = value & WaitQueue::ADDR_MASK;
let prev_state = State::from(value & WaitQueue::DATA_MASK);
let next_state = next_state.unwrap_or(prev_state);
let result = Self::into_u8(next_state, error);
if entry_addr != 0 {
WaitQueue::iter_forward(
WaitQueue::addr_to_ptr(entry_addr),
false,
|entry, _| {
entry.set_result(result);
count += 1;
false
},
);
}
(prev_state, count)
}
}
}
#[inline]
fn push_wait_queue_entry<F: FnOnce()>(
&self,
entry: &mut Pin<&mut Pager>,
mut wait_callback: F,
) {
if let Some(entry) = entry.entry.as_ref() {
let pinned_entry = Pin::new(entry);
loop {
let state = self.state.load(Acquire);
match State::from(state & WaitQueue::DATA_MASK) {
State::Controlled => {
if let Some(returned) =
self.try_push_wait_queue_entry(pinned_entry, state, wait_callback)
{
wait_callback = returned;
continue;
}
}
State::Sealed => {
entry.set_result(Self::into_u8(State::Sealed, Some(Error::Sealed)));
}
State::Open => {
entry.set_result(Self::into_u8(State::Open, None));
}
}
break;
}
}
}
#[inline]
fn noop(_entry: &WaitQueue) {
unreachable!("Noop function called");
}
#[inline]
fn into_u8(state: State, error: Option<Error>) -> u8 {
u8::from(state) | error.map_or(0_u8, u8::from)
}
#[inline]
fn from_u8(value: u8) -> (State, Option<Error>) {
let state = State::from(value & Self::STATE_MASK);
let error = value & !(Self::STATE_MASK);
if error != 0 {
(state, Some(Error::from(error)))
} else {
(state, None)
}
}
}
impl Drop for Gate {
#[inline]
fn drop(&mut self) {
if self.state.load(Relaxed) & WaitQueue::ADDR_MASK == 0 {
return;
}
self.seal();
}
}
impl SyncPrimitive for Gate {
#[inline]
fn state(&self) -> &AtomicUsize {
&self.state
}
#[inline]
fn max_shared_owners() -> usize {
usize::MAX
}
}
impl<'g> Pager<'g> {
#[inline]
pub fn is_registered(&self) -> bool {
self.entry.is_some()
}
#[inline]
pub fn is_sync(&self) -> bool {
self.entry.as_ref().is_some_and(WaitQueue::is_sync)
}
#[inline]
pub fn poll_sync(self: &mut Pin<&mut Pager<'g>>) -> Result<State, Error> {
let Some(entry) = self.entry.as_ref() else {
return Err(Error::NotRegistered);
};
let result = entry.poll_result_sync();
if result == WaitQueue::ERROR_WRONG_MODE {
return Err(Error::WrongMode);
}
self.entry.take();
let (state, error) = Gate::from_u8(result);
error.map_or(Ok(state), Err)
}
#[inline]
pub fn try_poll(&self) -> Result<State, Error> {
let Some(entry) = self.entry.as_ref() else {
return Err(Error::NotRegistered);
};
if let Some(result) = entry.try_acknowledge_result() {
let (state, error) = Gate::from_u8(result);
error.map_or(Ok(state), Err)
} else {
Err(Error::NotReady)
}
}
}
impl Drop for Pager<'_> {
#[inline]
fn drop(&mut self) {
let Some(entry) = self.entry.as_mut() else {
return;
};
if entry.try_acknowledge_result().is_none() {
let gate: &Gate = entry.sync_primitive_ref();
gate.wake_all(None, Some(Error::SpuriousFailure));
entry.acknowledge_result_sync();
}
}
}
impl Future for Pager<'_> {
type Output = Result<State, Error>;
#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Some(entry) = self.entry.as_ref() else {
return Poll::Ready(Err(Error::NotRegistered));
};
if let Poll::Ready(result) = entry.poll_result_async(cx) {
self.entry.take();
if result == WaitQueue::ERROR_WRONG_MODE {
return Poll::Ready(Err(Error::WrongMode));
}
let (state, error) = Gate::from_u8(result);
return Poll::Ready(error.map_or(Ok(state), Err));
}
Poll::Pending
}
}
impl From<State> for u8 {
#[inline]
fn from(value: State) -> Self {
match value {
State::Controlled => 0_u8,
State::Sealed => 1_u8,
State::Open => 2_u8,
}
}
}
impl From<u8> for State {
#[inline]
fn from(value: u8) -> Self {
State::from(value as usize)
}
}
impl From<usize> for State {
#[inline]
fn from(value: usize) -> Self {
match value {
0 => State::Controlled,
1 => State::Sealed,
_ => State::Open,
}
}
}
impl From<Error> for u8 {
#[inline]
fn from(value: Error) -> Self {
match value {
Error::Rejected => 4_u8,
Error::Sealed => 8_u8,
Error::SpuriousFailure => 12_u8,
Error::NotRegistered => 16_u8,
Error::WrongMode => 20_u8,
Error::NotReady => 24_u8,
}
}
}
impl From<u8> for Error {
#[inline]
fn from(value: u8) -> Self {
Error::from(value as usize)
}
}
impl From<usize> for Error {
#[inline]
fn from(value: usize) -> Self {
match value {
4 => Error::Rejected,
8 => Error::Sealed,
12 => Error::SpuriousFailure,
16 => Error::NotRegistered,
20 => Error::WrongMode,
_ => Error::NotReady,
}
}
}