use std::cell::UnsafeCell;
use std::fmt;
use std::future::Future;
use std::mem::align_of;
use std::pin::Pin;
use std::ptr::{from_ref, null_mut, with_exposed_provenance};
use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
#[cfg(not(feature = "loom"))]
use std::sync::atomic::{AtomicBool, AtomicPtr};
#[cfg(not(feature = "loom"))]
use std::sync::{Condvar, Mutex};
use std::task::{Context, Poll, Waker};
#[cfg(feature = "loom")]
use loom::sync::atomic::{AtomicBool, AtomicPtr};
#[cfg(feature = "loom")]
use loom::sync::{Condvar, Mutex};
use crate::opcode::Opcode;
#[repr(align(128))]
pub(crate) struct WaitQueue {
next_entry_ptr: AtomicPtr<Self>,
prev_entry_ptr: AtomicPtr<Self>,
opcode: Opcode,
monitor: Monitor,
}
pub(crate) struct PinnedWaitQueue<'w>(pub(crate) Pin<&'w WaitQueue>);
type AsyncContextCleaner = (usize, fn(&WaitQueue, usize));
struct AsyncContext {
result: AtomicBool,
ready: AtomicBool,
finalized: AtomicBool,
acknowledged: AtomicBool,
waker_lock: AtomicBool,
waker: UnsafeCell<Option<Waker>>,
cleaner: AsyncContextCleaner,
}
#[derive(Debug, Default)]
struct SyncContext {
state: Mutex<Option<bool>>,
cond_var: Condvar,
}
#[derive(Debug)]
enum Monitor {
Async(AsyncContext),
Sync(SyncContext),
}
impl WaitQueue {
pub(crate) const LOCKED_FLAG: usize = align_of::<Self>() >> 1;
pub(crate) const DATA_MASK: usize = (align_of::<Self>() >> 1) - 1;
pub(crate) const ADDR_MASK: usize = !(Self::LOCKED_FLAG | Self::DATA_MASK);
pub(crate) fn new(opcode: Opcode, async_cleanup_fn: Option<AsyncContextCleaner>) -> Self {
let monitor = if let Some(async_cleanup_fn) = async_cleanup_fn {
Monitor::Async(AsyncContext {
result: AtomicBool::new(false),
ready: AtomicBool::new(false),
finalized: AtomicBool::new(false),
acknowledged: AtomicBool::new(false),
waker_lock: AtomicBool::new(false),
waker: UnsafeCell::new(None),
cleaner: async_cleanup_fn,
})
} else {
Monitor::Sync(SyncContext::default())
};
Self {
next_entry_ptr: AtomicPtr::new(null_mut()),
prev_entry_ptr: AtomicPtr::new(null_mut()),
opcode,
monitor,
}
}
pub(crate) fn next_entry_ptr(&self) -> *const Self {
self.next_entry_ptr.load(Acquire)
}
pub(crate) fn prev_entry_ptr(&self) -> *const Self {
self.prev_entry_ptr.load(Acquire)
}
pub(crate) fn update_next_entry_ptr(&self, next_entry_ptr: *const Self) {
debug_assert_eq!(next_entry_ptr as usize % align_of::<Self>(), 0);
self.next_entry_ptr
.store(next_entry_ptr.cast_mut(), Release);
}
pub(crate) fn update_prev_entry_ptr(&self, prev_entry_ptr: *const Self) {
debug_assert_eq!(prev_entry_ptr as usize % align_of::<Self>(), 0);
self.prev_entry_ptr
.store(prev_entry_ptr.cast_mut(), Release);
}
pub(crate) const fn opcode(&self) -> Opcode {
self.opcode
}
pub(crate) fn ref_to_ptr(this: &Self) -> *const Self {
let wait_queue_ptr: *const Self = from_ref(this);
debug_assert_eq!(wait_queue_ptr as usize % align_of::<Self>(), 0);
wait_queue_ptr
}
pub(crate) fn addr_to_ptr(wait_queue_addr: usize) -> *const Self {
debug_assert_eq!(wait_queue_addr % align_of::<Self>(), 0);
with_exposed_provenance(wait_queue_addr)
}
pub(crate) fn install_backward_link(tail_entry_ptr: *const Self) {
let mut entry_ptr = tail_entry_ptr;
while !entry_ptr.is_null() {
entry_ptr = unsafe {
let next_entry_ptr = (*entry_ptr).next_entry_ptr();
if let Some(next_entry) = next_entry_ptr.as_ref() {
if next_entry.prev_entry_ptr().is_null() {
next_entry.update_prev_entry_ptr(entry_ptr);
} else {
debug_assert_eq!(next_entry.prev_entry_ptr(), entry_ptr);
return;
}
}
next_entry_ptr
};
}
}
pub(crate) fn iter_forward<F: FnMut(&Self, Option<&Self>) -> bool>(
tail_entry_ptr: *const Self,
mut f: F,
) {
let mut entry_ptr = tail_entry_ptr;
while !entry_ptr.is_null() {
entry_ptr = unsafe {
let next_entry_ptr = (*entry_ptr).next_entry_ptr();
if let Some(next_entry) = next_entry_ptr.as_ref() {
next_entry.update_prev_entry_ptr(entry_ptr);
}
if f(&*entry_ptr, next_entry_ptr.as_ref()) {
return;
}
next_entry_ptr
};
}
}
pub(crate) fn iter_backward<F: FnMut(&Self, Option<&Self>) -> bool>(
head_entry_ptr: *const Self,
mut f: F,
) {
let mut entry_ptr = head_entry_ptr;
while !entry_ptr.is_null() {
entry_ptr = unsafe {
let prev_entry_ptr = (*entry_ptr).prev_entry_ptr();
if f(&*entry_ptr, prev_entry_ptr.as_ref()) {
return;
}
prev_entry_ptr
};
}
}
pub(crate) fn set_result(&self, result: bool) {
match &self.monitor {
Monitor::Async(async_context) => {
debug_assert!(!async_context.finalized.load(Relaxed));
async_context.result.store(result, Release);
async_context.ready.store(true, Release);
if async_context
.waker_lock
.compare_exchange(false, true, AcqRel, Relaxed)
.is_ok()
{
async_context.waker_lock.store(false, Release);
}
unsafe {
if let Some(waker) = (*async_context.waker.get()).take() {
waker.wake();
}
}
async_context.finalized.store(true, Release);
}
Monitor::Sync(sync_context) => {
if let Ok(mut state) = sync_context.state.lock() {
*state = Some(result);
sync_context.cond_var.notify_one();
}
}
}
}
pub(crate) fn poll_result_async(&self, cx: &mut Context<'_>) -> Poll<bool> {
let Monitor::Async(async_context) = &self.monitor else {
debug_assert!(false, "Logic error");
return Poll::Ready(false);
};
if async_context.finalized.load(Acquire) {
debug_assert!(async_context.ready.load(Relaxed));
async_context.acknowledged.store(true, Release);
return Poll::Ready(async_context.result.load(Relaxed));
}
let waker = cx.waker().clone();
if async_context.ready.load(Acquire) {
waker.wake();
if async_context.finalized.load(Acquire) {
debug_assert!(async_context.ready.load(Relaxed));
async_context.acknowledged.store(true, Release);
return Poll::Ready(async_context.result.load(Relaxed));
}
} else if async_context
.waker_lock
.compare_exchange(false, true, AcqRel, Acquire)
.is_ok()
{
if async_context.ready.load(Acquire) {
waker.wake();
} else {
unsafe {
(*async_context.waker.get()) = Some(waker);
}
}
async_context.waker_lock.store(false, Release);
} else {
waker.wake();
if async_context.finalized.load(Acquire) {
debug_assert!(async_context.ready.load(Relaxed));
async_context.acknowledged.store(true, Release);
return Poll::Ready(async_context.result.load(Relaxed));
}
}
Poll::Pending
}
pub(crate) fn poll_result_sync(&self) -> bool {
let Monitor::Sync(sync_context) = &self.monitor else {
debug_assert!(false, "Logic error");
return false;
};
let Ok(mut state) = sync_context.state.lock() else {
debug_assert!(false, "The mutex can never be poisoned");
return false;
};
loop {
if let Some(result) = (*state).take() {
drop(state);
return result;
}
let Ok(returned) = sync_context.cond_var.wait(state) else {
debug_assert!(false, "The mutex can never be poisoned");
return false;
};
state = returned;
}
}
pub(crate) fn result_acknowledged(&self) {
let Monitor::Async(async_context) = &self.monitor else {
debug_assert!(false, "Logic error");
return;
};
debug_assert!(!async_context.acknowledged.load(Relaxed));
async_context.acknowledged.store(true, Release);
}
pub(crate) fn result_finalized(&self) -> bool {
let Monitor::Async(async_context) = &self.monitor else {
debug_assert!(false, "Logic error");
return false;
};
debug_assert!(!async_context.acknowledged.load(Relaxed));
async_context.finalized.load(Acquire)
}
}
impl Drop for WaitQueue {
#[inline]
fn drop(&mut self) {
let Monitor::Async(async_context) = &self.monitor else {
return;
};
if !async_context.acknowledged.load(Acquire) {
async_context.cleaner.1(self, async_context.cleaner.0);
}
}
}
impl Future for PinnedWaitQueue<'_> {
type Output = bool;
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
this.0.poll_result_async(cx)
}
}
unsafe impl Sync for AsyncContext {}
impl fmt::Debug for AsyncContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AsyncContext")
.field("result", &self.result)
.field("ready", &self.ready)
.field("finalized", &self.finalized)
.field("acknowledged", &self.acknowledged)
.field("waker_lock", &self.waker_lock)
.field("waker", &self.waker)
.field("cleaner_arg", &self.cleaner.0)
.finish()
}
}