use crate::shim::atomic::{AtomicU8, Ordering};
use core::future::Future;
use core::pin::Pin;
use core::task::{Context, Poll};
use super::atomic_waker::AtomicWaker;
const EMPTY: u8 = 0; const WAITING: u8 = 1; const NOTIFIED: u8 = 2;
pub struct SingleWaiterNotify {
state: AtomicU8,
waker: AtomicWaker,
}
impl core::fmt::Debug for SingleWaiterNotify {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let state = self.state.load(Ordering::Acquire);
let state_str = match state {
EMPTY => "Empty",
WAITING => "Waiting",
NOTIFIED => "Notified",
_ => "Unknown",
};
f.debug_struct("SingleWaiterNotify")
.field("state", &state_str)
.finish()
}
}
impl Default for SingleWaiterNotify {
fn default() -> Self {
Self::new()
}
}
impl SingleWaiterNotify {
#[inline]
pub fn new() -> Self {
Self {
state: AtomicU8::new(EMPTY),
waker: AtomicWaker::new(),
}
}
#[inline]
pub fn notified(&self) -> Notified<'_> {
Notified {
notify: self,
registered: false,
}
}
#[inline]
pub fn notify_one(&self) {
let prev_state = self.state.swap(NOTIFIED, Ordering::AcqRel);
if prev_state == WAITING {
self.waker.wake();
}
}
#[inline]
fn register_waker(&self, waker: &core::task::Waker) -> bool {
self.waker.register(waker);
let current_state = self.state.load(Ordering::Acquire);
if current_state == NOTIFIED {
self.state.store(EMPTY, Ordering::Release);
return true;
}
match self
.state
.compare_exchange(EMPTY, WAITING, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => {
if self.state.load(Ordering::Acquire) == NOTIFIED {
self.state.store(EMPTY, Ordering::Release);
true
} else {
false
}
}
Err(state) => {
if state == NOTIFIED {
self.state.store(EMPTY, Ordering::Release);
true
} else {
if self.state.load(Ordering::Acquire) == NOTIFIED {
self.state.store(EMPTY, Ordering::Release);
true
} else {
false
}
}
}
}
}
}
pub struct Notified<'a> {
notify: &'a SingleWaiterNotify,
registered: bool,
}
impl<'a> core::fmt::Debug for Notified<'a> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let state = self.notify.state.load(Ordering::Acquire);
let state_str = match state {
EMPTY => "Empty",
WAITING => "Waiting",
NOTIFIED => "Notified",
_ => "Unknown",
};
f.debug_struct("Notified")
.field("state", &state_str)
.field("registered", &self.registered)
.finish()
}
}
impl Future for Notified<'_> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if !self.registered {
self.registered = true;
if self.notify.register_waker(cx.waker()) {
return Poll::Ready(());
}
} else {
if self.notify.state.load(Ordering::Acquire) == NOTIFIED {
self.notify.state.store(EMPTY, Ordering::Release);
return Poll::Ready(());
}
if self.notify.register_waker(cx.waker()) {
return Poll::Ready(());
}
}
Poll::Pending
}
}
impl Drop for Notified<'_> {
fn drop(&mut self) {
if self.registered {
let _ = self.notify.state.compare_exchange(
WAITING,
EMPTY,
Ordering::Relaxed,
Ordering::Relaxed,
);
}
}
}
#[cfg(all(test, not(feature = "loom")))]
mod tests {
use super::*;
use std::sync::Arc;
use tokio::time::{Duration, sleep};
#[tokio::test]
async fn test_notify_before_wait() {
let notify = Arc::new(SingleWaiterNotify::new());
notify.notify_one();
notify.notified().await;
}
#[tokio::test]
async fn test_notify_after_wait() {
let notify = Arc::new(SingleWaiterNotify::new());
let notify_clone = notify.clone();
tokio::spawn(async move {
sleep(Duration::from_millis(10)).await;
notify_clone.notify_one();
});
notify.notified().await;
}
#[tokio::test]
async fn test_multiple_notify_cycles() {
let notify = Arc::new(SingleWaiterNotify::new());
for _ in 0..10 {
let notify_clone = notify.clone();
tokio::spawn(async move {
sleep(Duration::from_millis(5)).await;
notify_clone.notify_one();
});
notify.notified().await;
}
}
#[tokio::test]
async fn test_concurrent_notify() {
let notify = Arc::new(SingleWaiterNotify::new());
let notify_clone = notify.clone();
for _ in 0..5 {
let n = notify_clone.clone();
tokio::spawn(async move {
sleep(Duration::from_millis(10)).await;
n.notify_one();
});
}
notify.notified().await;
}
#[tokio::test]
async fn test_notify_no_waiter() {
let notify = SingleWaiterNotify::new();
notify.notify_one();
notify.notify_one();
notify.notified().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_stress_test() {
let notify = Arc::new(SingleWaiterNotify::new());
for i in 0..100 {
let notify_clone = notify.clone();
tokio::spawn(async move {
sleep(Duration::from_micros(i % 10)).await;
notify_clone.notify_one();
});
notify.notified().await;
}
}
#[tokio::test]
async fn test_immediate_notification_race() {
for _ in 0..100 {
let notify = Arc::new(SingleWaiterNotify::new());
let notify_clone = notify.clone();
let waiter = tokio::spawn(async move {
notify.notified().await;
});
notify_clone.notify_one();
tokio::time::timeout(Duration::from_millis(100), waiter)
.await
.expect("Should not timeout")
.expect("Task should complete");
}
}
}