use std::{
sync::{
Arc,
atomic::{AtomicBool, AtomicUsize, Ordering},
},
time::{Duration, Instant},
};
#[derive(Clone, Debug, Default)]
pub struct CancellationToken(Arc<AtomicBool>);
impl CancellationToken {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn cancel(&self) {
self.0.store(true, Ordering::Release);
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.0.load(Ordering::Acquire)
}
}
#[derive(Debug)]
pub struct NodeScanBudget {
max_nodes: usize,
scanned: AtomicUsize,
}
impl NodeScanBudget {
#[must_use]
pub const fn new(max_nodes: usize) -> Self {
Self {
max_nodes,
scanned: AtomicUsize::new(0),
}
}
#[must_use]
pub const fn max_nodes(&self) -> usize {
self.max_nodes
}
#[must_use]
pub fn scanned(&self) -> usize {
self.scanned.load(Ordering::Acquire)
}
fn note_nodes_scanned(&self, nodes: usize) -> Result<(), CancellationCause> {
let mut observed = self.scanned.load(Ordering::Acquire);
loop {
let next = observed.saturating_add(nodes);
match self.scanned.compare_exchange_weak(
observed,
next,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
if next > self.max_nodes {
return Err(CancellationCause::NodeScanBudgetExceeded {
limit: self.max_nodes,
scanned: next,
});
}
return Ok(());
}
Err(actual) => observed = actual,
}
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum CancellationCause {
Cancelled,
Timeout {
elapsed: Duration,
},
NodeScanBudgetExceeded {
limit: usize,
scanned: usize,
},
}
#[derive(Clone, Copy, Debug)]
pub struct CancellationChecker<'a> {
token: Option<&'a CancellationToken>,
deadline: Option<Instant>,
node_scan_budget: Option<&'a NodeScanBudget>,
}
impl<'a> CancellationChecker<'a> {
#[must_use]
pub const fn new(token: Option<&'a CancellationToken>, deadline: Option<Instant>) -> Self {
Self {
token,
deadline,
node_scan_budget: None,
}
}
#[must_use]
pub const fn new_with_node_scan_budget(
token: Option<&'a CancellationToken>,
deadline: Option<Instant>,
node_scan_budget: Option<&'a NodeScanBudget>,
) -> Self {
Self {
token,
deadline,
node_scan_budget,
}
}
#[must_use]
pub const fn disabled() -> Self {
Self {
token: None,
deadline: None,
node_scan_budget: None,
}
}
#[must_use]
#[inline(always)]
pub const fn is_disabled(&self) -> bool {
self.token.is_none() && self.deadline.is_none() && self.node_scan_budget.is_none()
}
#[inline]
pub fn check(&self) -> Result<(), CancellationCause> {
if self.token.is_some_and(CancellationToken::is_cancelled) {
return Err(CancellationCause::Cancelled);
}
if let Some(deadline) = self.deadline {
let now = Instant::now();
if now >= deadline {
return Err(CancellationCause::Timeout {
elapsed: now.duration_since(deadline),
});
}
}
Ok(())
}
#[inline]
pub fn note_nodes_scanned(&self, nodes: usize) -> Result<(), CancellationCause> {
self.check()?;
if let Some(budget) = self.node_scan_budget {
budget.note_nodes_scanned(nodes)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn disabled_checker_never_trips() {
let checker = CancellationChecker::disabled();
assert!(checker.is_disabled());
assert_eq!(checker.check(), Ok(()));
}
#[test]
fn checker_with_token_is_not_disabled() {
let token = CancellationToken::new();
let checker = CancellationChecker::new(Some(&token), None);
assert!(!checker.is_disabled());
}
#[test]
fn checker_with_deadline_is_not_disabled() {
let deadline = Instant::now();
let checker = CancellationChecker::new(None, Some(deadline));
assert!(!checker.is_disabled());
}
#[test]
fn checker_with_node_scan_budget_is_not_disabled() {
let budget = NodeScanBudget::new(10);
let checker = CancellationChecker::new_with_node_scan_budget(None, None, Some(&budget));
assert!(!checker.is_disabled());
}
#[test]
fn token_wins_over_deadline_when_both_tripped() {
let token = CancellationToken::new();
token.cancel();
let elapsed_deadline = Instant::now() - Duration::from_secs(1);
let checker = CancellationChecker::new(Some(&token), Some(elapsed_deadline));
assert_eq!(checker.check(), Err(CancellationCause::Cancelled));
}
#[test]
fn token_wins_over_node_scan_budget_when_both_tripped() {
let token = CancellationToken::new();
token.cancel();
let budget = NodeScanBudget::new(0);
let checker =
CancellationChecker::new_with_node_scan_budget(Some(&token), None, Some(&budget));
assert_eq!(
checker.note_nodes_scanned(1),
Err(CancellationCause::Cancelled)
);
assert_eq!(budget.scanned(), 0);
}
#[test]
fn deadline_wins_over_node_scan_budget_when_both_tripped() {
let elapsed_deadline = Instant::now() - Duration::from_secs(1);
let budget = NodeScanBudget::new(0);
let checker = CancellationChecker::new_with_node_scan_budget(
None,
Some(elapsed_deadline),
Some(&budget),
);
assert!(matches!(
checker.note_nodes_scanned(1),
Err(CancellationCause::Timeout { .. })
));
assert_eq!(budget.scanned(), 0);
}
#[test]
fn deadline_reported_when_only_deadline_tripped() {
let elapsed_deadline = Instant::now() - Duration::from_secs(1);
let checker = CancellationChecker::new(None, Some(elapsed_deadline));
assert!(matches!(
checker.check(),
Err(CancellationCause::Timeout { .. })
));
}
#[test]
fn live_token_with_future_deadline_passes() {
let token = CancellationToken::new();
let future_deadline = Instant::now() + Duration::from_secs(3600);
let checker = CancellationChecker::new(Some(&token), Some(future_deadline));
assert_eq!(checker.check(), Ok(()));
}
#[test]
fn node_scan_budget_trips_after_crossing_limit() {
let budget = NodeScanBudget::new(3);
let checker = CancellationChecker::new_with_node_scan_budget(None, None, Some(&budget));
assert_eq!(checker.note_nodes_scanned(2), Ok(()));
assert_eq!(
checker.note_nodes_scanned(2),
Err(CancellationCause::NodeScanBudgetExceeded {
limit: 3,
scanned: 4
})
);
assert_eq!(budget.scanned(), 4);
}
}