use std::cell::UnsafeCell;
use std::future::Future;
use std::mem::align_of;
use std::pin::Pin;
use std::ptr::{from_ref, null_mut, with_exposed_provenance};
use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
#[cfg(not(feature = "loom"))]
use std::sync::atomic::{AtomicPtr, AtomicU16};
use std::task::{Context, Poll, Waker};
#[cfg(not(feature = "loom"))]
use std::thread::{Thread, current, park, yield_now};
#[cfg(feature = "loom")]
use loom::sync::atomic::{AtomicPtr, AtomicU16};
#[cfg(feature = "loom")]
use loom::thread::{Thread, current, park, yield_now};
use crate::opcode::Opcode;
use crate::sync_primitive::SyncPrimitive;
#[repr(align(128))]
#[derive(Debug)]
pub(crate) struct WaitQueue {
next_entry_ptr: AtomicPtr<Self>,
prev_entry_ptr: AtomicPtr<Self>,
addr: usize,
opcode: Opcode,
state: AtomicU16,
monitor: Monitor,
}
pub(crate) struct PinnedWaitQueue<'w>(pub(crate) Pin<&'w WaitQueue>);
#[derive(Debug)]
struct AsyncContext {
waker: UnsafeCell<Option<Waker>>,
cleaner: fn(&WaitQueue),
}
#[derive(Debug)]
struct SyncContext {
thread: UnsafeCell<Option<Thread>>,
}
#[derive(Debug)]
enum Monitor {
Async(AsyncContext),
Sync(SyncContext),
}
impl WaitQueue {
pub(crate) const ERROR_WRONG_MODE: u8 = u8::MAX;
pub(crate) const LOCKED_FLAG: usize = align_of::<Self>() >> 1;
pub(crate) const DATA_MASK: usize = (align_of::<Self>() >> 1) - 1;
pub(crate) const ADDR_MASK: usize = !(Self::LOCKED_FLAG | Self::DATA_MASK);
const ENQUEUED: u16 = 1_u16 << u8::BITS;
const RESULT_SET: u16 = 1_u16 << (u8::BITS + 1);
const WAKER_SET: u16 = 1_u16 << (u8::BITS + 2);
const RESULT_FINALIZED: u16 = 1_u16 << (u8::BITS + 3);
const RESULT_ACKED: u16 = 1_u16 << (u8::BITS + 4);
pub(crate) fn new_async(opcode: Opcode, cleaner: fn(&WaitQueue), addr: usize) -> Self {
let monitor = Monitor::Async(AsyncContext {
waker: UnsafeCell::new(None),
cleaner,
});
Self {
next_entry_ptr: AtomicPtr::new(null_mut()),
prev_entry_ptr: AtomicPtr::new(null_mut()),
addr,
opcode,
state: AtomicU16::new(0),
monitor,
}
}
pub(crate) fn new_sync(opcode: Opcode, addr: usize) -> Self {
let monitor = Monitor::Sync(SyncContext {
thread: UnsafeCell::new(None),
});
Self {
next_entry_ptr: AtomicPtr::new(null_mut()),
prev_entry_ptr: AtomicPtr::new(null_mut()),
addr,
opcode,
state: AtomicU16::new(0),
monitor,
}
}
pub(crate) fn next_entry_ptr(&self) -> *const Self {
self.next_entry_ptr.load(Acquire)
}
pub(crate) fn prev_entry_ptr(&self) -> *const Self {
self.prev_entry_ptr.load(Acquire)
}
pub(crate) fn update_next_entry_ptr(&self, next_entry_ptr: *const Self) {
debug_assert_eq!(next_entry_ptr as usize % align_of::<Self>(), 0);
self.next_entry_ptr
.store(next_entry_ptr.cast_mut(), Release);
}
pub(crate) fn update_prev_entry_ptr(&self, prev_entry_ptr: *const Self) {
debug_assert_eq!(prev_entry_ptr as usize % align_of::<Self>(), 0);
self.prev_entry_ptr
.store(prev_entry_ptr.cast_mut(), Release);
}
pub(crate) const fn opcode(&self) -> Opcode {
self.opcode
}
pub(crate) const fn ref_to_ptr(this: &Self) -> *const Self {
let wait_queue_ptr: *const Self = from_ref(this);
wait_queue_ptr
}
pub(crate) fn addr_to_ptr(wait_queue_addr: usize) -> *const Self {
debug_assert_eq!(wait_queue_addr % align_of::<Self>(), 0);
with_exposed_provenance(wait_queue_addr)
}
pub(crate) fn sync_primitive_ref<S: SyncPrimitive>(&self) -> &S {
unsafe { &*with_exposed_provenance::<S>(self.addr) }
}
pub(crate) fn set_prev_ptr(tail_entry_ptr: *const Self) {
let mut entry_ptr = tail_entry_ptr;
while !entry_ptr.is_null() {
entry_ptr = unsafe {
let next_entry_ptr = (*entry_ptr).next_entry_ptr();
if let Some(next_entry) = next_entry_ptr.as_ref() {
if next_entry.prev_entry_ptr().is_null() {
next_entry.update_prev_entry_ptr(entry_ptr);
} else {
debug_assert_eq!(next_entry.prev_entry_ptr(), entry_ptr);
return;
}
}
next_entry_ptr
};
}
}
pub(crate) fn iter_forward<F: FnMut(&Self, Option<&Self>) -> bool>(
tail_entry_ptr: *const Self,
set_prev: bool,
mut f: F,
) {
let mut entry_ptr = tail_entry_ptr;
while !entry_ptr.is_null() {
entry_ptr = unsafe {
let next_entry_ptr = (*entry_ptr).next_entry_ptr();
if set_prev {
if let Some(next_entry) = next_entry_ptr.as_ref() {
next_entry.update_prev_entry_ptr(entry_ptr);
}
}
if f(&*entry_ptr, next_entry_ptr.as_ref()) {
return;
}
next_entry_ptr
};
}
}
pub(crate) fn iter_backward<F: FnMut(&Self, Option<&Self>) -> bool>(
head_entry_ptr: *const Self,
mut f: F,
) {
let mut entry_ptr = head_entry_ptr;
while !entry_ptr.is_null() {
entry_ptr = unsafe {
let prev_entry_ptr = (*entry_ptr).prev_entry_ptr();
if f(&*entry_ptr, prev_entry_ptr.as_ref()) {
return;
}
prev_entry_ptr
};
}
}
pub(crate) fn is_sync(&self) -> bool {
matches!(self.monitor, Monitor::Sync(_))
}
pub(crate) fn set_result(&self, result: u8) {
let mut state = self.state.load(Acquire);
loop {
debug_assert_eq!(state & Self::RESULT_SET, 0);
debug_assert_eq!(state & Self::RESULT_FINALIZED, 0);
let next_state = (state | Self::RESULT_SET) | u16::from(result);
match self
.state
.compare_exchange_weak(state, next_state, AcqRel, Acquire)
{
Ok(_) => {
state = next_state;
break;
}
Err(new_state) => state = new_state,
}
}
if state & Self::WAKER_SET == Self::WAKER_SET {
unsafe {
match &self.monitor {
Monitor::Async(async_context) => {
if let Some(waker) = (*async_context.waker.get()).take() {
self.state.fetch_or(Self::RESULT_FINALIZED, Release);
waker.wake();
return;
}
}
Monitor::Sync(sync_context) => {
if let Some(thread) = (*sync_context.thread.get()).take() {
self.state.fetch_or(Self::RESULT_FINALIZED, Release);
thread.unpark();
return;
}
}
}
}
}
self.state.fetch_or(Self::RESULT_FINALIZED, Release);
}
pub(crate) fn poll_result_async(&self, cx: &mut Context<'_>) -> Poll<u8> {
let Monitor::Async(async_context) = &self.monitor else {
return Poll::Ready(Self::ERROR_WRONG_MODE);
};
if let Some(result) = self.try_acknowledge_result() {
return Poll::Ready(result);
}
let mut this_waker = None;
let state = self.state.load(Acquire);
if state & Self::RESULT_SET == Self::RESULT_SET {
if let Some(result) = self.try_acknowledge_result() {
return Poll::Ready(result);
}
} else if state & Self::WAKER_SET == Self::WAKER_SET {
if self
.state
.compare_exchange_weak(state, state & !Self::WAKER_SET, AcqRel, Acquire)
.is_ok()
{
this_waker.replace(cx.waker().clone());
}
} else {
this_waker.replace(cx.waker().clone());
}
if let Some(waker) = this_waker {
unsafe {
(*async_context.waker.get()).replace(waker);
}
if self.state.fetch_or(Self::WAKER_SET, Release) & Self::RESULT_SET == Self::RESULT_SET
{
cx.waker().wake_by_ref();
}
} else {
cx.waker().wake_by_ref();
}
Poll::Pending
}
pub(crate) fn poll_result_sync(&self) -> u8 {
let Monitor::Sync(sync_context) = &self.monitor else {
return Self::ERROR_WRONG_MODE;
};
loop {
if let Some(result) = self.try_acknowledge_result() {
return result;
}
let mut this_thread = None;
let state = self.state.load(Acquire);
if state & Self::RESULT_SET == Self::RESULT_SET {
if let Some(result) = self.try_acknowledge_result() {
return result;
}
} else if state & Self::WAKER_SET == Self::WAKER_SET {
if self
.state
.compare_exchange_weak(state, state & !Self::WAKER_SET, AcqRel, Acquire)
.is_ok()
{
this_thread.replace(current());
}
} else {
this_thread.replace(current());
}
if let Some(thread) = this_thread {
unsafe {
(*sync_context.thread.get()).replace(thread);
}
if self.state.fetch_or(Self::WAKER_SET, Release) & Self::RESULT_SET
== Self::RESULT_SET
{
yield_now();
} else {
park();
}
} else {
yield_now();
}
}
}
pub(crate) fn enqueued(&self) {
let Monitor::Async(_) = &self.monitor else {
return;
};
debug_assert_eq!(self.state.load(Relaxed) & Self::ENQUEUED, 0);
self.state.fetch_or(Self::ENQUEUED, Release);
}
pub(crate) fn result_finalized(&self) -> bool {
let state = self.state.load(Acquire);
state & Self::RESULT_FINALIZED == Self::RESULT_FINALIZED
}
pub(crate) fn acknowledge_result_sync(&self) -> u8 {
loop {
if let Some(result) = self.try_acknowledge_result() {
return result;
}
yield_now();
}
}
pub(crate) fn try_acknowledge_result(&self) -> Option<u8> {
let state = self.state.load(Acquire);
if state & Self::RESULT_FINALIZED == Self::RESULT_FINALIZED {
debug_assert_ne!(state & Self::RESULT_SET, 0);
self.state.fetch_or(Self::RESULT_ACKED, Release);
return u8::try_from(state & ((1_u16 << u8::BITS) - 1)).ok();
}
None
}
}
impl Drop for WaitQueue {
#[inline]
fn drop(&mut self) {
let Monitor::Async(async_context) = &self.monitor else {
return;
};
let state = self.state.load(Acquire);
if state & Self::ENQUEUED == Self::ENQUEUED && state & Self::RESULT_ACKED == 0 {
(async_context.cleaner)(self);
}
}
}
impl Future for PinnedWaitQueue<'_> {
type Output = u8;
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
this.0.poll_result_async(cx)
}
}
unsafe impl Send for Monitor {}
unsafe impl Sync for Monitor {}