use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::Notify;
#[derive(Debug, Clone)]
pub struct CancellationToken {
cancelled: Arc<AtomicBool>,
notify: Arc<Notify>,
}
impl CancellationToken {
pub fn new() -> Self {
Self {
cancelled: Arc::new(AtomicBool::new(false)),
notify: Arc::new(Notify::new()),
}
}
pub fn cancelled_now() -> Self {
let t = Self::new();
t.cancel();
t
}
pub fn cancel(&self) {
self.cancelled.store(true, Ordering::SeqCst);
self.notify.notify_waiters();
}
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::SeqCst)
}
pub async fn cancelled(&self) {
if self.is_cancelled() {
return;
}
loop {
self.notify.notified().await;
if self.is_cancelled() {
return;
}
}
}
pub fn reset(&self) {
self.cancelled.store(false, Ordering::SeqCst);
}
pub fn check(&self, reason: &str) -> crate::error::Result<()> {
if self.is_cancelled() {
Err(crate::error::CognisError::Cancelled(reason.to_string()))
} else {
Ok(())
}
}
}
impl Default for CancellationToken {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn initial_state_is_not_cancelled() {
let token = CancellationToken::new();
assert!(!token.is_cancelled());
}
#[test]
fn cancel_sets_cancelled() {
let token = CancellationToken::new();
token.cancel();
assert!(token.is_cancelled());
}
#[test]
fn reset_clears_cancellation() {
let token = CancellationToken::new();
token.cancel();
assert!(token.is_cancelled());
token.reset();
assert!(!token.is_cancelled());
}
#[test]
fn clones_share_state() {
let token = CancellationToken::new();
let token2 = token.clone();
token.cancel();
assert!(token2.is_cancelled());
}
#[test]
fn default_is_not_cancelled() {
let token = CancellationToken::default();
assert!(!token.is_cancelled());
}
#[tokio::test]
async fn cancelled_future_resolves_when_signalled() {
let token = CancellationToken::new();
let token2 = token.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(20)).await;
token2.cancel();
});
tokio::time::timeout(Duration::from_secs(2), token.cancelled())
.await
.expect("cancelled() should resolve when token is cancelled");
}
#[tokio::test]
async fn cancelled_future_resolves_immediately_when_already_cancelled() {
let token = CancellationToken::new();
token.cancel();
tokio::time::timeout(Duration::from_millis(50), token.cancelled())
.await
.expect("cancelled() should resolve immediately when already cancelled");
}
#[test]
fn cancelled_now_constructs_pre_cancelled() {
let token = CancellationToken::cancelled_now();
assert!(token.is_cancelled());
}
#[test]
fn check_method_returns_ok_when_not_cancelled() {
let token = CancellationToken::new();
assert!(token.check("unused").is_ok());
}
#[test]
fn check_method_returns_err_when_cancelled() {
let token = CancellationToken::new();
token.cancel();
let err = token.check("stop now").unwrap_err();
match err {
crate::error::CognisError::Cancelled(reason) => assert_eq!(reason, "stop now"),
other => panic!("expected Cancelled, got {other:?}"),
}
}
}