use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
#[derive(Clone, Default)]
pub struct CancelToken {
inner: Arc<Inner>,
}
impl std::fmt::Debug for CancelToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CancelToken")
.field("cancelled", &self.is_cancelled())
.finish()
}
}
type Hook = Box<dyn Fn() + Send + Sync>;
#[derive(Default)]
struct Inner {
cancelled: AtomicBool,
hooks: Mutex<Vec<Option<Hook>>>,
}
impl CancelToken {
pub fn new() -> Self {
Self::default()
}
pub fn cancel(&self) {
self.inner.cancelled.store(true, Ordering::SeqCst);
if let Ok(hooks) = self.inner.hooks.lock() {
for h in hooks.iter().flatten() {
h();
}
}
}
pub fn is_cancelled(&self) -> bool {
self.inner.cancelled.load(Ordering::SeqCst)
}
pub(crate) fn register(&self, hook: Hook) -> CancelGuard {
if self.is_cancelled() {
hook();
return CancelGuard {
token: self.clone(),
idx: usize::MAX,
};
}
let mut hooks = self.inner.hooks.lock().unwrap();
let idx = match hooks.iter().position(|h| h.is_none()) {
Some(i) => {
hooks[i] = Some(hook);
i
}
None => {
hooks.push(Some(hook));
hooks.len() - 1
}
};
drop(hooks);
if self.is_cancelled() {
if let Ok(hooks) = self.inner.hooks.lock() {
if let Some(Some(h)) = hooks.get(idx) {
h();
}
}
}
CancelGuard {
token: self.clone(),
idx,
}
}
}
pub(crate) struct CancelGuard {
token: CancelToken,
idx: usize,
}
impl Drop for CancelGuard {
fn drop(&mut self) {
if self.idx == usize::MAX {
return;
}
if let Ok(mut hooks) = self.token.inner.hooks.lock() {
if let Some(slot) = hooks.get_mut(self.idx) {
*slot = None;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cancel_sets_flag_and_fires_hooks() {
let t = CancelToken::new();
assert!(!t.is_cancelled());
let fired = Arc::new(AtomicBool::new(false));
let f2 = fired.clone();
let _g = t.register(Box::new(move || f2.store(true, Ordering::SeqCst)));
t.cancel();
assert!(t.is_cancelled());
assert!(fired.load(Ordering::SeqCst));
}
#[test]
fn register_after_cancel_fires_immediately() {
let t = CancelToken::new();
t.cancel();
let fired = Arc::new(AtomicBool::new(false));
let f2 = fired.clone();
let _g = t.register(Box::new(move || f2.store(true, Ordering::SeqCst)));
assert!(fired.load(Ordering::SeqCst));
}
#[test]
fn dropped_guard_is_not_invoked() {
let t = CancelToken::new();
let count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
{
let c2 = count.clone();
let _g = t.register(Box::new(move || {
c2.fetch_add(1, Ordering::SeqCst);
}));
}
t.cancel();
assert_eq!(count.load(Ordering::SeqCst), 0);
}
#[test]
fn slot_reuse_keeps_active_hooks() {
let t = CancelToken::new();
let count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let c1 = count.clone();
let g1 = t.register(Box::new(move || {
c1.fetch_add(1, Ordering::SeqCst);
}));
{
let c2 = count.clone();
let _g2 = t.register(Box::new(move || {
c2.fetch_add(1, Ordering::SeqCst);
}));
} let c3 = count.clone();
let _g3 = t.register(Box::new(move || {
c3.fetch_add(1, Ordering::SeqCst);
})); drop(g1); t.cancel();
assert_eq!(count.load(Ordering::SeqCst), 1);
}
}