use alloc::sync::{Arc, Weak};
use core::fmt;
use core::future::Future;
use core::pin::Pin;
use core::sync::atomic::{AtomicBool, Ordering};
use core::task::{Context, Poll};
use futures_core::Stream;
use futures_util::task::AtomicWaker;
use crate::{Listen, Listener};
pub struct WakeFlag {
shared: Arc<WakeFlagShared>,
}
#[derive(Clone)]
pub struct WakeFlagListener {
shared: Weak<WakeFlagShared>,
_alive: Arc<()>,
}
#[derive(Debug)]
struct WakeFlagFuture<'a> {
shared: &'a WakeFlagShared,
done: bool,
}
#[derive(Debug)]
struct WakeFlagShared {
notified: AtomicBool,
listeners_alive: Weak<()>,
waker: AtomicWaker,
}
impl WakeFlag {
const SET_ORDERING: Ordering = Ordering::Release;
const GET_CLEAR_ORDERING: Ordering = Ordering::Acquire;
#[must_use]
pub fn new(wake_immediately: bool) -> (Self, WakeFlagListener) {
let strong_alive = Arc::new(());
let listeners_alive = Arc::downgrade(&strong_alive);
let shared = Arc::new(WakeFlagShared {
notified: AtomicBool::new(wake_immediately),
listeners_alive,
waker: AtomicWaker::new(),
});
let listener = WakeFlagListener {
shared: Arc::downgrade(&shared),
_alive: strong_alive,
};
(Self { shared }, listener)
}
#[must_use]
pub fn listening<L>(wake_immediately: bool, source: L) -> Self
where
L: Listen,
L::Listener: crate::FromListener<WakeFlagListener, L::Msg>,
{
let (flag, listener) = Self::new(wake_immediately);
source.listen(listener);
flag
}
#[inline]
#[must_use]
pub async fn wait(&mut self) -> bool {
WakeFlagFuture {
shared: &self.shared,
done: false,
}
.await
}
#[inline]
pub fn notify(&self) {
self.shared.notify_message();
}
}
impl<M> Listener<M> for WakeFlagListener {
fn receive(&self, messages: &[M]) -> bool {
if let Some(shared) = self.shared.upgrade() {
if !messages.is_empty() {
shared.notify_message();
}
true
} else {
false
}
}
}
impl<M> crate::FromListener<WakeFlagListener, M> for WakeFlagListener {
fn from_listener(listener: WakeFlagListener) -> Self {
listener
}
}
impl Drop for WakeFlagListener {
fn drop(&mut self) {
if let Some(shared) = self.shared.upgrade() {
shared.waker.wake();
}
}
}
impl Future for WakeFlagFuture<'_> {
type Output = bool;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
assert!(!self.done);
let poll_outcome = self.shared.poll(cx);
if poll_outcome.is_ready() {
self.get_mut().done = true;
}
poll_outcome
}
}
impl WakeFlagShared {
fn notify_message(&self) {
self.notified.store(true, WakeFlag::SET_ORDERING);
self.waker.wake();
}
fn poll(&self, cx: &mut Context<'_>) -> Poll<bool> {
if let Some(answer) = self.get_and_clear() {
return Poll::Ready(answer);
}
self.waker.register(cx.waker());
if let Some(answer) = self.get_and_clear() {
Poll::Ready(answer)
} else {
Poll::Pending
}
}
fn get_and_clear(&self) -> Option<bool> {
if self.notified.swap(false, WakeFlag::GET_CLEAR_ORDERING) {
Some(true)
} else if self.listeners_alive.strong_count() == 0 {
Some(false)
} else {
None
}
}
}
impl core::panic::RefUnwindSafe for WakeFlagShared {}
impl fmt::Debug for WakeFlag {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("WakeFlag")
.field(&self.shared.notified.load(Ordering::Relaxed))
.finish()
}
}
impl fmt::Debug for WakeFlagListener {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { shared, _alive } = self;
let strong = shared.upgrade();
let mut ds = f.debug_struct("WakeFlagListener");
ds.field("alive", &strong.is_some());
if let Some(strong) = strong {
ds.field("value", &(strong.notified.load(Ordering::Relaxed)));
}
ds.finish()
}
}
impl fmt::Pointer for WakeFlag {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Arc::as_ptr(&self.shared).fmt(f)
}
}
impl fmt::Pointer for WakeFlagListener {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.shared.as_ptr().fmt(f)
}
}
impl Stream for WakeFlag {
type Item = ();
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.shared
.poll(cx)
.map(|alive| if alive { Some(()) } else { None })
}
}