use std::sync::Arc;
use tokio::sync::watch;
pub fn cancellation_pair() -> (CancellationHandle, CancellationToken) {
let (tx, rx) = watch::channel(false);
(
CancellationHandle {
sender: Arc::new(tx),
},
CancellationToken { receiver: rx },
)
}
#[derive(Clone)]
pub struct CancellationHandle {
sender: Arc<watch::Sender<bool>>,
}
impl CancellationHandle {
pub fn cancel(&self) {
let _ = self.sender.send(true);
}
pub fn is_cancelled(&self) -> bool {
*self.sender.borrow()
}
}
#[derive(Clone)]
pub struct CancellationToken {
receiver: watch::Receiver<bool>,
}
impl CancellationToken {
pub fn is_cancelled(&self) -> bool {
*self.receiver.borrow()
}
pub async fn cancelled(&self) {
if self.is_cancelled() {
return;
}
let mut rx = self.receiver.clone();
loop {
match rx.changed().await {
Ok(_) => {
if *rx.borrow() {
return;
}
}
Err(_) => return,
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_cancellation_signal_sync() {
let (handle, token) = cancellation_pair();
assert!(
!token.is_cancelled(),
"token should not be cancelled initially"
);
handle.cancel();
assert!(
token.is_cancelled(),
"token should be cancelled after handle.cancel()"
);
}
#[tokio::test]
async fn test_cancellation_signal_async() {
let (handle, token) = cancellation_pair();
assert!(!token.is_cancelled());
handle.cancel();
assert!(token.is_cancelled());
let result = tokio::time::timeout(Duration::from_millis(100), token.cancelled()).await;
assert!(
result.is_ok(),
"token.cancelled().await should complete immediately after cancel, not time out"
);
}
}