#[cfg(any(test, maitake_ultraverbose))]
use crate::util::fmt;
use crate::{
blocking::{DefaultMutex, Mutex, ScopedRawMutex},
loom::{
cell::UnsafeCell,
sync::atomic::{AtomicUsize, Ordering::*},
},
util::{CachePadded, WakeBatch},
WaitResult,
};
use cordyceps::{
list::{self, List},
Linked,
};
use core::{
future::Future,
marker::PhantomPinned,
mem,
pin::Pin,
ptr::{self, NonNull},
task::{Context, Poll, Waker},
};
use mycelium_bitfield::{bitfield, enum_from_bits, FromBits};
use pin_project::{pin_project, pinned_drop};
#[cfg(test)]
mod tests;
#[derive(Debug)]
pub struct WaitQueue<Lock: ScopedRawMutex = DefaultMutex> {
state: CachePadded<AtomicUsize>,
queue: Mutex<List<Waiter>, Lock>,
}
#[derive(Debug)]
#[pin_project(PinnedDrop)]
#[must_use = "futures do nothing unless `.await`ed or `poll`ed"]
pub struct Wait<'a, Lock: ScopedRawMutex = DefaultMutex> {
queue: &'a WaitQueue<Lock>,
#[pin]
waiter: Waiter,
}
#[derive(Debug)]
#[repr(C)]
#[pin_project]
struct Waiter {
#[pin]
node: UnsafeCell<Node>,
state: WaitStateBits,
}
#[derive(Debug)]
struct Node {
links: list::Links<Waiter>,
waker: Wakeup,
_pin: PhantomPinned,
}
bitfield! {
#[derive(Eq, PartialEq)]
struct QueueState<usize> {
const STATE: State;
const WAKE_ALLS = ..;
}
}
bitfield! {
#[derive(Eq, PartialEq)]
struct WaitStateBits<usize> {
const STATE: WaitState;
const WAKE_ALLS = ..;
}
}
enum_from_bits! {
#[derive(Debug, Eq, PartialEq)]
enum WaitState<u8> {
Start = 0b00,
Waiting = 0b01,
Woken = 0b10,
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[repr(u8)]
enum State {
Empty = 0b00,
Waiting = 0b01,
Woken = 0b10,
Closed = 0b11,
}
#[derive(Clone, Debug)]
enum Wakeup {
Empty,
Waiting(Waker),
One,
All,
Closed,
}
impl WaitQueue {
loom_const_fn! {
#[must_use]
pub fn new() -> Self {
Self::new_with_raw_mutex(DefaultMutex::new())
}
}
}
impl<Lock> Default for WaitQueue<Lock>
where
Lock: ScopedRawMutex + Default,
{
fn default() -> Self {
Self::new_with_raw_mutex(Lock::default())
}
}
impl<Lock> WaitQueue<Lock>
where
Lock: ScopedRawMutex,
{
loom_const_fn! {
#[must_use]
pub fn new_with_raw_mutex(lock: Lock) -> Self {
Self::make(State::Empty, lock)
}
}
loom_const_fn! {
#[must_use]
pub(crate) fn new_woken(lock: Lock) -> Self {
Self::make(State::Woken, lock)
}
}
loom_const_fn! {
#[must_use]
fn make(state: State, lock: Lock) -> Self {
Self {
state: CachePadded::new(AtomicUsize::new(state.into_usize())),
queue: Mutex::new_with_raw_mutex(List::new(), lock),
}
}
}
#[inline]
pub fn wake(&self) {
let mut state = self.load();
loop {
match state.get(QueueState::STATE) {
State::Closed => return,
State::Waiting => break,
_ => {}
}
let next = state.with_state(State::Woken);
match self.compare_exchange(state, next) {
Ok(_) => return,
Err(actual) => state = actual,
}
}
let waker = self.queue.with_lock(|queue| {
test_debug!("wake: -> locked");
state = self.load();
self.wake_locked(queue, state)
});
if let Some(waker) = waker {
waker.wake();
}
}
pub fn wake_all(&self) {
let mut batch = WakeBatch::new();
let mut waiters_remaining = true;
let done = self.queue.with_lock(|queue| {
let state = self.load();
match test_dbg!(state.get(QueueState::STATE)) {
State::Woken | State::Empty => {
self.state.fetch_add(QueueState::ONE_WAKE_ALL, SeqCst);
true
}
State::Closed => true,
State::Waiting => {
let next_state = QueueState::new()
.with_state(State::Empty)
.with(QueueState::WAKE_ALLS, state.get(QueueState::WAKE_ALLS) + 1);
self.compare_exchange(state, next_state)
.expect("state should not have transitioned while locked");
waiters_remaining =
test_dbg!(Self::drain_to_wake_batch(&mut batch, queue, Wakeup::All));
false
}
}
});
if done {
return;
}
batch.wake_all();
while waiters_remaining {
self.queue.with_lock(|queue| {
waiters_remaining = Self::drain_to_wake_batch(&mut batch, queue, Wakeup::All);
});
batch.wake_all();
}
}
pub fn close(&self) {
let state = self.state.fetch_or(State::Closed.into_usize(), SeqCst);
let state = test_dbg!(QueueState::from_bits(state));
if state.get(QueueState::STATE) != State::Waiting {
return;
}
let mut batch = WakeBatch::new();
let mut waking = true;
while waking {
waking = self
.queue
.with_lock(|queue| Self::drain_to_wake_batch(&mut batch, queue, Wakeup::Closed));
batch.wake_all();
}
}
pub fn wait(&self) -> Wait<'_, Lock> {
Wait {
queue: self,
waiter: self.waiter(),
}
}
pub(crate) fn try_wait(&self) -> Poll<WaitResult<()>> {
let mut state = self.load();
let initial_wake_alls = state.get(QueueState::WAKE_ALLS);
while state.get(QueueState::STATE) == State::Woken {
match self.compare_exchange(state, state.with_state(State::Empty)) {
Ok(_) => return Poll::Ready(Ok(())),
Err(actual) => state = actual,
}
}
match state.get(QueueState::STATE) {
State::Closed => crate::closed(),
_ if state.get(QueueState::WAKE_ALLS) > initial_wake_alls => Poll::Ready(Ok(())),
State::Empty | State::Waiting => Poll::Pending,
State::Woken => Poll::Ready(Ok(())),
}
}
pub async fn wait_for<F: FnMut() -> bool>(&self, mut f: F) -> WaitResult<()> {
loop {
let wait = self.wait();
let mut wait = core::pin::pin!(wait);
let _ = wait.as_mut().subscribe()?;
if f() {
return Ok(());
}
wait.await?;
}
}
pub async fn wait_for_value<T, F: FnMut() -> Option<T>>(&self, mut f: F) -> WaitResult<T> {
loop {
let wait = self.wait();
let mut wait = core::pin::pin!(wait);
match wait.as_mut().subscribe() {
Poll::Ready(wr) => wr?,
Poll::Pending => {}
}
if let Some(t) = f() {
return Ok(t);
}
wait.await?;
}
}
#[must_use]
pub fn is_closed(&self) -> bool {
self.load().get(QueueState::STATE) == State::Closed
}
fn waiter(&self) -> Waiter {
let current_wake_alls = test_dbg!(self.load().get(QueueState::WAKE_ALLS));
let state = WaitStateBits::new()
.with(WaitStateBits::WAKE_ALLS, current_wake_alls)
.with(WaitStateBits::STATE, WaitState::Start);
Waiter {
state,
node: UnsafeCell::new(Node {
links: list::Links::new(),
waker: Wakeup::Empty,
_pin: PhantomPinned,
}),
}
}
#[cfg_attr(test, track_caller)]
fn load(&self) -> QueueState {
#[allow(clippy::let_and_return)]
let state = QueueState::from_bits(self.state.load(SeqCst));
test_debug!("state.load() = {state:?}");
state
}
#[cfg_attr(test, track_caller)]
fn store(&self, state: QueueState) {
test_debug!("state.store({state:?}");
self.state.store(state.0, SeqCst);
}
#[cfg_attr(test, track_caller)]
fn compare_exchange(
&self,
current: QueueState,
new: QueueState,
) -> Result<QueueState, QueueState> {
#[allow(clippy::let_and_return)]
let res = self
.state
.compare_exchange(current.0, new.0, SeqCst, SeqCst)
.map(QueueState::from_bits)
.map_err(QueueState::from_bits);
test_debug!("state.compare_exchange({current:?}, {new:?}) = {res:?}");
res
}
#[cold]
#[inline(never)]
fn wake_locked(&self, queue: &mut List<Waiter>, curr: QueueState) -> Option<Waker> {
let state = curr.get(QueueState::STATE);
if test_dbg!(state) != State::Waiting {
if let Err(actual) = self.compare_exchange(curr, curr.with_state(State::Woken)) {
debug_assert!(actual.get(QueueState::STATE) != State::Waiting);
self.store(actual.with_state(State::Woken));
}
return None;
}
let node = queue
.pop_back()
.expect("if we are in the Waiting state, there must be waiters in the queue");
let waker = Waiter::wake(node, queue, Wakeup::One);
if test_dbg!(queue.is_empty()) {
self.store(curr.with_state(State::Empty));
}
waker
}
fn drain_to_wake_batch(
batch: &mut WakeBatch,
queue: &mut List<Waiter>,
wakeup: Wakeup,
) -> bool {
while let Some(node) = queue.pop_back() {
let Some(waker) = Waiter::wake(node, queue, wakeup.clone()) else {
continue;
};
if batch.add_waker(waker) {
continue;
}
break;
}
!queue.is_empty()
}
}
impl Waiter {
#[inline(always)]
#[cfg_attr(loom, track_caller)]
fn wake(this: NonNull<Self>, list: &mut List<Self>, wakeup: Wakeup) -> Option<Waker> {
Waiter::with_node(this, list, |node| {
let waker = test_dbg!(mem::replace(&mut node.waker, wakeup));
match waker {
Wakeup::Waiting(waker) => Some(waker),
Wakeup::Empty => None,
_ => unreachable!("tried to wake a waiter in the {:?} state!", waker),
}
})
}
#[inline(always)]
#[cfg_attr(loom, track_caller)]
fn with_node<T>(
mut this: NonNull<Self>,
_list: &mut List<Self>,
f: impl FnOnce(&mut Node) -> T,
) -> T {
unsafe {
this.as_mut().node.with_mut(|node| f(&mut *node))
}
}
fn poll_wait<Lock>(
mut self: Pin<&mut Self>,
queue: &WaitQueue<Lock>,
waker: Option<&Waker>,
) -> Poll<WaitResult<()>>
where
Lock: ScopedRawMutex,
{
test_debug!(ptr = ?fmt::ptr(self.as_mut()), "Waiter::poll_wait");
let ptr = unsafe { NonNull::from(Pin::into_inner_unchecked(self.as_mut())) };
let mut this = self.as_mut().project();
match test_dbg!(this.state.get(WaitStateBits::STATE)) {
WaitState::Start => {
let queue_state = queue.load();
if queue
.compare_exchange(
queue_state.with_state(State::Woken),
queue_state.with_state(State::Empty),
)
.is_ok()
{
this.state.set(WaitStateBits::STATE, WaitState::Woken);
return Poll::Ready(Ok(()));
}
test_debug!("poll_wait: locking...");
queue.queue.with_lock(move |waiters| {
test_debug!("poll_wait: -> locked");
let mut queue_state = queue.load();
if queue_state.get(QueueState::WAKE_ALLS)
!= this.state.get(WaitStateBits::WAKE_ALLS)
{
this.state.set(WaitStateBits::STATE, WaitState::Woken);
return Poll::Ready(Ok(()));
}
'to_waiting: loop {
match test_dbg!(queue_state.get(QueueState::STATE)) {
State::Empty => {
match queue.compare_exchange(
queue_state,
queue_state.with_state(State::Waiting),
) {
Ok(_) => break 'to_waiting,
Err(actual) => queue_state = actual,
}
}
State::Waiting => break 'to_waiting,
State::Woken => {
match queue.compare_exchange(
queue_state,
queue_state.with_state(State::Empty),
) {
Ok(_) => {
this.state.set(WaitStateBits::STATE, WaitState::Woken);
return Poll::Ready(Ok(()));
}
Err(actual) => queue_state = actual,
}
}
State::Closed => return crate::closed(),
}
}
this.state.set(WaitStateBits::STATE, WaitState::Waiting);
if let Some(waker) = waker {
this.node.as_mut().with_mut(|node| {
unsafe {
debug_assert!(matches!((*node).waker, Wakeup::Empty));
(*node).waker = Wakeup::Waiting(waker.clone());
}
});
}
waiters.push_front(ptr);
Poll::Pending
})
}
WaitState::Waiting => {
queue.queue.with_lock(|_waiters| {
this.node.with_mut(|node| unsafe {
let node = &mut *node;
match node.waker {
Wakeup::Waiting(ref mut curr_waker) => {
match waker {
Some(waker) if !curr_waker.will_wake(waker) => {
*curr_waker = waker.clone()
}
_ => {}
}
Poll::Pending
}
Wakeup::All | Wakeup::One => {
this.state.set(WaitStateBits::STATE, WaitState::Woken);
Poll::Ready(Ok(()))
}
Wakeup::Closed => {
this.state.set(WaitStateBits::STATE, WaitState::Woken);
crate::closed()
}
Wakeup::Empty => {
if let Some(waker) = waker {
node.waker = Wakeup::Waiting(waker.clone());
}
Poll::Pending
}
}
})
})
}
WaitState::Woken => Poll::Ready(Ok(())),
}
}
fn release<Lock>(mut self: Pin<&mut Self>, queue: &WaitQueue<Lock>)
where
Lock: ScopedRawMutex,
{
let state = *(self.as_mut().project().state);
let ptr = NonNull::from(unsafe { Pin::into_inner_unchecked(self) });
test_debug!(self = ?fmt::ptr(ptr), ?state, ?queue.state, "Waiter::release");
if state.get(WaitStateBits::STATE) != WaitState::Waiting {
return;
}
let next_waiter = queue.queue.with_lock(|waiters| {
let state = queue.load();
unsafe {
waiters.remove(ptr);
};
if test_dbg!(waiters.is_empty()) && state.get(QueueState::STATE) == State::Waiting {
queue.store(state.with_state(State::Empty));
}
if Waiter::with_node(ptr, waiters, |node| matches!(&node.waker, Wakeup::One)) {
queue.wake_locked(waiters, state)
} else {
None
}
});
if let Some(next) = next_waiter {
next.wake();
}
}
}
unsafe impl Linked<list::Links<Waiter>> for Waiter {
type Handle = NonNull<Waiter>;
fn into_ptr(r: Self::Handle) -> NonNull<Self> {
r
}
unsafe fn from_ptr(ptr: NonNull<Self>) -> Self::Handle {
ptr
}
unsafe fn links(target: NonNull<Self>) -> NonNull<list::Links<Waiter>> {
let node = ptr::addr_of!((*target.as_ptr()).node);
(*node).with_mut(|node| {
let links = ptr::addr_of_mut!((*node).links);
NonNull::new_unchecked(links)
})
}
}
impl<Lock: ScopedRawMutex> Wait<'_, Lock> {
#[inline]
#[must_use]
pub fn waits_on(&self, queue: &WaitQueue<Lock>) -> bool {
ptr::eq(self.queue, queue)
}
#[inline]
#[must_use]
pub fn same_queue(&self, other: &Wait<'_, Lock>) -> bool {
ptr::eq(self.queue, other.queue)
}
pub fn subscribe(self: Pin<&mut Self>) -> Poll<WaitResult<()>> {
let this = self.project();
this.waiter.poll_wait(this.queue, None)
}
}
impl<Lock: ScopedRawMutex> Future for Wait<'_, Lock> {
type Output = WaitResult<()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.waiter.poll_wait(this.queue, Some(cx.waker()))
}
}
#[pinned_drop]
impl<Lock: ScopedRawMutex> PinnedDrop for Wait<'_, Lock> {
fn drop(mut self: Pin<&mut Self>) {
let this = self.project();
this.waiter.release(this.queue);
}
}
impl QueueState {
const ONE_WAKE_ALL: usize = Self::WAKE_ALLS.first_bit();
fn with_state(self, state: State) -> Self {
self.with(Self::STATE, state)
}
}
impl FromBits<usize> for State {
const BITS: u32 = 2;
type Error = core::convert::Infallible;
fn try_from_bits(bits: usize) -> Result<Self, Self::Error> {
Ok(match bits as u8 {
bits if bits == Self::Empty as u8 => Self::Empty,
bits if bits == Self::Waiting as u8 => Self::Waiting,
bits if bits == Self::Woken as u8 => Self::Woken,
bits if bits == Self::Closed as u8 => Self::Closed,
_ => unsafe {
unreachable_unchecked!("all potential 2-bit patterns should be covered!")
},
})
}
fn into_bits(self) -> usize {
self.into_usize()
}
}
impl State {
const fn into_usize(self) -> usize {
self as u8 as usize
}
}
feature! {
#![feature = "alloc"]
use alloc::sync::Arc;
#[derive(Debug)]
#[pin_project(PinnedDrop)]
pub struct WaitOwned<Lock: ScopedRawMutex = DefaultMutex> {
queue: Arc<WaitQueue<Lock>>,
#[pin]
waiter: Waiter,
}
impl<Lock: ScopedRawMutex> WaitQueue<Lock> {
pub fn wait_owned(self: &Arc<Self>) -> WaitOwned<Lock> {
let waiter = self.waiter();
let queue = self.clone();
WaitOwned { queue, waiter }
}
}
impl<Lock: ScopedRawMutex> WaitOwned<Lock> {
#[inline]
#[must_use]
pub fn waits_on(&self, queue: &WaitQueue<Lock>) -> bool {
ptr::eq(&*self.queue, queue)
}
#[inline]
#[must_use]
pub fn same_queue(&self, other: &WaitOwned<Lock>) -> bool {
Arc::ptr_eq(&self.queue, &other.queue)
}
pub fn subscribe(self: Pin<&mut Self>) -> Poll<WaitResult<()>> {
let this = self.project();
this.waiter.poll_wait(this.queue, None)
}
}
impl<Lock: ScopedRawMutex> Future for WaitOwned<Lock> {
type Output = WaitResult<()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.waiter.poll_wait(&*this.queue, Some(cx.waker()))
}
}
#[pinned_drop]
impl<Lock: ScopedRawMutex> PinnedDrop for WaitOwned<Lock> {
fn drop(mut self: Pin<&mut Self>) {
let this = self.project();
this.waiter.release(&*this.queue);
}
}
}