use crate::{
blocking::{DefaultMutex, Mutex, ScopedRawMutex},
loom::{
cell::UnsafeCell,
sync::atomic::{AtomicUsize, Ordering::*},
},
util::{fmt, CachePadded, WakeBatch},
};
use cordyceps::{
list::{self, List},
Linked,
};
use core::{
fmt::Debug,
future::Future,
marker::PhantomPinned,
mem,
pin::Pin,
ptr::{self, NonNull},
task::{Context, Poll, Waker},
};
use mycelium_bitfield::{enum_from_bits, FromBits};
use pin_project::{pin_project, pinned_drop};
#[cfg(test)]
mod tests;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum WaitError {
Closed,
AlreadyConsumed,
NeverAdded,
Duplicate,
}
pub type WaitResult<T> = Result<T, WaitError>;
const fn closed<T>() -> Poll<WaitResult<T>> {
Poll::Ready(Err(WaitError::Closed))
}
const fn consumed<T>() -> Poll<WaitResult<T>> {
Poll::Ready(Err(WaitError::AlreadyConsumed))
}
const fn never_added<T>() -> Poll<WaitResult<T>> {
Poll::Ready(Err(WaitError::NeverAdded))
}
const fn duplicate<T>() -> Poll<WaitResult<T>> {
Poll::Ready(Err(WaitError::Duplicate))
}
const fn notified<T>(data: T) -> Poll<WaitResult<T>> {
Poll::Ready(Ok(data))
}
pub struct WaitMap<K: PartialEq, V, Lock: ScopedRawMutex = DefaultMutex> {
state: CachePadded<AtomicUsize>,
queue: Mutex<List<Waiter<K, V>>, Lock>,
}
impl<K, V, Lock> Debug for WaitMap<K, V, Lock>
where
K: PartialEq,
Lock: ScopedRawMutex,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("WaitMap")
.field("state", &self.state)
.field("queue", &self.queue)
.finish()
}
}
#[derive(Debug)]
#[pin_project(PinnedDrop)]
#[must_use = "futures do nothing unless `.await`ed or `poll`ed"]
pub struct Wait<'a, K: PartialEq, V, Lock: ScopedRawMutex = DefaultMutex> {
queue: &'a WaitMap<K, V, Lock>,
#[pin]
waiter: Waiter<K, V>,
}
impl<'map, 'wait, K: PartialEq, V, Lock: ScopedRawMutex> Wait<'map, K, V, Lock> {
pub fn subscribe(self: Pin<&'wait mut Self>) -> Subscribe<'wait, 'map, K, V, Lock> {
Subscribe { wait: self }
}
#[deprecated(
since = "0.1.3",
note = "renamed to `subscribe` for consistency, use that instead"
)]
#[allow(deprecated)] pub fn enqueue(self: Pin<&'wait mut Self>) -> EnqueueWait<'wait, 'map, K, V, Lock> {
self.subscribe()
}
}
#[pin_project]
struct Waiter<K: PartialEq, V> {
#[pin]
node: UnsafeCell<Node<K, V>>,
state: WaitState,
key: K,
}
impl<K: PartialEq, V> Debug for Waiter<K, V> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Waiter")
.field("node", &self.node)
.field("state", &self.state)
.field("key", &fmt::display(core::any::type_name::<K>()))
.field("val", &fmt::display(core::any::type_name::<V>()))
.finish()
}
}
#[repr(C)]
struct Node<K: PartialEq, V> {
links: list::Links<Waiter<K, V>>,
waker: Wakeup<V>,
_pin: PhantomPinned,
}
impl<K: PartialEq, V> Debug for Node<K, V> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Node")
.field("links", &self.links)
.field("waker", &self.waker)
.finish()
}
}
enum_from_bits! {
#[derive(Debug, Eq, PartialEq)]
enum WaitState<u8> {
Start = 0b01,
Waiting = 0b10,
Completed = 0b11,
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[repr(u8)]
enum State {
Empty = 0b00,
Waiting = 0b01,
Closed = 0b11,
}
#[derive(Clone)]
enum Wakeup<V> {
Empty,
Waiting(Waker),
DataReceived(V),
Retreived,
Closed,
}
impl<K: PartialEq, V> WaitMap<K, V> {
loom_const_fn! {
#[must_use]
pub fn new() -> Self {
Self::new_with_raw_mutex(DefaultMutex::new())
}
}
}
impl<K, V, Lock> Default for WaitMap<K, V, Lock>
where
K: PartialEq,
Lock: ScopedRawMutex + Default,
{
fn default() -> Self {
Self::new_with_raw_mutex(Lock::default())
}
}
impl<K, V, Lock> WaitMap<K, V, Lock>
where
K: PartialEq,
Lock: ScopedRawMutex,
{
loom_const_fn! {
#[must_use]
pub fn new_with_raw_mutex(lock: Lock) -> Self {
Self {
state: CachePadded::new(AtomicUsize::new(State::Empty.into_usize())),
queue: Mutex::new_with_raw_mutex(List::new(), lock),
}
}
}
}
impl<K: PartialEq, V, Lock: ScopedRawMutex> WaitMap<K, V, Lock> {
#[inline]
pub fn wake(&self, key: &K, val: V) -> WakeOutcome<V> {
let mut state = self.load();
match state {
State::Waiting => {}
State::Closed => return WakeOutcome::Closed(val),
State::Empty => return WakeOutcome::NoMatch(val),
}
let mut val = Some(val);
let maybe_waker = self.queue.with_lock(|queue| {
test_debug!("wake: -> locked");
state = self.load();
let node = self.node_match_locked(key, &mut *queue, state)?;
let val = val
.take()
.expect("value is only taken elsewhere if there is no waker, but there is one");
let waker = Waiter::<K, V>::wake(node, &mut *queue, Wakeup::DataReceived(val));
Some(waker)
});
if let Some(waker) = maybe_waker {
waker.wake();
WakeOutcome::Woke
} else {
let val =
val.expect("value is only taken elsewhere if there is a waker, and there isn't");
WakeOutcome::NoMatch(val)
}
}
#[must_use]
pub fn is_closed(&self) -> bool {
self.load() == State::Closed
}
pub fn close(&self) {
let state = self.state.fetch_or(State::Closed.into_usize(), SeqCst);
let state = test_dbg!(State::from_bits(state));
if state != State::Waiting {
return;
}
let mut batch = WakeBatch::new();
let mut waiters_remaining = true;
while waiters_remaining {
waiters_remaining = self.queue.with_lock(|waiters| {
while let Some(node) = waiters.pop_back() {
let waker = Waiter::wake(node, waiters, Wakeup::Closed);
if !batch.add_waker(waker) {
return true;
}
}
false
});
batch.wake_all();
}
}
pub fn wait(&self, key: K) -> Wait<'_, K, V, Lock> {
Wait {
queue: self,
waiter: self.waiter(key),
}
}
fn waiter(&self, key: K) -> Waiter<K, V> {
let state = WaitState::Start;
Waiter {
state,
node: UnsafeCell::new(Node {
links: list::Links::new(),
waker: Wakeup::Empty,
_pin: PhantomPinned,
}),
key,
}
}
#[cfg_attr(test, track_caller)]
fn load(&self) -> State {
#[allow(clippy::let_and_return)]
let state = State::from_bits(self.state.load(SeqCst));
test_debug!("state.load() = {state:?}");
state
}
#[cfg_attr(test, track_caller)]
fn store(&self, state: State) {
test_debug!("state.store({state:?}");
self.state.store(state as usize, SeqCst);
}
#[cfg_attr(test, track_caller)]
fn compare_exchange(&self, current: State, new: State) -> Result<State, State> {
#[allow(clippy::let_and_return)]
let res = self
.state
.compare_exchange(current as usize, new as usize, SeqCst, SeqCst)
.map(State::from_bits)
.map_err(State::from_bits);
test_debug!("state.compare_exchange({current:?}, {new:?}) = {res:?}");
res
}
#[cold]
#[inline(never)]
fn node_match_locked(
&self,
key: &K,
queue: &mut List<Waiter<K, V>>,
curr: State,
) -> Option<NonNull<Waiter<K, V>>> {
let state = curr;
if test_dbg!(state) != State::Waiting {
return None;
}
let mut cursor = queue.cursor_front_mut();
let opt_node = cursor.remove_first(|t| &t.key == key);
if test_dbg!(queue.is_empty()) {
self.store(State::Empty);
}
opt_node
}
}
#[derive(Debug)]
pub enum WakeOutcome<V> {
Woke,
NoMatch(V),
Closed(V),
}
impl fmt::Display for WaitError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Closed => f.pad("WaitMap closed"),
Self::Duplicate => f.pad("duplicate key"),
&Self::AlreadyConsumed => f.pad("received data has already been consumed"),
Self::NeverAdded => f.pad("Wait was never added to WaitMap"),
}
}
}
feature! {
#![feature = "core-error"]
impl core::error::Error for WaitError {}
}
#[must_use = "futures do nothing unless `.await`ed or `poll`ed"]
#[derive(Debug)]
pub struct Subscribe<'a, 'b, K, V, Lock = DefaultMutex>
where
K: PartialEq,
Lock: ScopedRawMutex,
{
wait: Pin<&'a mut Wait<'b, K, V, Lock>>,
}
impl<K, V, Lock> Future for Subscribe<'_, '_, K, V, Lock>
where
K: PartialEq,
Lock: ScopedRawMutex,
{
type Output = WaitResult<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.wait.as_mut().project();
if let WaitState::Start = test_dbg!(&this.waiter.state) {
this.waiter.start_to_wait(this.queue, cx)
} else {
Poll::Ready(Ok(()))
}
}
}
#[deprecated(
since = "0.1.3",
note = "renamed to `Subscribe` for consistency, use that instead"
)]
pub type EnqueueWait<'a, 'b, K, V, Lock> = Subscribe<'a, 'b, K, V, Lock>;
impl<K: PartialEq, V> Waiter<K, V> {
#[inline(always)]
#[cfg_attr(loom, track_caller)]
fn wake(this: NonNull<Self>, list: &mut List<Self>, wakeup: Wakeup<V>) -> Waker {
Waiter::with_node(this, list, |node| {
let waker = test_dbg!(mem::replace(&mut node.waker, wakeup));
match waker {
Wakeup::Waiting(waker) => waker,
_ => unreachable!("tried to wake a waiter in the {:?} state!", waker),
}
})
}
#[inline(always)]
#[cfg_attr(loom, track_caller)]
fn with_node<T>(
mut this: NonNull<Self>,
_list: &mut List<Self>,
f: impl FnOnce(&mut Node<K, V>) -> T,
) -> T {
unsafe {
this.as_mut().node.with_mut(|node| f(&mut *node))
}
}
fn start_to_wait<Lock>(
mut self: Pin<&mut Self>,
queue: &WaitMap<K, V, Lock>,
cx: &mut Context<'_>,
) -> Poll<WaitResult<()>>
where
Lock: ScopedRawMutex,
{
test_debug!("poll_wait: locking...");
queue.queue.with_lock(move |waiters| {
test_debug!("poll_wait: -> locked");
let mut this = self.as_mut().project();
debug_assert!(
matches!(this.state, WaitState::Start),
"start_to_wait should ONLY be called from the Start state!"
);
let mut queue_state = queue.load();
'to_waiting: loop {
match test_dbg!(queue_state) {
State::Empty => match queue.compare_exchange(queue_state, State::Waiting) {
Ok(_) => break 'to_waiting,
Err(actual) => queue_state = actual,
},
State::Waiting => break 'to_waiting,
State::Closed => return closed(),
}
}
let mut cursor = waiters.cursor_front_mut();
if cursor.any(|n| &n.key == this.key) {
return duplicate();
}
*this.state = WaitState::Waiting;
this.node.as_mut().with_mut(|node| {
unsafe {
(*node).waker = Wakeup::Waiting(cx.waker().clone());
}
});
let ptr = unsafe { NonNull::from(Pin::into_inner_unchecked(self)) };
waiters.push_front(ptr);
Poll::Ready(Ok(()))
})
}
fn poll_wait<Lock>(
mut self: Pin<&mut Self>,
queue: &WaitMap<K, V, Lock>,
cx: &mut Context<'_>,
) -> Poll<WaitResult<V>>
where
Lock: ScopedRawMutex,
{
test_debug!(ptr = ?fmt::ptr(self.as_mut()), "Waiter::poll_wait");
let this = self.as_mut().project();
match test_dbg!(&this.state) {
WaitState::Start => {
let _ = self.start_to_wait(queue, cx)?;
Poll::Pending
}
WaitState::Waiting => {
queue.queue.with_lock(|_waiters| {
this.node.with_mut(|node| unsafe {
let node = &mut *node;
let result;
node.waker = match mem::replace(&mut node.waker, Wakeup::Empty) {
Wakeup::Waiting(waker) => {
result = Poll::Pending;
if !waker.will_wake(cx.waker()) {
Wakeup::Waiting(cx.waker().clone())
} else {
Wakeup::Waiting(waker)
}
}
Wakeup::DataReceived(val) => {
result = notified(val);
Wakeup::Retreived
}
Wakeup::Retreived => {
result = consumed();
Wakeup::Retreived
}
Wakeup::Closed => {
*this.state = WaitState::Completed;
result = closed();
Wakeup::Closed
}
Wakeup::Empty => {
result = never_added();
Wakeup::Closed
}
};
result
})
})
}
WaitState::Completed => consumed(),
}
}
fn release<Lock>(mut self: Pin<&mut Self>, queue: &WaitMap<K, V, Lock>)
where
Lock: ScopedRawMutex,
{
let state = *(self.as_mut().project().state);
let ptr = NonNull::from(unsafe { Pin::into_inner_unchecked(self) });
test_debug!(self = ?fmt::ptr(ptr), ?state, ?queue, "Waiter::release");
if state != WaitState::Waiting {
return;
}
queue.queue.with_lock(|waiters| {
let state = queue.load();
unsafe {
waiters.remove(ptr);
};
if test_dbg!(waiters.is_empty()) && state == State::Waiting {
queue.store(State::Empty);
}
})
}
}
unsafe impl<K: PartialEq, V> Linked<list::Links<Waiter<K, V>>> for Waiter<K, V> {
type Handle = NonNull<Waiter<K, V>>;
fn into_ptr(r: Self::Handle) -> NonNull<Self> {
r
}
unsafe fn from_ptr(ptr: NonNull<Self>) -> Self::Handle {
ptr
}
unsafe fn links(target: NonNull<Self>) -> NonNull<list::Links<Waiter<K, V>>> {
let node = ptr::addr_of!((*target.as_ptr()).node);
(*node).with_mut(|node| {
let links = ptr::addr_of_mut!((*node).links);
NonNull::new_unchecked(links)
})
}
}
impl<K, V, Lock> Future for Wait<'_, K, V, Lock>
where
K: PartialEq,
Lock: ScopedRawMutex,
{
type Output = WaitResult<V>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.waiter.poll_wait(this.queue, cx)
}
}
#[pinned_drop]
impl<K, V, Lock> PinnedDrop for Wait<'_, K, V, Lock>
where
K: PartialEq,
Lock: ScopedRawMutex,
{
fn drop(mut self: Pin<&mut Self>) {
let this = self.project();
this.waiter.release(this.queue);
}
}
impl State {
#[inline]
fn from_bits(bits: usize) -> Self {
Self::try_from_bits(bits).expect("This shouldn't be possible")
}
}
impl FromBits<usize> for State {
const BITS: u32 = 2;
type Error = core::convert::Infallible;
fn try_from_bits(bits: usize) -> Result<Self, Self::Error> {
Ok(match bits as u8 {
bits if bits == Self::Empty as u8 => Self::Empty,
bits if bits == Self::Waiting as u8 => Self::Waiting,
bits if bits == Self::Closed as u8 => Self::Closed,
_ => unsafe {
unreachable_unchecked!("all potential 2-bit patterns should be covered!")
},
})
}
fn into_bits(self) -> usize {
self.into_usize()
}
}
impl State {
const fn into_usize(self) -> usize {
self as u8 as usize
}
}
feature! {
#![feature = "alloc"]
use alloc::sync::Arc;
#[derive(Debug)]
#[pin_project(PinnedDrop)]
pub struct WaitOwned<K: PartialEq, V, Lock: ScopedRawMutex = DefaultMutex> {
queue: Arc<WaitMap<K, V, Lock>>,
#[pin]
waiter: Waiter<K, V>,
}
impl<K: PartialEq, V, Lock: ScopedRawMutex> WaitMap<K, V, Lock> {
pub fn wait_owned(self: &Arc<Self>, key: K) -> WaitOwned<K, V, Lock> {
let waiter = self.waiter(key);
let queue = self.clone();
WaitOwned { queue, waiter }
}
}
impl<K: PartialEq, V, Lock: ScopedRawMutex> Future for WaitOwned<K, V, Lock> {
type Output = WaitResult<V>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.waiter.poll_wait(&*this.queue, cx)
}
}
#[pinned_drop]
impl<K, V, Lock> PinnedDrop for WaitOwned<K, V, Lock>
where
K: PartialEq,
Lock: ScopedRawMutex,
{
fn drop(mut self: Pin<&mut Self>) {
let this = self.project();
this.waiter.release(&*this.queue);
}
}
}
impl<V> fmt::Debug for Wakeup<V> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Empty => f.write_str("Wakeup::Empty"),
Self::Waiting(waker) => f.debug_tuple("Wakeup::Waiting").field(waker).finish(),
Self::DataReceived(_) => f.write_str("Wakeup::DataReceived(..)"),
Self::Retreived => f.write_str("Wakeup::Retrieved"),
Self::Closed => f.write_str("Wakeup::Closed"),
}
}
}