#![cfg_attr(feature = "docs", doc = "\n\nSee the [changelog][changelog] for a full release history.")]
#![cfg_attr(feature = "docs", doc = "## Feature flags")]
#![cfg_attr(feature = "docs", doc = document_features::document_features!())]
#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
#![deny(missing_docs)]
#![deny(unsafe_code)]
#![deny(unreachable_pub)]
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize};
use tokio_util::sync::CancellationToken;
mod ext;
pub use ext::*;
#[derive(Debug)]
struct ContextTracker(Arc<ContextTrackerInner>);
impl Drop for ContextTracker {
fn drop(&mut self) {
let prev_active_count = self.0.active_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
if prev_active_count == 1 && self.0.stopped.load(std::sync::atomic::Ordering::Relaxed) {
self.0.notify.notify_waiters();
}
}
}
#[derive(Debug)]
struct ContextTrackerInner {
stopped: AtomicBool,
active_count: AtomicUsize,
notify: tokio::sync::Notify,
}
impl ContextTrackerInner {
fn new() -> Arc<Self> {
Arc::new(Self {
stopped: AtomicBool::new(false),
active_count: AtomicUsize::new(0),
notify: tokio::sync::Notify::new(),
})
}
fn child(self: &Arc<Self>) -> ContextTracker {
self.active_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
ContextTracker(Arc::clone(self))
}
fn stop(&self) {
self.stopped.store(true, std::sync::atomic::Ordering::Relaxed);
}
async fn wait(&self) {
let notify = self.notify.notified();
if self.active_count.load(std::sync::atomic::Ordering::Relaxed) == 0 {
return;
}
notify.await;
}
}
#[derive(Debug)]
pub struct Context {
token: CancellationToken,
tracker: ContextTracker,
}
impl Clone for Context {
fn clone(&self) -> Self {
Self {
token: self.token.clone(),
tracker: self.tracker.0.child(),
}
}
}
impl Context {
#[must_use]
pub fn new() -> (Self, Handler) {
Handler::global().new_child()
}
#[must_use]
pub fn new_child(&self) -> (Self, Handler) {
let token = self.token.child_token();
let tracker = ContextTrackerInner::new();
(
Self {
tracker: tracker.child(),
token: token.clone(),
},
Handler {
token: Arc::new(TokenDropGuard(token)),
tracker,
},
)
}
#[must_use]
pub fn global() -> Self {
Handler::global().context()
}
pub async fn done(&self) {
self.token.cancelled().await;
}
pub async fn into_done(self) {
self.done().await;
}
#[must_use]
pub fn is_done(&self) -> bool {
self.token.is_cancelled()
}
}
#[derive(Debug)]
struct TokenDropGuard(CancellationToken);
impl TokenDropGuard {
#[must_use]
fn child(&self) -> CancellationToken {
self.0.child_token()
}
fn cancel(&self) {
self.0.cancel();
}
}
impl Drop for TokenDropGuard {
fn drop(&mut self) {
self.cancel();
}
}
#[derive(Debug, Clone)]
pub struct Handler {
token: Arc<TokenDropGuard>,
tracker: Arc<ContextTrackerInner>,
}
impl Default for Handler {
fn default() -> Self {
Self::new()
}
}
impl Handler {
#[must_use]
pub fn new() -> Handler {
let token = CancellationToken::new();
let tracker = ContextTrackerInner::new();
Handler {
token: Arc::new(TokenDropGuard(token)),
tracker,
}
}
#[must_use]
pub fn global() -> &'static Self {
static GLOBAL: std::sync::OnceLock<Handler> = std::sync::OnceLock::new();
GLOBAL.get_or_init(Handler::new)
}
pub async fn shutdown(&self) {
self.cancel();
self.done().await;
}
pub async fn done(&self) {
self.token.0.cancelled().await;
self.wait().await;
}
pub async fn wait(&self) {
self.tracker.wait().await;
}
#[must_use]
pub fn context(&self) -> Context {
Context {
token: self.token.child(),
tracker: self.tracker.child(),
}
}
#[must_use]
pub fn new_child(&self) -> (Context, Handler) {
self.context().new_child()
}
pub fn cancel(&self) {
self.tracker.stop();
self.token.cancel();
}
pub fn is_done(&self) -> bool {
self.token.0.is_cancelled()
}
}
#[cfg_attr(all(coverage_nightly, test), coverage(off))]
#[cfg(test)]
mod tests {
use scuffle_future_ext::FutureExt;
use crate::{Context, Handler};
#[tokio::test]
async fn new() {
let (ctx, handler) = Context::new();
assert!(!handler.is_done());
assert!(!ctx.is_done());
let handler = Handler::default();
assert!(!handler.is_done());
}
#[tokio::test]
async fn cancel() {
let (ctx, handler) = Context::new();
let (child_ctx, child_handler) = ctx.new_child();
let child_ctx2 = ctx.clone();
assert!(!handler.is_done());
assert!(!ctx.is_done());
assert!(!child_handler.is_done());
assert!(!child_ctx.is_done());
assert!(!child_ctx2.is_done());
handler.cancel();
assert!(handler.is_done());
assert!(ctx.is_done());
assert!(child_handler.is_done());
assert!(child_ctx.is_done());
assert!(child_ctx2.is_done());
}
#[tokio::test]
async fn cancel_child() {
let (ctx, handler) = Context::new();
let (child_ctx, child_handler) = ctx.new_child();
assert!(!handler.is_done());
assert!(!ctx.is_done());
assert!(!child_handler.is_done());
assert!(!child_ctx.is_done());
child_handler.cancel();
assert!(!handler.is_done());
assert!(!ctx.is_done());
assert!(child_handler.is_done());
assert!(child_ctx.is_done());
}
#[tokio::test]
async fn shutdown() {
let (ctx, handler) = Context::new();
assert!(!handler.is_done());
assert!(!ctx.is_done());
assert!(
handler
.shutdown()
.with_timeout(std::time::Duration::from_millis(200))
.await
.is_err()
);
assert!(handler.is_done());
assert!(ctx.is_done());
assert!(
ctx.into_done()
.with_timeout(std::time::Duration::from_millis(200))
.await
.is_ok()
);
assert!(
handler
.shutdown()
.with_timeout(std::time::Duration::from_millis(200))
.await
.is_ok()
);
assert!(
handler
.wait()
.with_timeout(std::time::Duration::from_millis(200))
.await
.is_ok()
);
assert!(
handler
.done()
.with_timeout(std::time::Duration::from_millis(200))
.await
.is_ok()
);
assert!(handler.is_done());
}
#[tokio::test]
async fn global_handler() {
let handler = Handler::global();
assert!(!handler.is_done());
handler.cancel();
assert!(handler.is_done());
assert!(Handler::global().is_done());
assert!(Context::global().is_done());
let (child_ctx, child_handler) = Handler::global().new_child();
assert!(child_handler.is_done());
assert!(child_ctx.is_done());
}
}
#[cfg(feature = "docs")]
#[scuffle_changelog::changelog]
pub mod changelog {}