use std::future::Future;
use std::time::Duration;
use tokio::task::JoinHandle;
use tokio::time::{MissedTickBehavior, interval};
use tokio_util::sync::CancellationToken;
use tracing::warn;
use super::error::TickError;
pub trait PeriodicTask: Send + 'static {
fn tick(&mut self) -> impl Future<Output = Result<(), TickError>> + Send;
fn shutdown(&mut self) -> impl Future<Output = Result<(), TickError>> + Send {
std::future::ready(Ok(()))
}
}
pub struct PeriodicWorker {
join: JoinHandle<()>,
}
impl PeriodicWorker {
pub fn spawn<T: PeriodicTask>(
mut task: T,
interval_duration: Duration,
shutdown: CancellationToken,
) -> Self {
let join = tokio::spawn(async move {
let mut tick = interval(interval_duration);
tick.set_missed_tick_behavior(MissedTickBehavior::Delay);
tick.tick().await;
loop {
tokio::select! {
biased;
() = shutdown.cancelled() => {
if let Err(e) = task.shutdown().await {
warn!(error = %e, "periodic task shutdown hook failed");
}
return;
}
_ = tick.tick() => {
if let Err(e) = task.tick().await {
warn!(error = %e, "periodic task tick failed");
}
}
}
}
});
Self { join }
}
pub async fn join(self) -> Result<(), tokio::task::JoinError> {
self.join.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Instant;
struct CountingTask {
ticks: Arc<AtomicU32>,
}
impl PeriodicTask for CountingTask {
async fn tick(&mut self) -> Result<(), TickError> {
self.ticks.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
struct ShutdownTask {
ticks: Arc<AtomicU32>,
shutdown_called: Arc<AtomicU32>,
}
impl PeriodicTask for ShutdownTask {
async fn tick(&mut self) -> Result<(), TickError> {
self.ticks.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn shutdown(&mut self) -> Result<(), TickError> {
self.shutdown_called.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
struct FailingTask {
ticks: Arc<AtomicU32>,
}
impl PeriodicTask for FailingTask {
async fn tick(&mut self) -> Result<(), TickError> {
self.ticks.fetch_add(1, Ordering::SeqCst);
Err(TickError::Generic("simulated".into()))
}
}
#[tokio::test]
async fn tick_fires_at_interval() {
let ticks = Arc::new(AtomicU32::new(0));
let shutdown = CancellationToken::new();
let _worker = PeriodicWorker::spawn(
CountingTask {
ticks: ticks.clone(),
},
Duration::from_millis(20),
shutdown.clone(),
);
tokio::time::sleep(Duration::from_millis(110)).await;
shutdown.cancel();
let n = ticks.load(Ordering::SeqCst);
assert!((4..=7).contains(&n), "got {n} ticks, expected 4-7");
}
#[tokio::test]
async fn first_tick_is_delayed_not_immediate() {
let ticks = Arc::new(AtomicU32::new(0));
let shutdown = CancellationToken::new();
let _worker = PeriodicWorker::spawn(
CountingTask {
ticks: ticks.clone(),
},
Duration::from_millis(100),
shutdown.clone(),
);
tokio::time::sleep(Duration::from_millis(10)).await;
assert_eq!(ticks.load(Ordering::SeqCst), 0);
shutdown.cancel();
}
#[tokio::test]
async fn shutdown_hook_called_exactly_once() {
let ticks = Arc::new(AtomicU32::new(0));
let shutdown_called = Arc::new(AtomicU32::new(0));
let shutdown = CancellationToken::new();
let worker = PeriodicWorker::spawn(
ShutdownTask {
ticks: ticks.clone(),
shutdown_called: shutdown_called.clone(),
},
Duration::from_mins(1), shutdown.clone(),
);
shutdown.cancel();
worker.join().await.expect("clean exit");
assert_eq!(shutdown_called.load(Ordering::SeqCst), 1);
assert_eq!(ticks.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn failing_tick_does_not_stop_worker() {
let ticks = Arc::new(AtomicU32::new(0));
let shutdown = CancellationToken::new();
let _worker = PeriodicWorker::spawn(
FailingTask {
ticks: ticks.clone(),
},
Duration::from_millis(15),
shutdown.clone(),
);
tokio::time::sleep(Duration::from_millis(80)).await;
shutdown.cancel();
let n = ticks.load(Ordering::SeqCst);
assert!(n >= 3, "got {n} ticks, expected >=3 even with errors");
}
#[tokio::test]
async fn biased_select_prioritises_shutdown_over_tick() {
let ticks = Arc::new(AtomicU32::new(0));
let shutdown = CancellationToken::new();
let worker = PeriodicWorker::spawn(
CountingTask {
ticks: ticks.clone(),
},
Duration::from_millis(1), shutdown.clone(),
);
let t0 = Instant::now();
tokio::time::sleep(Duration::from_millis(20)).await;
shutdown.cancel();
worker.join().await.expect("clean exit");
let elapsed = t0.elapsed();
assert!(
elapsed < Duration::from_millis(500),
"worker took {elapsed:?} to shut down (expected <500ms)",
);
}
}