use crate::{
locks::{lock, Lock},
FillQueue,
};
use alloc::sync::{Arc, Weak};
use core::mem::ManuallyDrop;
use docfg::docfg;
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
#[derive(Debug, Clone)]
pub struct Flag {
inner: Arc<FlagQueue>,
}
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
#[derive(Debug, Clone)]
pub struct Subscribe {
inner: Weak<FlagQueue>,
}
impl Flag {
#[inline]
pub unsafe fn into_raw(self) -> *const FillQueue<Lock> {
Arc::into_raw(self.inner).cast()
}
#[inline]
pub unsafe fn from_raw(ptr: *const FillQueue<Lock>) -> Self {
Self {
inner: Arc::from_raw(ptr.cast()),
}
}
#[inline]
pub fn has_subscriber(&self) -> bool {
return Arc::weak_count(&self.inner) > 0;
}
#[inline]
pub fn mark(self) {}
#[inline]
pub fn silent_drop(self) {
if let Ok(inner) = Arc::try_unwrap(self.inner) {
inner.silent_drop()
}
}
}
impl Subscribe {
#[inline]
pub fn is_marked(&self) -> bool {
return self.inner.strong_count() == 0;
}
#[inline]
pub fn wait(self) {
if let Some(queue) = self.inner.upgrade() {
let (waker, sub) = lock();
queue.0.push(waker);
drop(queue);
sub.wait()
}
}
#[docfg(feature = "std")]
#[inline]
pub fn wait_timeout(self, dur: core::time::Duration) -> Result<(), crate::Timeout> {
if let Some(queue) = self.inner.upgrade() {
let (waker, sub) = lock();
queue.0.push(waker);
drop(queue);
sub.wait_timeout(dur);
return match self.is_marked() {
true => Ok(()),
false => Err(crate::Timeout),
};
}
return Ok(());
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
pub fn flag() -> (Flag, Subscribe) {
let flag = Arc::new(FlagQueue(FillQueue::new()));
let sub = Arc::downgrade(&flag);
(Flag { inner: flag }, Subscribe { inner: sub })
}
#[repr(transparent)]
#[derive(Debug)]
struct FlagQueue(pub FillQueue<Lock>);
impl FlagQueue {
#[inline]
pub fn silent_drop(self) {
let mut this = ManuallyDrop::new(self);
this.0.chop_mut().for_each(Lock::silent_drop);
unsafe { core::ptr::drop_in_place(&mut this) };
}
}
impl Drop for FlagQueue {
#[inline]
fn drop(&mut self) {
self.0.chop_mut().for_each(Lock::wake);
}
}
cfg_if::cfg_if! {
if #[cfg(feature = "futures")] {
use core::{future::Future, task::{Waker, Poll}};
use futures::future::FusedFuture;
#[cfg_attr(docsrs, doc(cfg(all(feature = "alloc", feature = "futures"))))]
#[inline]
pub fn async_flag () -> (AsyncFlag, AsyncSubscribe) {
let flag = Arc::new(AsyncFlagQueue(FillQueue::new()));
let sub = Arc::downgrade(&flag);
return (AsyncFlag { inner: flag }, AsyncSubscribe { inner: Some(sub) })
}
#[cfg_attr(docsrs, doc(cfg(all(feature = "alloc", feature = "futures"))))]
#[derive(Debug, Clone)]
pub struct AsyncFlag {
inner: Arc<AsyncFlagQueue>
}
impl AsyncFlag {
#[inline]
pub unsafe fn into_raw (self) -> *const FillQueue<Waker> {
Arc::into_raw(self.inner).cast()
}
#[inline]
pub unsafe fn from_raw (ptr: *const FillQueue<Waker>) -> Self {
Self { inner: Arc::from_raw(ptr.cast()) }
}
#[inline]
pub fn has_subscriber(&self) -> bool {
return Arc::weak_count(&self.inner) > 0
}
#[inline]
pub fn mark (self) {}
#[inline]
pub fn subscribe (&self) -> AsyncSubscribe {
AsyncSubscribe {
inner: Some(Arc::downgrade(&self.inner))
}
}
#[inline]
pub fn silent_drop (self) {
if let Ok(inner) = Arc::try_unwrap(self.inner) {
inner.silent_drop()
}
}
}
#[cfg_attr(docsrs, doc(cfg(all(feature = "alloc", feature = "futures"))))]
#[derive(Debug, Clone)]
pub struct AsyncSubscribe {
inner: Option<Weak<AsyncFlagQueue>>
}
impl AsyncSubscribe {
#[inline]
pub fn marked () -> AsyncSubscribe {
return Self { inner: None }
}
#[inline]
pub fn is_marked (&self) -> bool {
return !crate::is_some_and(self.inner.as_ref(), |x| x.strong_count() > 0)
}
}
impl Future for AsyncSubscribe {
type Output = ();
#[inline]
fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> {
if let Some(ref queue) = self.inner {
if let Some(queue) = queue.upgrade() {
queue.0.push(cx.waker().clone());
return Poll::Pending;
}
self.inner = None;
return Poll::Ready(())
}
return Poll::Ready(())
}
}
impl FusedFuture for AsyncSubscribe {
#[inline]
fn is_terminated(&self) -> bool {
self.inner.is_none()
}
}
#[derive(Debug)]
struct AsyncFlagQueue (pub FillQueue<Waker>);
impl AsyncFlagQueue {
#[inline]
pub fn silent_drop (self) {
let mut this = ManuallyDrop::new(self);
let _: crate::prelude::ChopIter<Waker> = this.0.chop_mut();
unsafe { core::ptr::drop_in_place(&mut this.0) }
}
}
impl Drop for AsyncFlagQueue {
#[inline]
fn drop(&mut self) {
self.0.chop_mut().for_each(Waker::wake);
}
}
}
}
#[cfg(all(feature = "std", test))]
mod tests {
use super::flag;
use super::Flag;
use core::time::Duration;
use std::thread;
#[test]
fn test_normal_conditions() {
let (f, _) = flag();
f.mark();
let (f, s) = flag();
let f = unsafe { Flag::from_raw(Flag::into_raw(f)) };
thread::spawn(move || {
thread::sleep(Duration::from_millis(100));
f.mark();
});
s.wait();
}
#[test]
fn test_silent_drop() {
let (f, s) = flag();
let handle = thread::spawn(move || s.wait_timeout(Duration::from_millis(100)));
std::thread::sleep(Duration::from_millis(200));
f.silent_drop();
let time = handle.join().unwrap();
assert!(time.is_err());
}
#[test]
fn test_stressed_conditions() {
let mut handles = Vec::new();
let (f, s) = flag();
for _ in 0..10 {
let cloned_s = s.clone();
let handle = thread::spawn(move || {
for _ in 0..10 {
let cloned_s = cloned_s.clone();
cloned_s.wait();
}
});
handles.push(handle);
}
thread::sleep(Duration::from_millis(100));
for _ in 0..9 {
f.clone().mark();
}
f.mark();
for handle in handles {
handle.join().unwrap();
}
}
}
#[cfg(all(feature = "futures", test))]
mod async_tests {
use super::{async_flag, AsyncFlag};
use core::time::Duration;
use std::time::Instant;
#[tokio::test]
async fn test_async_normal_conditions() {
let (f, s) = async_flag();
assert_eq!(s.is_marked(), false);
f.mark();
assert_eq!(s.is_marked(), true);
let (f, mut s) = async_flag();
let f = unsafe { AsyncFlag::from_raw(AsyncFlag::into_raw(f)) };
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
f.mark();
});
(&mut s).await;
assert_eq!(s.is_marked(), true);
}
#[tokio::test]
async fn test_silent_drop() {
let (f, s) = async_flag();
let handle = tokio::spawn(tokio::time::timeout(
Duration::from_millis(200),
async move {
let now = Instant::now();
s.await;
now.elapsed()
},
));
tokio::time::sleep(Duration::from_millis(100)).await;
f.silent_drop();
match handle.await.unwrap() {
Ok(t) if t < Duration::from_millis(200) => panic!("{t:?}"),
_ => {}
}
}
#[tokio::test]
async fn test_async_stressed_conditions() {
let (f, s) = async_flag();
let mut handles = Vec::new();
for _ in 0..100 {
let mut cloned_s = s.clone();
let handle = tokio::spawn(async move {
(&mut cloned_s).await;
assert_eq!(cloned_s.is_marked(), true);
});
handles.push(handle);
}
tokio::time::sleep(Duration::from_millis(100)).await;
f.mark();
for handle in handles {
handle.await.unwrap();
}
}
}