use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::watch;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
fn deadline_after(now: Instant, budget: Duration) -> Instant {
now.checked_add(budget)
.unwrap_or_else(|| now + Duration::from_secs(60 * 60 * 24 * 365))
}
struct State {
deadline: Instant,
suspended: Vec<(u64, Duration)>,
next_id: u64,
}
pub struct Watchdog {
state: Mutex<State>,
deadline_tx: watch::Sender<Instant>,
}
impl Watchdog {
pub fn new(budget: Duration) -> Self {
let deadline = deadline_after(Instant::now(), budget);
let (deadline_tx, _) = watch::channel(deadline);
Self {
state: Mutex::new(State { deadline, suspended: Vec::new(), next_id: 0 }),
deadline_tx,
}
}
pub async fn run(self: Arc<Self>, elapsed: Arc<AtomicBool>, token: CancellationToken) {
let mut deadline_rx = self.deadline_tx.subscribe();
loop {
let deadline = *deadline_rx.borrow_and_update();
if Instant::now() >= deadline {
elapsed.store(true, Ordering::SeqCst);
token.cancel();
return;
}
tokio::select! {
_ = tokio::time::sleep_until(deadline) => {}
_ = deadline_rx.changed() => {}
}
}
}
pub fn hold(self: &Arc<Self>, budget: Duration) -> WatchdogHold {
#[allow(clippy::expect_used)]
let mut state = self.state.lock().expect("watchdog state poisoned");
let now = Instant::now();
let id = state.next_id;
state.next_id += 1;
let remaining = state.deadline.duration_since(now);
state.suspended.push((id, remaining));
state.deadline = deadline_after(now, budget);
self.deadline_tx.send_replace(state.deadline);
drop(state);
WatchdogHold { watchdog: self.clone(), id }
}
fn release(&self, id: u64) {
#[allow(clippy::expect_used)]
let mut state = self.state.lock().expect("watchdog state poisoned");
let Some(index) = state.suspended.iter().position(|(hold_id, _)| *hold_id == id) else {
return;
};
let (_, saved) = state.suspended.remove(index);
if index == state.suspended.len() {
state.deadline = deadline_after(Instant::now(), saved);
self.deadline_tx.send_replace(state.deadline);
} else {
state.suspended[index].1 = saved;
}
}
}
pub struct WatchdogHold {
watchdog: Arc<Watchdog>,
id: u64,
}
impl Drop for WatchdogHold {
fn drop(&mut self) {
self.watchdog.release(self.id);
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn settle() {
for _ in 0..10 {
tokio::task::yield_now().await;
}
}
fn spawn_watchdog(
budget: Duration,
) -> (Arc<Watchdog>, Arc<AtomicBool>, CancellationToken, tokio::task::JoinHandle<()>) {
let watchdog = Arc::new(Watchdog::new(budget));
let elapsed = Arc::new(AtomicBool::new(false));
let token = CancellationToken::new();
let handle = tokio::spawn(watchdog.clone().run(elapsed.clone(), token.clone()));
(watchdog, elapsed, token, handle)
}
#[tokio::test(start_paused = true)]
async fn fires_at_deadline() {
let (_watchdog, elapsed, token, handle) = spawn_watchdog(Duration::from_secs(1));
settle().await;
tokio::time::advance(Duration::from_millis(999)).await;
settle().await;
assert!(!elapsed.load(Ordering::SeqCst), "fired before the deadline");
tokio::time::advance(Duration::from_millis(2)).await;
handle.await.expect("timer task");
assert!(elapsed.load(Ordering::SeqCst));
assert!(token.is_cancelled());
}
#[tokio::test(start_paused = true)]
async fn hold_freezes_script_clock_and_restores_remaining() {
let (watchdog, elapsed, _token, handle) = spawn_watchdog(Duration::from_secs(1));
settle().await;
tokio::time::advance(Duration::from_millis(400)).await;
settle().await;
let hold = watchdog.hold(Duration::from_secs(10));
tokio::time::advance(Duration::from_secs(5)).await;
settle().await;
assert!(!elapsed.load(Ordering::SeqCst), "fired while the script clock was frozen");
drop(hold);
tokio::time::advance(Duration::from_millis(599)).await;
settle().await;
assert!(!elapsed.load(Ordering::SeqCst), "restored remaining was shortened");
tokio::time::advance(Duration::from_millis(2)).await;
handle.await.expect("timer task");
assert!(elapsed.load(Ordering::SeqCst));
}
#[tokio::test(start_paused = true)]
async fn hold_budget_overrun_fires() {
let (watchdog, elapsed, token, handle) = spawn_watchdog(Duration::from_secs(60));
settle().await;
let _hold = watchdog.hold(Duration::from_millis(500));
tokio::time::advance(Duration::from_millis(501)).await;
handle.await.expect("timer task");
assert!(elapsed.load(Ordering::SeqCst), "hold overran its budget but didn't fire");
assert!(token.is_cancelled());
}
#[tokio::test(start_paused = true)]
async fn nested_holds_restore_in_lifo_order() {
let (watchdog, elapsed, _token, handle) = spawn_watchdog(Duration::from_secs(1));
settle().await;
let outer = watchdog.hold(Duration::from_secs(10));
tokio::time::advance(Duration::from_secs(2)).await;
settle().await;
let inner = watchdog.hold(Duration::from_secs(30));
tokio::time::advance(Duration::from_secs(20)).await;
settle().await;
assert!(!elapsed.load(Ordering::SeqCst));
drop(inner);
tokio::time::advance(Duration::from_millis(7_999)).await;
settle().await;
assert!(!elapsed.load(Ordering::SeqCst), "outer remaining was shortened");
drop(outer);
tokio::time::advance(Duration::from_millis(999)).await;
settle().await;
assert!(!elapsed.load(Ordering::SeqCst), "script remaining was shortened");
tokio::time::advance(Duration::from_millis(2)).await;
handle.await.expect("timer task");
assert!(elapsed.load(Ordering::SeqCst));
}
#[tokio::test(start_paused = true)]
async fn out_of_order_release_keeps_chain_consistent() {
let (watchdog, elapsed, _token, handle) = spawn_watchdog(Duration::from_secs(1));
settle().await;
let first = watchdog.hold(Duration::from_secs(10));
let second = watchdog.hold(Duration::from_secs(30));
drop(first);
tokio::time::advance(Duration::from_secs(20)).await;
settle().await;
assert!(!elapsed.load(Ordering::SeqCst), "second hold's budget was lost");
drop(second);
tokio::time::advance(Duration::from_millis(999)).await;
settle().await;
assert!(!elapsed.load(Ordering::SeqCst), "script remaining was lost");
tokio::time::advance(Duration::from_millis(2)).await;
handle.await.expect("timer task");
assert!(elapsed.load(Ordering::SeqCst));
}
#[tokio::test(start_paused = true)]
async fn hold_acquired_after_fire_is_harmless() {
let (watchdog, elapsed, _token, handle) = spawn_watchdog(Duration::from_millis(10));
tokio::time::advance(Duration::from_millis(11)).await;
handle.await.expect("timer task");
assert!(elapsed.load(Ordering::SeqCst));
let hold = watchdog.hold(Duration::from_secs(5));
drop(hold);
}
}