use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::OnceLock;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
#[must_use = "cancellation token should be passed to handler for cooperative timeout"]
pub struct CancellationToken {
cancelled: std::sync::Arc<AtomicBool>,
observed_cancelled: std::sync::Arc<AtomicBool>,
cancellation_requested_at: std::sync::Arc<OnceLock<Instant>>,
cancellation_observed_at: std::sync::Arc<OnceLock<Instant>>,
}
impl CancellationToken {
pub fn new() -> Self {
Self {
cancelled: std::sync::Arc::new(AtomicBool::new(false)),
observed_cancelled: std::sync::Arc::new(AtomicBool::new(false)),
cancellation_requested_at: std::sync::Arc::new(OnceLock::new()),
cancellation_observed_at: std::sync::Arc::new(OnceLock::new()),
}
}
pub fn is_cancelled(&self) -> bool {
let cancelled = self.cancelled.load(Ordering::SeqCst);
if cancelled {
self.observed_cancelled.store(true, Ordering::SeqCst);
let _ = self.cancellation_observed_at.get_or_init(Instant::now);
}
cancelled
}
pub fn cancel(&self) {
let _ = self.cancellation_requested_at.get_or_init(Instant::now);
self.cancelled.store(true, Ordering::SeqCst);
}
pub fn was_cancellation_observed(&self) -> bool {
self.observed_cancelled.load(Ordering::SeqCst)
}
pub fn cancelled_flag(&self) -> std::sync::Arc<AtomicBool> {
std::sync::Arc::clone(&self.cancelled)
}
pub fn cancellation_observation_latency(&self) -> Option<Duration> {
let requested = self.cancellation_requested_at.get().copied()?;
let observed = self.cancellation_observed_at.get().copied()?;
Some(observed.saturating_duration_since(requested))
}
}
impl Default for CancellationToken {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CancellationContext {
token: CancellationToken,
}
impl CancellationContext {
pub fn new() -> Self {
Self { token: CancellationToken::new() }
}
pub fn token(&self) -> &CancellationToken {
&self.token
}
pub fn cancel(&self) {
self.token.cancel();
}
pub fn was_cancellation_observed(&self) -> bool {
self.token.was_cancellation_observed()
}
pub fn cancellation_observation_latency(&self) -> Option<Duration> {
self.token.cancellation_observation_latency()
}
}
impl Default for CancellationContext {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cancellation_token_defaults_to_not_cancelled() {
let token = CancellationToken::new();
assert!(!token.is_cancelled());
}
#[test]
fn cancellation_token_can_be_cancelled() {
let token = CancellationToken::new();
token.cancel();
assert!(token.is_cancelled());
}
#[test]
fn cancellation_observation_is_recorded_when_polled_after_cancel() {
let token = CancellationToken::new();
assert!(!token.was_cancellation_observed());
assert_eq!(token.cancellation_observation_latency(), None);
token.cancel();
assert!(token.is_cancelled());
assert!(token.was_cancellation_observed());
assert!(token.cancellation_observation_latency().is_some());
}
#[test]
fn cancellation_is_idempotent() {
let token = CancellationToken::new();
token.cancel();
token.cancel();
assert!(token.is_cancelled());
}
#[test]
fn cancellation_context_creates_fresh_token() {
let context = CancellationContext::new();
assert!(!context.token().is_cancelled());
}
#[test]
fn cancellation_context_can_cancel() {
let context = CancellationContext::new();
context.cancel();
assert!(context.token().is_cancelled());
}
#[test]
fn cancellation_context_reports_observation_state() {
let context = CancellationContext::new();
assert!(!context.was_cancellation_observed());
assert_eq!(context.cancellation_observation_latency(), None);
context.cancel();
assert!(context.token().is_cancelled());
assert!(context.was_cancellation_observed());
assert!(context.cancellation_observation_latency().is_some());
}
}