use std::future::Future;
pub use tokio_util::sync::CancellationToken;
use crate::errors::CoreError;
pub fn root_token() -> CancellationToken {
CancellationToken::new()
}
pub fn round_token(parent: &CancellationToken) -> CancellationToken {
parent.child_token()
}
pub fn child_token(round: &CancellationToken) -> CancellationToken {
round.child_token()
}
pub trait CancelExt {
fn error_if_cancelled(&self) -> Result<(), CoreError>;
fn guard<F, T>(&self, fut: F) -> impl Future<Output = Result<T, CoreError>> + Send
where
F: Future<Output = T> + Send;
}
impl CancelExt for CancellationToken {
fn error_if_cancelled(&self) -> Result<(), CoreError> {
if self.is_cancelled() {
Err(CoreError::cancelled())
} else {
Ok(())
}
}
async fn guard<F, T>(&self, fut: F) -> Result<T, CoreError>
where
F: Future<Output = T> + Send,
{
tokio::select! {
biased;
() = self.cancelled() => Err(CoreError::cancelled()),
out = fut => Ok(out),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn root_starts_uncancelled() {
let t = root_token();
assert!(!t.is_cancelled());
assert!(t.error_if_cancelled().is_ok());
}
#[test]
fn cancel_down_the_chain_propagates_parent_to_child() {
let parent = root_token();
let round = round_token(&parent);
let child = child_token(&round);
assert!(!child.is_cancelled());
parent.cancel();
assert!(parent.is_cancelled());
assert!(round.is_cancelled());
assert!(child.is_cancelled());
}
#[test]
fn cancel_does_not_propagate_up() {
let parent = root_token();
let round = round_token(&parent);
let child = child_token(&round);
child.cancel();
assert!(child.is_cancelled());
assert!(!round.is_cancelled());
assert!(!parent.is_cancelled());
}
#[test]
fn error_if_cancelled_yields_typed_error_not_panic() {
let t = root_token();
t.cancel();
let err = t.error_if_cancelled().unwrap_err();
assert!(err.is_cancelled());
assert_eq!(err.to_string(), "cancelled");
}
#[tokio::test]
async fn guard_returns_ok_when_future_wins() {
let t = root_token();
let out = t.guard(async { 7 }).await;
assert_eq!(out.unwrap(), 7);
}
#[tokio::test]
async fn guard_returns_typed_cancel_when_token_wins() {
let t = root_token();
let t2 = t.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(5)).await;
t2.cancel();
});
let out: Result<(), CoreError> = t
.guard(async {
tokio::time::sleep(Duration::from_secs(3600)).await;
})
.await;
assert!(out.unwrap_err().is_cancelled());
}
}