use std::future::Future;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use wacore::runtime::Runtime;
pub struct FlushScope {
count: AtomicUsize,
idle: event_listener::Event,
closed: std::sync::Mutex<bool>,
}
impl Default for FlushScope {
fn default() -> Self {
Self::new()
}
}
impl FlushScope {
pub fn new() -> Self {
Self {
count: AtomicUsize::new(0),
idle: event_listener::Event::new(),
closed: std::sync::Mutex::new(false),
}
}
pub fn close(&self) {
*self.closed.lock().unwrap_or_else(|e| e.into_inner()) = true;
}
pub fn reopen(&self) {
*self.closed.lock().unwrap_or_else(|e| e.into_inner()) = false;
}
pub fn spawn<F>(self: &Arc<Self>, rt: &dyn Runtime, fut: F)
where
F: Future<Output = ()> + Send + 'static,
{
{
let closed = self.closed.lock().unwrap_or_else(|e| e.into_inner());
if *closed {
return;
}
self.count.fetch_add(1, Ordering::Relaxed);
}
let guard = DecrementOnDrop {
scope: Arc::clone(self),
};
rt.spawn(Box::pin(async move {
let _guard = guard;
fut.await;
}))
.detach();
}
pub async fn flush(&self, rt: &dyn Runtime, timeout: Duration) {
use wacore::time::Instant;
let deadline = Instant::now() + timeout;
loop {
let listener = self.idle.listen();
if self.count.load(Ordering::Relaxed) == 0 {
return;
}
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero()
|| wacore::runtime::timeout(rt, remaining, listener)
.await
.is_err()
{
log::warn!(
"FlushScope timed out with {} pending task(s)",
self.count.load(Ordering::Relaxed)
);
return;
}
}
}
#[cfg(test)]
pub fn pending(&self) -> usize {
self.count.load(Ordering::Relaxed)
}
}
struct DecrementOnDrop {
scope: Arc<FlushScope>,
}
impl Drop for DecrementOnDrop {
fn drop(&mut self) {
if self.scope.count.fetch_sub(1, Ordering::Relaxed) == 1 {
self.scope.idle.notify(usize::MAX);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicBool;
use wacore::time::Instant;
fn rt() -> Arc<dyn Runtime> {
Arc::new(crate::runtime_impl::TokioRuntime)
}
#[tokio::test]
async fn flush_returns_after_tasks_complete() {
let scope = Arc::new(FlushScope::new());
let runtime = rt();
for _ in 0..5 {
let r = Arc::clone(&runtime);
scope.spawn(&*runtime, async move {
r.sleep(Duration::from_millis(20)).await;
});
}
assert_eq!(scope.pending(), 5);
let start = Instant::now();
scope.flush(&*runtime, Duration::from_secs(1)).await;
assert!(start.elapsed() < Duration::from_millis(500));
assert_eq!(scope.pending(), 0);
}
#[tokio::test]
async fn flush_returns_immediately_when_idle() {
let scope = Arc::new(FlushScope::new());
let runtime = rt();
let start = Instant::now();
scope.flush(&*runtime, Duration::from_secs(5)).await;
assert!(start.elapsed() < Duration::from_millis(10));
}
#[tokio::test]
async fn close_rejects_new_tasks_until_reopened() {
let scope = Arc::new(FlushScope::new());
let runtime = rt();
let ran = Arc::new(AtomicBool::new(false));
scope.close();
let ran_closed = Arc::clone(&ran);
scope.spawn(&*runtime, async move {
ran_closed.store(true, Ordering::Relaxed);
});
scope.flush(&*runtime, Duration::from_secs(1)).await;
assert!(!ran.load(Ordering::Relaxed));
scope.reopen();
let ran_open = Arc::clone(&ran);
scope.spawn(&*runtime, async move {
ran_open.store(true, Ordering::Relaxed);
});
scope.flush(&*runtime, Duration::from_secs(1)).await;
assert!(ran.load(Ordering::Relaxed));
}
#[tokio::test]
async fn flush_honors_timeout_and_logs() {
let scope = Arc::new(FlushScope::new());
let runtime = rt();
let r = Arc::clone(&runtime);
scope.spawn(&*runtime, async move {
r.sleep(Duration::from_secs(10)).await;
});
let start = Instant::now();
scope.flush(&*runtime, Duration::from_millis(50)).await;
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(50));
assert!(elapsed < Duration::from_millis(500));
assert_eq!(scope.pending(), 1);
}
#[tokio::test]
async fn decrement_runs_when_future_is_aborted_mid_flight() {
use std::task::{Context, Poll};
let scope = Arc::new(FlushScope::new());
scope.count.fetch_add(1, Ordering::Relaxed);
let scope_for_fut = Arc::clone(&scope);
let mut fut: std::pin::Pin<Box<dyn Future<Output = ()> + Send>> = Box::pin(async move {
let _guard = DecrementOnDrop {
scope: scope_for_fut,
};
futures::future::pending::<()>().await;
});
let waker = futures::task::noop_waker();
let mut cx = Context::from_waker(&waker);
assert!(matches!(fut.as_mut().poll(&mut cx), Poll::Pending));
assert_eq!(scope.pending(), 1);
drop(fut);
assert_eq!(
scope.pending(),
0,
"DecrementOnDrop must fire when the in-flight future is dropped"
);
}
#[tokio::test]
async fn decrement_runs_when_future_is_dropped_before_first_poll() {
let scope = Arc::new(FlushScope::new());
scope.count.fetch_add(1, Ordering::Relaxed);
let guard = DecrementOnDrop {
scope: Arc::clone(&scope),
};
let never_polled_fut = async move {
let _guard = guard;
futures::future::pending::<()>().await;
};
assert_eq!(scope.pending(), 1);
drop(never_polled_fut);
assert_eq!(
scope.pending(),
0,
"guard must drop with the never-polled future"
);
}
#[tokio::test]
async fn ran_bodies_observe_completion() {
let scope = Arc::new(FlushScope::new());
let runtime = rt();
let done = Arc::new(AtomicBool::new(false));
let done_clone = Arc::clone(&done);
scope.spawn(&*runtime, async move {
done_clone.store(true, Ordering::Relaxed);
});
scope.flush(&*runtime, Duration::from_secs(1)).await;
assert!(done.load(Ordering::Relaxed));
}
}