use core::ptr;
use std::{
cell::UnsafeCell,
sync::atomic::{AtomicU8, Ordering},
task::{RawWaker, RawWakerVTable, Waker},
thread,
};
pub(crate) struct AtomicWaker {
state: AtomicU8,
waker: UnsafeCell<Waker>,
}
unsafe impl Send for AtomicWaker {}
unsafe impl Sync for AtomicWaker {}
const WAITING: u8 = 0b0;
const WAKING: u8 = 0b01;
const REGISTERING: u8 = 0b10;
const FULL: u8 = WAKING | REGISTERING;
impl AtomicWaker {
pub fn register(&self, waker: &Waker) {
match self.state.fetch_or(REGISTERING, Ordering::AcqRel) {
WAKING => {
waker.wake_by_ref();
self.state.fetch_and(!REGISTERING, Ordering::Release);
thread::yield_now();
}
state => {
debug_assert_eq!(state, WAITING);
let ptr = self.waker.get();
let inner_waker = unsafe { &mut *ptr };
if !inner_waker.will_wake(waker) {
*inner_waker = waker.clone();
}
match self.state.fetch_and(!REGISTERING, Ordering::AcqRel) {
FULL => {
let ptr = self.waker.get();
let inner_waker = unsafe { &*ptr };
inner_waker.wake_by_ref();
self.state.swap(WAITING, Ordering::AcqRel);
}
state => {
debug_assert_eq!(state, REGISTERING)
}
}
}
}
}
pub fn wake(&self) {
match self.state.fetch_or(WAKING, Ordering::AcqRel) {
WAITING => {
let ptr = self.waker.get();
let inner_waker = unsafe { &*ptr };
inner_waker.wake_by_ref();
self.state.fetch_and(!WAKING, Ordering::Release);
}
state => {
debug_assert!(state == REGISTERING || state == FULL || state == WAKING);
}
}
}
}
impl Default for AtomicWaker {
fn default() -> Self {
Self {
state: AtomicU8::new(WAITING),
waker: UnsafeCell::new(dummy_waker()),
}
}
}
const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop);
unsafe fn noop_clone(_: *const ()) -> RawWaker {
RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)
}
unsafe fn noop(_: *const ()) {}
fn dummy_waker() -> Waker {
unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) }
}