use crate::sync::AtomicUsize;
use core::{
cell::UnsafeCell,
fmt::{self, Debug},
sync::atomic::Ordering,
task::Waker,
};
const WAITING: usize = 0;
const REGISTERING: usize = 0b01;
const WAKING: usize = 0b10;
pub struct AtomicWaker {
state: AtomicUsize,
waker: UnsafeCell<Option<Waker>>,
}
impl AtomicWaker {
#[inline]
pub const fn new() -> Self {
AtomicWaker { state: AtomicUsize::new(WAITING), waker: UnsafeCell::new(None) }
}
#[inline]
pub fn register(&self, waker: &Waker) {
let prev_state = self
.state
.compare_exchange(WAITING, REGISTERING, Ordering::Acquire, Ordering::Acquire)
.unwrap_or_else(|el| el);
match prev_state {
WAITING => {
let waker_opt = unsafe { &mut *self.waker.get() };
match waker_opt {
Some(elem) => elem.clone_from(waker),
_ => *waker_opt = Some(waker.clone()),
}
let prev_state_is_not_waiting = self
.state
.compare_exchange(REGISTERING, WAITING, Ordering::AcqRel, Ordering::Acquire)
.is_err();
if prev_state_is_not_waiting {
let Some(local_waker) = waker_opt.take() else {
return;
};
let _ = self.state.swap(WAITING, Ordering::AcqRel);
local_waker.wake();
}
}
WAKING => {
waker.wake_by_ref();
}
_ => {}
}
}
#[inline]
pub fn take(&self) -> Option<Waker> {
match self.state.fetch_or(WAKING, Ordering::AcqRel) {
WAITING => {
let waker = unsafe { (*self.waker.get()).take() };
let _ = self.state.swap(WAITING, Ordering::Release);
waker
}
_ => None,
}
}
#[inline]
pub fn wake(&self) {
if let Some(waker) = self.take() {
waker.wake();
}
}
}
impl Debug for AtomicWaker {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AtomicWaker")
}
}
impl Default for AtomicWaker {
#[inline]
fn default() -> Self {
AtomicWaker::new()
}
}
unsafe impl Send for AtomicWaker {}
unsafe impl Sync for AtomicWaker {}
#[cfg(test)]
mod tests {
use crate::{
executor::Runtime,
sync::{Arc, AtomicBool, AtomicWaker},
};
use core::{future::poll_fn, sync::atomic::Ordering, task::Poll};
use std::thread;
#[test]
fn non_blocking_operation() {
let atomic_waker = Arc::new(AtomicWaker::new());
let atomic_waker_clone = atomic_waker.clone();
let waiting = Arc::new(AtomicBool::new(false));
let waiting_clone = waiting.clone();
let woken = Arc::new(AtomicBool::new(false));
let woken_clone = woken.clone();
let jh = thread::spawn(move || {
let mut pending = 0;
Runtime::new().block_on(poll_fn(move |cx| {
if woken_clone.load(Ordering::Relaxed) {
Poll::Ready(())
} else {
assert_eq!(0, pending);
pending += 1;
atomic_waker_clone.register(cx.waker());
waiting_clone.store(true, Ordering::Relaxed);
Poll::Pending
}
}));
});
while !waiting.load(Ordering::Relaxed) {}
thread::yield_now();
woken.store(true, Ordering::Relaxed);
atomic_waker.wake();
jh.join().unwrap();
}
}