use tokio::sync::watch;
pub struct CancellationToken {
sender: watch::Sender<bool>,
receiver: watch::Receiver<bool>,
}
impl CancellationToken {
pub fn new() -> Self {
let (sender, receiver) = watch::channel(false);
Self { sender, receiver }
}
pub fn cancel(&self) {
let _ = self.sender.send(true);
}
pub fn is_cancelled(&self) -> bool {
*self.receiver.borrow()
}
pub fn child(&self) -> CancellationChild {
let parent = self.sender.subscribe();
let (local, local_rx) = watch::channel(false);
CancellationChild {
parent,
local,
local_rx,
}
}
pub async fn cancelled(&self) {
let mut rx = self.receiver.clone();
if *rx.borrow() {
return;
}
loop {
if rx.changed().await.is_err() {
return;
}
if *rx.borrow() {
return;
}
}
}
}
impl Default for CancellationToken {
fn default() -> Self {
Self::new()
}
}
pub struct CancellationChild {
parent: watch::Receiver<bool>,
local: watch::Sender<bool>,
local_rx: watch::Receiver<bool>,
}
impl CancellationChild {
pub fn cancel(&self) {
let _ = self.local.send(true);
}
pub fn is_cancelled(&self) -> bool {
*self.parent.borrow() || *self.local_rx.borrow()
}
pub async fn cancelled(&self) {
if self.is_cancelled() {
return;
}
let mut parent_rx = self.parent.clone();
let mut local_rx = self.local_rx.clone();
loop {
tokio::select! {
res = parent_rx.changed() => {
if res.is_err() || *parent_rx.borrow() {
return;
}
}
res = local_rx.changed() => {
if res.is_err() || *local_rx.borrow() {
return;
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn token_starts_uncancelled() {
let token = CancellationToken::new();
assert!(!token.is_cancelled());
}
#[test]
fn cancel_propagates() {
let token = CancellationToken::new();
token.cancel();
assert!(token.is_cancelled());
}
#[tokio::test]
async fn child_inherits_parent_cancel() {
let parent = CancellationToken::new();
let child = parent.child();
assert!(!child.is_cancelled());
parent.cancel();
assert!(child.is_cancelled());
tokio::time::timeout(std::time::Duration::from_millis(100), child.cancelled())
.await
.expect("child.cancelled() should complete immediately when parent is cancelled");
}
#[tokio::test]
async fn child_local_cancel_independent() {
let parent = CancellationToken::new();
let child = parent.child();
child.cancel();
assert!(child.is_cancelled());
assert!(!parent.is_cancelled());
tokio::time::timeout(std::time::Duration::from_millis(100), child.cancelled())
.await
.expect("child.cancelled() should complete immediately after local cancel");
}
}