use crate::{
loom::{
atomic::{AtomicUsize, Ordering::*},
cell::UnsafeCell,
},
util::{mutex::Mutex, CachePadded},
wait::{Notify, WaitResult},
};
use core::{fmt, marker::PhantomPinned, pin::Pin, ptr::NonNull};
#[derive(Debug)]
pub(crate) struct WaitQueue<T> {
state: CachePadded<AtomicUsize>,
list: Mutex<List<T>>,
}
#[derive(Debug)]
pub(crate) struct Waiter<T> {
state: CachePadded<AtomicUsize>,
node: UnsafeCell<Node<T>>,
}
#[derive(Debug)]
#[pin_project::pin_project]
struct Node<T> {
next: Link<Waiter<T>>,
prev: Link<Waiter<T>>,
waiter: Option<T>,
#[pin]
_pin: PhantomPinned,
}
type Link<T> = Option<NonNull<T>>;
struct List<T> {
head: Link<Waiter<T>>,
tail: Link<Waiter<T>>,
}
const EMPTY: usize = 0;
const WAITING: usize = 1;
const WAKING: usize = 2;
const CLOSED: usize = 3;
impl<T> WaitQueue<T> {
#[cfg(loom)]
pub(crate) fn new() -> Self {
Self {
state: CachePadded(AtomicUsize::new(EMPTY)),
list: Mutex::new(List::new()),
}
}
#[cfg(not(loom))]
pub(crate) const fn new() -> Self {
Self {
state: CachePadded(AtomicUsize::new(EMPTY)),
list: crate::util::mutex::const_mutex(List::new()),
}
}
}
impl<T: Notify + Unpin> WaitQueue<T> {
#[inline(always)]
pub(crate) fn start_wait(&self, node: Pin<&mut Waiter<T>>, waiter: &T) -> WaitResult {
test_println!("WaitQueue::start_wait({:p})", node);
match test_dbg!(self.state.compare_exchange(WAKING, EMPTY, SeqCst, SeqCst)) {
Ok(_) => return WaitResult::Notified,
Err(CLOSED) => return WaitResult::Closed,
Err(_) => {}
}
self.start_wait_slow(node, waiter)
}
#[cold]
#[inline(never)]
fn start_wait_slow(&self, node: Pin<&mut Waiter<T>>, waiter: &T) -> WaitResult {
test_println!("WaitQueue::start_wait_slow({:p})", node);
let mut list = self.list.lock();
let mut state = self.state.load(Acquire);
loop {
match test_dbg!(state) {
EMPTY => {
match test_dbg!(self
.state
.compare_exchange_weak(EMPTY, WAITING, SeqCst, SeqCst))
{
Ok(_) => break,
Err(actual) => {
debug_assert!(actual == EMPTY || actual == WAKING || actual == CLOSED);
state = actual;
}
}
}
WAKING => {
match test_dbg!(self
.state
.compare_exchange_weak(WAKING, EMPTY, SeqCst, SeqCst))
{
Ok(_) => return WaitResult::Notified,
Err(actual) => {
debug_assert!(actual == WAKING || actual == EMPTY || actual == CLOSED);
state = actual;
}
}
}
CLOSED => return WaitResult::Closed,
_state => {
debug_assert_eq!(_state, WAITING,
"start_wait_slow: unexpected state value {:?} (expected WAITING). this is a bug!",
_state,
);
break;
}
}
}
node.with_node(&mut *list, |node| {
let _prev = node.waiter.replace(waiter.clone());
debug_assert!(
_prev.is_none(),
"start_wait_slow: called with a node that already had a waiter!"
);
});
let _prev_state = test_dbg!(node.state.swap(WAITING, Release));
debug_assert!(
_prev_state == EMPTY || _prev_state == WAKING,
"start_wait_slow: called with a node that was not empty ({}) or woken ({})! actual={}",
EMPTY,
WAKING,
_prev_state,
);
list.enqueue(node);
WaitResult::Wait
}
#[inline(always)]
pub(crate) fn continue_wait(&self, node: Pin<&mut Waiter<T>>, my_waiter: &T) -> WaitResult {
test_println!("WaitQueue::continue_wait({:p})", node);
let state = test_dbg!(node.state.load(Acquire));
match state {
WAKING => return WaitResult::Notified,
CLOSED => return WaitResult::Closed,
_state => {
debug_assert_eq!(
_state, WAITING,
"continue_wait should not be called unless the node has been enqueued"
);
}
}
self.continue_wait_slow(node, my_waiter)
}
#[cold]
#[inline(never)]
fn continue_wait_slow(&self, node: Pin<&mut Waiter<T>>, my_waiter: &T) -> WaitResult {
test_println!("WaitQueue::continue_wait_slow({:p})", node);
let mut list = self.list.lock();
match test_dbg!(node.state.load(Acquire)) {
WAKING => return WaitResult::Notified,
CLOSED => return WaitResult::Closed,
_state => {
debug_assert_eq!(
_state, WAITING,
"continue_wait_slow should not be called unless the node has been enqueued"
);
}
}
node.with_node(&mut *list, |node| {
if let Some(ref mut waiter) = node.waiter {
if !waiter.same(my_waiter) {
*waiter = my_waiter.clone();
}
} else {
node.waiter = Some(my_waiter.clone());
}
});
WaitResult::Wait
}
#[inline(always)]
pub(crate) fn notify(&self) -> bool {
test_println!("WaitQueue::notify()");
let mut state = self.state.load(Acquire);
while test_dbg!(state) == WAKING || state == EMPTY {
match test_dbg!(self
.state
.compare_exchange_weak(state, WAKING, SeqCst, SeqCst))
{
Ok(_) => return false,
Err(actual) => state = actual,
}
}
self.notify_slow(state)
}
#[cold]
#[inline(never)]
fn notify_slow(&self, state: usize) -> bool {
test_println!("WaitQueue::notify_slow(state: {})", state);
let mut list = self.list.lock();
match state {
EMPTY | WAKING => {
if let Err(actual) = self.state.compare_exchange(state, WAKING, SeqCst, SeqCst) {
debug_assert!(actual == EMPTY || actual == WAKING);
self.state.store(WAKING, SeqCst);
}
}
WAITING => {
let waiter = list.dequeue(WAKING);
debug_assert!(
waiter.is_some(),
"if we were in the `WAITING` state, there must be a waiter in the queue!\nself={:#?}",
self,
);
if test_dbg!(list.is_empty()) {
self.state.store(EMPTY, SeqCst);
}
drop(list);
if let Some(waiter) = waiter {
waiter.notify();
return true;
}
}
_weird => {
#[cfg(debug_assertions)]
unreachable!("notify_slow: unexpected state value {:?}", _weird);
}
}
false
}
pub(crate) fn close(&self) {
test_println!("WaitQueue::close()");
test_dbg!(self.state.swap(CLOSED, SeqCst));
let mut list = self.list.lock();
while !list.is_empty() {
if let Some(waiter) = list.dequeue(CLOSED) {
waiter.notify();
}
}
}
}
impl<T: Notify> Waiter<T> {
pub(crate) fn new() -> Self {
Self {
state: CachePadded(AtomicUsize::new(EMPTY)),
node: UnsafeCell::new(Node {
next: None,
prev: None,
waiter: None,
_pin: PhantomPinned,
}),
}
}
#[inline(never)]
pub(crate) fn remove(self: Pin<&mut Self>, q: &WaitQueue<T>) {
test_println!("Waiter::remove({:p})", self);
let mut list = q.list.lock();
unsafe {
list.remove(self);
}
if test_dbg!(list.is_empty()) {
let _ = test_dbg!(q.state.compare_exchange(WAITING, EMPTY, SeqCst, SeqCst));
}
}
#[inline]
pub(crate) fn is_linked(&self) -> bool {
test_dbg!(self.state.load(Acquire)) == WAITING
}
}
impl<T> Waiter<T> {
#[inline(always)]
#[cfg_attr(loom, track_caller)]
fn with_node<U>(&self, _list: &mut List<T>, f: impl FnOnce(&mut Node<T>) -> U) -> U {
self.node.with_mut(|node| unsafe {
f(&mut *node)
})
}
#[cfg_attr(loom, track_caller)]
unsafe fn set_prev(&mut self, prev: Option<NonNull<Waiter<T>>>) {
self.node.with_mut(|node| (*node).prev = prev);
}
#[cfg_attr(loom, track_caller)]
unsafe fn take_prev(&mut self) -> Option<NonNull<Waiter<T>>> {
self.node.with_mut(|node| (*node).prev.take())
}
#[cfg_attr(loom, track_caller)]
unsafe fn take_next(&mut self) -> Option<NonNull<Waiter<T>>> {
self.node.with_mut(|node| (*node).next.take())
}
}
unsafe impl<T: Send> Send for Waiter<T> {}
unsafe impl<T: Send> Sync for Waiter<T> {}
impl<T> List<T> {
const fn new() -> Self {
Self {
head: None,
tail: None,
}
}
fn enqueue(&mut self, waiter: Pin<&mut Waiter<T>>) {
test_println!("List::enqueue({:p})", waiter);
let node = unsafe { waiter.get_unchecked_mut() };
let head = self.head.take();
node.with_node(self, |node| {
node.next = head;
node.prev = None;
});
let ptr = NonNull::from(node);
debug_assert_ne!(
self.head,
Some(ptr),
"tried to enqueue the same waiter twice!"
);
if let Some(mut head) = head {
unsafe {
head.as_mut().set_prev(Some(ptr));
}
}
self.head = Some(ptr);
if self.tail.is_none() {
self.tail = Some(ptr);
}
}
fn dequeue(&mut self, new_state: usize) -> Option<T> {
let mut last = self.tail?;
test_println!("List::dequeue({:?}) -> {:p}", new_state, last);
let last = unsafe { last.as_mut() };
let _prev_state = test_dbg!(last.state.swap(new_state, Release));
debug_assert_eq!(_prev_state, WAITING);
let (prev, waiter) = last.with_node(self, |node| {
node.next = None;
(node.prev.take(), node.waiter.take())
});
match prev {
Some(mut prev) => unsafe {
let _ = prev.as_mut().take_next();
},
None => self.head = None,
}
self.tail = prev;
waiter
}
unsafe fn remove(&mut self, node: Pin<&mut Waiter<T>>) {
test_println!("List::remove({:p})", node);
let node_ref = node.get_unchecked_mut();
let prev = node_ref.take_prev();
let next = node_ref.take_next();
let ptr = NonNull::from(node_ref);
if let Some(mut prev) = prev {
prev.as_mut().with_node(self, |prev| {
debug_assert_eq!(prev.next, Some(ptr));
prev.next = next;
});
} else if self.head == Some(ptr) {
self.head = next;
}
if let Some(mut next) = next {
next.as_mut().with_node(self, |next| {
debug_assert_eq!(next.prev, Some(ptr));
next.prev = prev;
});
} else if self.tail == Some(ptr) {
self.tail = prev;
}
}
fn is_empty(&self) -> bool {
self.head.is_none() && self.tail.is_none()
}
}
impl<T> fmt::Debug for List<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("List")
.field("head", &self.head)
.field("tail", &self.tail)
.field("is_empty", &self.is_empty())
.finish()
}
}
unsafe impl<T: Send> Send for List<T> {}
#[cfg(all(test, not(loom)))]
mod tests {
use super::*;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
#[derive(Debug, Clone)]
struct MockNotify(Arc<AtomicBool>);
impl Notify for MockNotify {
fn notify(self) {
self.0.store(true, Ordering::SeqCst);
}
fn same(&self, Self(other): &Self) -> bool {
Arc::ptr_eq(&self.0, other)
}
}
impl MockNotify {
fn new() -> Self {
Self(Arc::new(AtomicBool::new(false)))
}
fn was_notified(&self) -> bool {
self.0.load(Ordering::SeqCst)
}
}
#[test]
fn notify_one() {
let q = WaitQueue::new();
let notify1 = MockNotify::new();
let notify2 = MockNotify::new();
let mut waiter1 = Box::pin(Waiter::new());
let mut waiter2 = Box::pin(Waiter::new());
assert_eq_dbg!(q.start_wait(waiter1.as_mut(), ¬ify1), WaitResult::Wait);
assert_dbg!(waiter1.is_linked());
assert_eq_dbg!(q.start_wait(waiter2.as_mut(), ¬ify2), WaitResult::Wait);
assert_dbg!(waiter2.is_linked());
assert_dbg!(!notify1.was_notified());
assert_dbg!(!notify2.was_notified());
assert_dbg!(q.notify());
assert_dbg!(notify1.was_notified());
assert_dbg!(!waiter1.is_linked());
assert_dbg!(!notify2.was_notified());
assert_dbg!(waiter2.is_linked());
assert_eq_dbg!(
q.continue_wait(waiter2.as_mut(), ¬ify2),
WaitResult::Wait
);
assert_eq_dbg!(
q.continue_wait(waiter1.as_mut(), ¬ify1),
WaitResult::Notified
);
}
#[test]
fn close() {
let q = WaitQueue::new();
let notify1 = MockNotify::new();
let notify2 = MockNotify::new();
let mut waiter1 = Box::pin(Waiter::new());
let mut waiter2 = Box::pin(Waiter::new());
assert_eq_dbg!(q.start_wait(waiter1.as_mut(), ¬ify1), WaitResult::Wait);
assert_dbg!(waiter1.is_linked());
assert_eq_dbg!(q.start_wait(waiter2.as_mut(), ¬ify2), WaitResult::Wait);
assert_dbg!(waiter2.is_linked());
assert_dbg!(!notify1.was_notified());
assert_dbg!(!notify2.was_notified());
q.close();
assert_dbg!(notify1.was_notified());
assert_dbg!(!waiter1.is_linked());
assert_dbg!(notify2.was_notified());
assert_dbg!(!waiter2.is_linked());
assert_eq_dbg!(
q.continue_wait(waiter2.as_mut(), ¬ify2),
WaitResult::Closed
);
assert_eq_dbg!(
q.continue_wait(waiter1.as_mut(), ¬ify1),
WaitResult::Closed
);
}
#[test]
fn remove_from_middle() {
let q = WaitQueue::new();
let notify1 = MockNotify::new();
let notify2 = MockNotify::new();
let notify3 = MockNotify::new();
let mut waiter1 = Box::pin(Waiter::new());
let mut waiter2 = Box::pin(Waiter::new());
let mut waiter3 = Box::pin(Waiter::new());
assert_eq_dbg!(q.start_wait(waiter1.as_mut(), ¬ify1), WaitResult::Wait);
assert_dbg!(waiter1.is_linked());
assert_eq_dbg!(q.start_wait(waiter2.as_mut(), ¬ify2), WaitResult::Wait);
assert_dbg!(waiter2.is_linked());
assert_eq_dbg!(q.start_wait(waiter3.as_mut(), ¬ify3), WaitResult::Wait);
assert_dbg!(waiter2.is_linked());
assert_dbg!(!notify1.was_notified());
assert_dbg!(!notify2.was_notified());
assert_dbg!(!notify3.was_notified());
waiter2.as_mut().remove(&q);
assert_dbg!(!notify2.was_notified());
drop(waiter2);
assert_dbg!(q.notify());
assert_dbg!(notify1.was_notified());
assert_dbg!(!waiter1.is_linked());
assert_dbg!(!notify3.was_notified());
assert_dbg!(waiter3.is_linked());
assert_eq_dbg!(
q.continue_wait(waiter3.as_mut(), ¬ify3),
WaitResult::Wait
);
assert_eq_dbg!(
q.continue_wait(waiter1.as_mut(), ¬ify1),
WaitResult::Notified
);
}
#[test]
fn remove_after_notify() {
let q = WaitQueue::new();
let notify1 = MockNotify::new();
let notify2 = MockNotify::new();
let notify3 = MockNotify::new();
let mut waiter1 = Box::pin(Waiter::new());
let mut waiter2 = Box::pin(Waiter::new());
let mut waiter3 = Box::pin(Waiter::new());
assert_eq_dbg!(q.start_wait(waiter1.as_mut(), ¬ify1), WaitResult::Wait);
assert_dbg!(waiter1.is_linked());
assert_eq_dbg!(q.start_wait(waiter2.as_mut(), ¬ify2), WaitResult::Wait);
assert_dbg!(waiter2.is_linked());
assert_eq_dbg!(q.start_wait(waiter3.as_mut(), ¬ify3), WaitResult::Wait);
assert_dbg!(waiter2.is_linked());
assert_dbg!(!notify1.was_notified());
assert_dbg!(!notify2.was_notified());
assert_dbg!(!notify3.was_notified());
assert_dbg!(q.notify());
assert_dbg!(notify1.was_notified());
assert_dbg!(!waiter1.is_linked());
assert_dbg!(!notify2.was_notified());
assert_dbg!(waiter2.is_linked());
assert_dbg!(!notify3.was_notified());
assert_dbg!(waiter3.is_linked());
assert_eq_dbg!(
q.continue_wait(waiter3.as_mut(), ¬ify3),
WaitResult::Wait
);
assert_eq_dbg!(
q.continue_wait(waiter2.as_mut(), ¬ify2),
WaitResult::Wait
);
assert_eq_dbg!(
q.continue_wait(waiter1.as_mut(), ¬ify1),
WaitResult::Notified
);
waiter2.as_mut().remove(&q);
assert_dbg!(!notify2.was_notified());
drop(waiter2);
assert_dbg!(!notify3.was_notified());
assert_dbg!(waiter3.is_linked());
assert_eq_dbg!(
q.continue_wait(waiter3.as_mut(), ¬ify3),
WaitResult::Wait
);
}
}