use std::future::Future;
use std::pin::Pin;
use std::ptr::addr_of_mut;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::{AcqRel, Relaxed};
use std::sync::{Condvar, Mutex};
use std::task::{Context, Poll, Waker};
const ASYNC: usize = 1_usize;
#[derive(Debug, Default)]
pub(crate) struct WaitQueue {
wait_queue: AtomicUsize,
}
impl WaitQueue {
#[inline]
pub(crate) fn wait_sync<T, F: FnOnce() -> Result<T, ()>>(&self, f: F) -> Result<T, ()> {
let mut current = self.wait_queue.load(Relaxed);
let mut entry = SyncWait::new(current);
while let Err(actual) =
self.wait_queue
.compare_exchange(current, addr_of_mut!(entry) as usize, AcqRel, Relaxed)
{
current = actual;
entry.next = current;
}
let result = f();
if result.is_ok() {
self.signal();
}
entry.wait();
result
}
#[inline]
pub(crate) fn push_async_entry<T, F: FnOnce() -> Result<T, ()>>(
&self,
async_wait: *mut AsyncWait,
f: F,
) -> Result<T, ()> {
let async_wait_mut = unsafe { &mut *async_wait };
debug_assert!(async_wait_mut.mutex.is_none());
let mut current = self.wait_queue.load(Relaxed);
async_wait_mut.next = current;
async_wait_mut.mutex.replace(Mutex::new((false, None)));
while let Err(actual) = self.wait_queue.compare_exchange(
current,
(async_wait as usize) | ASYNC,
AcqRel,
Relaxed,
) {
current = actual;
async_wait_mut.next = current;
}
if let Ok(result) = f() {
self.signal();
if async_wait_mut.try_wait() {
async_wait_mut.mutex.take();
return Ok(result);
}
}
Err(())
}
#[inline]
pub(crate) fn signal(&self) {
let mut current = self.wait_queue.swap(0, AcqRel);
while (current & (!ASYNC)) != 0 {
if (current & ASYNC) == 0 {
let entry_ref = unsafe { &*SyncWait::reinterpret(current) };
let next = entry_ref.next;
entry_ref.signal();
current = next;
} else {
let entry_ref = unsafe { &*AsyncWait::reinterpret(current & (!ASYNC)) };
let next = entry_ref.next;
entry_ref.signal();
current = next;
}
}
}
}
#[derive(Debug, Default)]
pub(crate) struct AsyncWait {
next: usize,
mutex: Option<Mutex<(bool, Option<Waker>)>>,
}
impl AsyncWait {
#[inline]
pub(crate) fn mut_ptr(&mut self) -> *mut AsyncWait {
addr_of_mut!(*self)
}
fn signal(&self) {
if let Some(mutex) = self.mutex.as_ref() {
if let Ok(mut locked) = mutex.lock() {
locked.0 = true;
if let Some(waker) = locked.1.take() {
waker.wake();
}
}
} else {
unreachable!();
}
}
fn try_wait(&self) -> bool {
if let Some(mutex) = self.mutex.as_ref() {
if let Ok(locked) = mutex.lock() {
if locked.0 {
return true;
}
}
}
false
}
unsafe fn reinterpret(val: usize) -> *mut AsyncWait {
val as *mut AsyncWait
}
}
impl Future for AsyncWait {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(mutex) = self.mutex.as_ref() {
if let Ok(mut locked) = mutex.lock() {
if locked.0 {
return Poll::Ready(());
}
locked.1.replace(cx.waker().clone());
}
Poll::Pending
} else {
Poll::Ready(())
}
}
}
#[derive(Debug)]
struct SyncWait {
next: usize,
condvar: Condvar,
mutex: Mutex<bool>,
}
impl SyncWait {
fn new(next: usize) -> SyncWait {
#[allow(clippy::mutex_atomic)]
SyncWait {
next,
condvar: Condvar::new(),
mutex: Mutex::new(false),
}
}
fn wait(&self) {
#[allow(clippy::mutex_atomic)]
let mut completed = self.mutex.lock().unwrap();
while !*completed {
completed = self.condvar.wait(completed).unwrap();
}
}
fn signal(&self) {
#[allow(clippy::mutex_atomic)]
let mut completed = self.mutex.lock().unwrap();
*completed = true;
self.condvar.notify_one();
}
unsafe fn reinterpret(val: usize) -> *mut SyncWait {
val as *mut SyncWait
}
}