use std::error::Error;
use std::fmt;
use std::future::Future;
use std::time::Duration;
use tokio::sync::watch;
use tokio::time::Instant;
#[derive(Debug, Clone, Eq, PartialEq)]
#[non_exhaustive]
pub enum AsyncControlError {
TimedOut,
Cancelled,
}
impl fmt::Display for AsyncControlError {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TimedOut => formatter.write_str("operation timed out"),
Self::Cancelled => formatter.write_str("operation was cancelled"),
}
}
}
impl Error for AsyncControlError {}
#[derive(Debug, Clone)]
pub struct CancellationSource {
sender: watch::Sender<bool>,
}
impl CancellationSource {
#[must_use]
pub fn new() -> (Self, CancellationToken) {
let (sender, receiver) = watch::channel(false);
(Self { sender }, CancellationToken { receiver })
}
#[must_use]
pub fn token(&self) -> CancellationToken {
CancellationToken {
receiver: self.sender.subscribe(),
}
}
pub fn cancel(&self) {
let _ = self.sender.send(true);
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
*self.sender.borrow()
}
}
#[derive(Debug, Clone)]
pub struct CancellationToken {
receiver: watch::Receiver<bool>,
}
impl CancellationToken {
#[must_use]
pub fn is_cancelled(&self) -> bool {
*self.receiver.borrow()
}
pub async fn cancelled(&mut self) {
if *self.receiver.borrow() {
return;
}
loop {
if self.receiver.changed().await.is_err() {
return;
}
if *self.receiver.borrow_and_update() {
return;
}
}
}
}
#[derive(Debug, Clone)]
pub struct ShutdownTrigger {
source: CancellationSource,
}
impl ShutdownTrigger {
pub fn shutdown(&self) {
self.source.cancel();
}
#[must_use]
pub fn signal(&self) -> ShutdownSignal {
ShutdownSignal {
token: self.source.token(),
}
}
#[must_use]
pub fn is_shutdown_requested(&self) -> bool {
self.source.is_cancelled()
}
}
#[derive(Debug, Clone)]
pub struct ShutdownSignal {
token: CancellationToken,
}
impl ShutdownSignal {
#[must_use]
pub fn is_shutdown_requested(&self) -> bool {
self.token.is_cancelled()
}
pub async fn wait(&mut self) {
self.token.cancelled().await;
}
}
#[must_use]
pub fn shutdown_signal() -> (ShutdownTrigger, ShutdownSignal) {
let (source, token) = CancellationSource::new();
(ShutdownTrigger { source }, ShutdownSignal { token })
}
pub async fn with_timeout<F, T>(duration: Duration, future: F) -> Result<T, AsyncControlError>
where
F: Future<Output = T>,
{
tokio::time::timeout(duration, future)
.await
.map_err(|_| AsyncControlError::TimedOut)
}
pub async fn with_deadline<F, T>(deadline: Instant, future: F) -> Result<T, AsyncControlError>
where
F: Future<Output = T>,
{
tokio::time::timeout_at(deadline, future)
.await
.map_err(|_| AsyncControlError::TimedOut)
}
pub async fn run_until_cancelled<F, T>(
mut token: CancellationToken,
future: F,
) -> Result<T, AsyncControlError>
where
F: Future<Output = T>,
{
tokio::select! {
biased;
_ = token.cancelled() => Err(AsyncControlError::Cancelled),
value = future => Ok(value),
}
}
pub async fn with_timeout_or_cancel<F, T>(
duration: Duration,
mut token: CancellationToken,
future: F,
) -> Result<T, AsyncControlError>
where
F: Future<Output = T>,
{
tokio::select! {
biased;
_ = token.cancelled() => Err(AsyncControlError::Cancelled),
result = tokio::time::timeout(duration, future) => {
result.map_err(|_| AsyncControlError::TimedOut)
}
}
}
#[cfg(test)]
mod tests {
use std::future::pending;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::Notify;
use tokio::time::{Duration, Instant, sleep};
use super::*;
struct DropCounter {
counter: Arc<AtomicUsize>,
}
impl Drop for DropCounter {
fn drop(&mut self) {
self.counter.fetch_add(1, Ordering::SeqCst);
}
}
#[test]
fn async_control_error_formats_public_error_messages() {
assert_eq!(
AsyncControlError::TimedOut.to_string(),
"operation timed out"
);
assert_eq!(
AsyncControlError::Cancelled.to_string(),
"operation was cancelled"
);
assert!(AsyncControlError::TimedOut.source().is_none());
assert!(AsyncControlError::Cancelled.source().is_none());
}
#[tokio::test]
async fn with_timeout_returns_value_before_deadline() {
let actual = with_timeout(Duration::from_secs(1), async { 7 }).await;
assert_eq!(actual, Ok(7));
}
#[tokio::test(start_paused = true)]
async fn with_timeout_reports_elapsed_operation() {
let actual = with_timeout(Duration::from_millis(10), async {
sleep(Duration::from_secs(1)).await;
7
})
.await;
assert_eq!(actual, Err(AsyncControlError::TimedOut));
}
#[tokio::test(start_paused = true)]
async fn with_deadline_reports_elapsed_operation() {
let deadline = Instant::now() + Duration::from_millis(10);
let actual = with_deadline(deadline, async {
sleep(Duration::from_secs(1)).await;
7
})
.await;
assert_eq!(actual, Err(AsyncControlError::TimedOut));
}
#[tokio::test(start_paused = true)]
async fn with_timeout_or_cancel_reports_timeout_when_token_is_idle() {
let (_source, token) = CancellationSource::new();
let actual = with_timeout_or_cancel(Duration::from_millis(10), token, async {
sleep(Duration::from_secs(1)).await;
7
})
.await;
assert_eq!(actual, Err(AsyncControlError::TimedOut));
}
#[tokio::test]
async fn run_until_cancelled_returns_value_before_cancellation() {
let (_source, token) = CancellationSource::new();
let actual = run_until_cancelled(token, async { 7 }).await;
assert_eq!(actual, Ok(7));
}
#[tokio::test]
async fn cancellation_token_completes_when_all_sources_are_dropped() {
let (source, mut token) = CancellationSource::new();
drop(source);
token.cancelled().await;
assert!(!token.is_cancelled());
}
#[tokio::test]
async fn run_until_cancelled_reports_cancelled_when_source_is_dropped() {
let (source, token) = CancellationSource::new();
drop(source);
let actual = run_until_cancelled(token, pending::<()>()).await;
assert_eq!(actual, Err(AsyncControlError::Cancelled));
}
#[tokio::test]
async fn run_until_cancelled_reports_cancellation_and_drops_future() {
let (source, token) = CancellationSource::new();
let dropped = Arc::new(AtomicUsize::new(0));
let started = Arc::new(Notify::new());
let task = tokio::spawn({
let dropped = Arc::clone(&dropped);
let started = Arc::clone(&started);
async move {
run_until_cancelled(token, async move {
let _guard = DropCounter { counter: dropped };
started.notify_one();
pending::<()>().await;
7
})
.await
}
});
started.notified().await;
source.cancel();
let actual = task.await.unwrap();
assert_eq!(actual, Err(AsyncControlError::Cancelled));
assert_eq!(dropped.load(Ordering::SeqCst), 1);
}
#[tokio::test(start_paused = true)]
async fn with_timeout_or_cancel_prefers_cancellation() {
let (source, token) = CancellationSource::new();
source.cancel();
let actual = with_timeout_or_cancel(Duration::from_millis(10), token, async {
sleep(Duration::from_secs(1)).await;
7
})
.await;
assert_eq!(actual, Err(AsyncControlError::Cancelled));
}
#[tokio::test]
async fn with_timeout_or_cancel_reports_cancelled_when_source_is_dropped() {
let (source, token) = CancellationSource::new();
drop(source);
let actual = with_timeout_or_cancel(Duration::from_secs(1), token, pending::<()>()).await;
assert_eq!(actual, Err(AsyncControlError::Cancelled));
}
#[tokio::test]
async fn shutdown_signal_notifies_all_listeners() {
let (trigger, mut signal) = shutdown_signal();
let mut second = trigger.signal();
let first_task = tokio::spawn(async move {
signal.wait().await;
signal.is_shutdown_requested()
});
let second_task = tokio::spawn(async move {
second.wait().await;
second.is_shutdown_requested()
});
trigger.shutdown();
assert!(first_task.await.unwrap());
assert!(second_task.await.unwrap());
assert!(trigger.is_shutdown_requested());
}
#[tokio::test]
async fn shutdown_signal_waits_until_trigger_is_dropped() {
let (trigger, mut signal) = shutdown_signal();
drop(trigger);
signal.wait().await;
assert!(!signal.is_shutdown_requested());
}
}