use crate::shim::atomic::{AtomicUsize, Ordering};
use crate::shim::cell::UnsafeCell;
use core::task::Waker;
const WAITING: usize = 0;
const REGISTERING: usize = 0b01;
const WAKING: usize = 0b10;
pub struct AtomicWaker {
state: AtomicUsize,
waker: UnsafeCell<Option<Waker>>,
}
unsafe impl Sync for AtomicWaker {}
unsafe impl Send for AtomicWaker {}
impl AtomicWaker {
#[inline]
pub fn new() -> Self {
Self {
state: AtomicUsize::new(WAITING),
waker: UnsafeCell::new(None),
}
}
#[inline]
pub fn register(&self, waker: &Waker) {
match self.state.compare_exchange(
WAITING,
REGISTERING,
Ordering::Acquire,
Ordering::Acquire,
) {
Ok(_) => {
let old_waker = self
.waker
.with_mut(|w| unsafe { (*w).replace(waker.clone()) });
match self.state.compare_exchange(
REGISTERING,
WAITING,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
drop(old_waker);
}
Err(_) => {
let waker = self.waker.with_mut(|w| unsafe { (*w).take() });
self.state.store(WAITING, Ordering::Release);
drop(old_waker);
if let Some(waker) = waker {
waker.wake();
}
}
}
}
Err(WAKING) => {
waker.wake_by_ref();
}
Err(_) => {
}
}
}
#[inline]
pub fn take(&self) -> Option<Waker> {
match self.state.fetch_or(WAKING, Ordering::AcqRel) {
WAITING => {
let waker = self.waker.with_mut(|w| unsafe { (*w).take() });
self.state.store(WAITING, Ordering::Release);
waker
}
_ => {
None
}
}
}
#[inline]
pub fn wake(&self) {
if let Some(waker) = self.take() {
waker.wake();
}
}
}
impl Drop for AtomicWaker {
fn drop(&mut self) {
let _ = self.waker.with_mut(|w| unsafe { (*w).take() });
}
}
impl core::fmt::Debug for AtomicWaker {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let state = self.state.load(Ordering::Acquire);
let state_str = match state {
WAITING => "Waiting",
REGISTERING => "Registering",
WAKING => "Waking",
_ => "Unknown",
};
f.debug_struct("AtomicWaker")
.field("state", &state_str)
.finish()
}
}
#[cfg(all(test, not(feature = "loom")))]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_basic_register_and_take() {
let atomic_waker = AtomicWaker::new();
let waker = futures::task::noop_waker();
atomic_waker.register(&waker);
let taken = atomic_waker.take();
assert!(taken.is_some());
let taken2 = atomic_waker.take();
assert!(taken2.is_none());
}
#[test]
fn test_concurrent_access() {
use std::thread;
let atomic_waker = Arc::new(AtomicWaker::new());
let waker = futures::task::noop_waker();
let aw1 = atomic_waker.clone();
let w1 = waker.clone();
let h1 = thread::spawn(move || {
for _ in 0..100 {
aw1.register(&w1);
}
});
let aw2 = atomic_waker.clone();
let h2 = thread::spawn(move || {
for _ in 0..100 {
aw2.take();
}
});
h1.join().unwrap();
h2.join().unwrap();
}
}