use core::future::Future;
use core::pin::Pin;
use core::sync::atomic::{AtomicUsize, Ordering};
use core::task::{Context, Poll};
use std::sync::Mutex;
use futures::future::FusedFuture;
use crate::linked_list::LinkedList;
use crate::utils::is_unpin;
use crate::utils::notify::{State, Waiter};
#[derive(Debug)]
pub struct Notify {
state: AtomicUsize,
waiters: Mutex<LinkedList<Waiter>>,
}
impl Notify {
#[inline]
pub fn new() -> Self {
Self {
state: AtomicUsize::new(0),
waiters: Mutex::default(),
}
}
pub fn notify_all(&self) {
self.state.store(0, Ordering::SeqCst);
let mut waiters = self.waiters.lock().unwrap();
#[allow(clippy::significant_drop_in_scrutinee)]
for waiter in waiters.iter_mut() {
let waiter = unsafe { waiter.get() };
waiter.notified = true;
if let Some(waker) = &waiter.waker {
waker.wake_by_ref();
}
}
}
pub fn notify_one(&self) {
let waiters = self.waiters.lock().unwrap();
#[allow(clippy::significant_drop_in_scrutinee)]
match waiters.front() {
Some(waiter) => {
let waiter = unsafe { waiter.get() };
waiter.notified = true;
if let Some(waker) = &waiter.waker {
waker.wake_by_ref();
}
}
None => self.state.store(1, Ordering::SeqCst),
}
}
pub fn notified(&self) -> Notified<'_> {
Notified {
notify: self,
state: State::Init,
waiter: Waiter::new(),
}
}
}
impl Default for Notify {
#[inline]
fn default() -> Self {
Self::new()
}
}
unsafe impl Send for Notify {}
unsafe impl Sync for Notify {}
#[derive(Debug)]
pub struct Notified<'a> {
notify: &'a Notify,
state: State,
waiter: Waiter,
}
impl<'a> Notified<'a> {
#[inline]
fn state_mut(self: Pin<&mut Self>) -> &mut State {
is_unpin::<State>();
unsafe { &mut self.get_unchecked_mut().state }
}
}
impl<'a> Future for Notified<'a> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.state {
State::Init => {
let res =
self.notify
.state
.compare_exchange(1, 0, Ordering::SeqCst, Ordering::SeqCst);
if res.is_ok() {
*self.state_mut() = State::Done;
return Poll::Ready(());
}
let mut waiters = self.notify.waiters.lock().unwrap();
unsafe {
self.waiter.get().waker = Some(cx.waker().clone());
waiters.push_back((&self.waiter).into());
}
drop(waiters);
*self.state_mut() = State::Pending;
Poll::Pending
}
State::Pending => {
let mut waiters = self.notify.waiters.lock().unwrap();
let waiter = unsafe { self.waiter.get() };
if waiter.notified {
unsafe {
waiters.remove((&self.waiter).into());
}
*self.state_mut() = State::Done;
Poll::Ready(())
} else {
let update = match &waiter.waker {
Some(waker) => !waker.will_wake(cx.waker()),
None => true,
};
if update {
waiter.waker = Some(cx.waker().clone());
}
drop(waiters);
Poll::Pending
}
}
State::Done => Poll::Ready(()),
}
}
}
impl<'a> Drop for Notified<'a> {
fn drop(&mut self) {
if self.state == State::Pending {
let mut waiters = self.notify.waiters.lock().unwrap();
unsafe {
waiters.remove((&self.waiter).into());
}
}
}
}
impl<'a> FusedFuture for Notified<'a> {
#[inline]
fn is_terminated(&self) -> bool {
self.state == State::Done
}
}
unsafe impl<'a> Send for Notified<'a> {}
unsafe impl<'a> Sync for Notified<'a> {}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use std::vec::Vec;
use tokio::sync::mpsc;
use super::Notify;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_notify_all() {
let notify = Arc::new(Notify::new());
let (tx, mut rx) = mpsc::channel(5);
for _ in 0..5 {
let handle = notify.clone();
let tx = tx.clone();
tokio::task::spawn(async move {
handle.notified().await;
let _ = tx.send(()).await;
});
}
tokio::time::sleep(Duration::new(5, 0)).await;
notify.notify_all();
for _ in 0..5 {
let _ = rx.recv().await;
}
}
#[test]
fn test_notify_notified() {
let notify = Notify::new();
let _: Vec<_> = (0..5).map(|_| notify.notified()).collect();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_notify_one() {
let notify = Arc::new(Notify::new());
let handle = notify.clone();
tokio::task::spawn(async move {
handle.notified().await;
});
tokio::time::sleep(Duration::new(1, 0)).await;
notify.notify_one();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_notify_one_stored() {
let notify = Arc::new(Notify::new());
notify.notify_one();
notify.notified().await;
}
}