1use std::sync::atomic::{AtomicBool, Ordering};
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23
24use tokio::sync::watch;
25use tokio::time::Instant;
26use tokio_util::sync::CancellationToken;
27
28fn deadline_after(now: Instant, budget: Duration) -> Instant {
31 now.checked_add(budget)
32 .unwrap_or_else(|| now + Duration::from_secs(60 * 60 * 24 * 365))
33}
34
35struct State {
36 deadline: Instant,
39 suspended: Vec<(u64, Duration)>,
44 next_id: u64,
45}
46
47pub struct Watchdog {
53 state: Mutex<State>,
54 deadline_tx: watch::Sender<Instant>,
55}
56
57impl Watchdog {
58 pub fn new(budget: Duration) -> Self {
60 let deadline = deadline_after(Instant::now(), budget);
61 let (deadline_tx, _) = watch::channel(deadline);
62 Self {
63 state: Mutex::new(State { deadline, suspended: Vec::new(), next_id: 0 }),
64 deadline_tx,
65 }
66 }
67
68 pub async fn run(self: Arc<Self>, elapsed: Arc<AtomicBool>, token: CancellationToken) {
72 let mut deadline_rx = self.deadline_tx.subscribe();
73 loop {
74 let deadline = *deadline_rx.borrow_and_update();
75 if Instant::now() >= deadline {
76 elapsed.store(true, Ordering::SeqCst);
77 token.cancel();
78 return;
79 }
80 tokio::select! {
81 _ = tokio::time::sleep_until(deadline) => {}
82 _ = deadline_rx.changed() => {}
85 }
86 }
87 }
88
89 pub fn hold(self: &Arc<Self>, budget: Duration) -> WatchdogHold {
92 #[allow(clippy::expect_used)]
93 let mut state = self.state.lock().expect("watchdog state poisoned");
94 let now = Instant::now();
95 let id = state.next_id;
96 state.next_id += 1;
97 let remaining = state.deadline.duration_since(now);
100 state.suspended.push((id, remaining));
101 state.deadline = deadline_after(now, budget);
102 self.deadline_tx.send_replace(state.deadline);
103 drop(state);
104 WatchdogHold { watchdog: self.clone(), id }
105 }
106
107 fn release(&self, id: u64) {
108 #[allow(clippy::expect_used)]
109 let mut state = self.state.lock().expect("watchdog state poisoned");
110 let Some(index) = state.suspended.iter().position(|(hold_id, _)| *hold_id == id) else {
111 return;
113 };
114 let (_, saved) = state.suspended.remove(index);
115 if index == state.suspended.len() {
116 state.deadline = deadline_after(Instant::now(), saved);
119 self.deadline_tx.send_replace(state.deadline);
120 } else {
121 state.suspended[index].1 = saved;
125 }
126 }
127}
128
129pub struct WatchdogHold {
133 watchdog: Arc<Watchdog>,
134 id: u64,
135}
136
137impl Drop for WatchdogHold {
138 fn drop(&mut self) {
139 self.watchdog.release(self.id);
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 async fn settle() {
149 for _ in 0..10 {
150 tokio::task::yield_now().await;
151 }
152 }
153
154 fn spawn_watchdog(
155 budget: Duration,
156 ) -> (Arc<Watchdog>, Arc<AtomicBool>, CancellationToken, tokio::task::JoinHandle<()>) {
157 let watchdog = Arc::new(Watchdog::new(budget));
158 let elapsed = Arc::new(AtomicBool::new(false));
159 let token = CancellationToken::new();
160 let handle = tokio::spawn(watchdog.clone().run(elapsed.clone(), token.clone()));
161 (watchdog, elapsed, token, handle)
162 }
163
164 #[tokio::test(start_paused = true)]
165 async fn fires_at_deadline() {
166 let (_watchdog, elapsed, token, handle) = spawn_watchdog(Duration::from_secs(1));
167 settle().await;
168 tokio::time::advance(Duration::from_millis(999)).await;
169 settle().await;
170 assert!(!elapsed.load(Ordering::SeqCst), "fired before the deadline");
171 tokio::time::advance(Duration::from_millis(2)).await;
172 handle.await.expect("timer task");
173 assert!(elapsed.load(Ordering::SeqCst));
174 assert!(token.is_cancelled());
175 }
176
177 #[tokio::test(start_paused = true)]
178 async fn hold_freezes_script_clock_and_restores_remaining() {
179 let (watchdog, elapsed, _token, handle) = spawn_watchdog(Duration::from_secs(1));
180 settle().await;
181 tokio::time::advance(Duration::from_millis(400)).await;
182 settle().await;
183
184 let hold = watchdog.hold(Duration::from_secs(10));
186 tokio::time::advance(Duration::from_secs(5)).await;
187 settle().await;
188 assert!(!elapsed.load(Ordering::SeqCst), "fired while the script clock was frozen");
189
190 drop(hold);
192 tokio::time::advance(Duration::from_millis(599)).await;
193 settle().await;
194 assert!(!elapsed.load(Ordering::SeqCst), "restored remaining was shortened");
195 tokio::time::advance(Duration::from_millis(2)).await;
196 handle.await.expect("timer task");
197 assert!(elapsed.load(Ordering::SeqCst));
198 }
199
200 #[tokio::test(start_paused = true)]
201 async fn hold_budget_overrun_fires() {
202 let (watchdog, elapsed, token, handle) = spawn_watchdog(Duration::from_secs(60));
203 settle().await;
204 let _hold = watchdog.hold(Duration::from_millis(500));
205 tokio::time::advance(Duration::from_millis(501)).await;
206 handle.await.expect("timer task");
207 assert!(elapsed.load(Ordering::SeqCst), "hold overran its budget but didn't fire");
208 assert!(token.is_cancelled());
209 }
210
211 #[tokio::test(start_paused = true)]
212 async fn nested_holds_restore_in_lifo_order() {
213 let (watchdog, elapsed, _token, handle) = spawn_watchdog(Duration::from_secs(1));
214 settle().await;
215
216 let outer = watchdog.hold(Duration::from_secs(10));
217 tokio::time::advance(Duration::from_secs(2)).await;
218 settle().await;
219 let inner = watchdog.hold(Duration::from_secs(30));
221 tokio::time::advance(Duration::from_secs(20)).await;
222 settle().await;
223 assert!(!elapsed.load(Ordering::SeqCst));
224
225 drop(inner);
228 tokio::time::advance(Duration::from_millis(7_999)).await;
229 settle().await;
230 assert!(!elapsed.load(Ordering::SeqCst), "outer remaining was shortened");
231 drop(outer);
232 tokio::time::advance(Duration::from_millis(999)).await;
233 settle().await;
234 assert!(!elapsed.load(Ordering::SeqCst), "script remaining was shortened");
235 tokio::time::advance(Duration::from_millis(2)).await;
236 handle.await.expect("timer task");
237 assert!(elapsed.load(Ordering::SeqCst));
238 }
239
240 #[tokio::test(start_paused = true)]
241 async fn out_of_order_release_keeps_chain_consistent() {
242 let (watchdog, elapsed, _token, handle) = spawn_watchdog(Duration::from_secs(1));
243 settle().await;
244
245 let first = watchdog.hold(Duration::from_secs(10));
246 let second = watchdog.hold(Duration::from_secs(30));
247 drop(first);
251 tokio::time::advance(Duration::from_secs(20)).await;
252 settle().await;
253 assert!(!elapsed.load(Ordering::SeqCst), "second hold's budget was lost");
254 drop(second);
255 tokio::time::advance(Duration::from_millis(999)).await;
256 settle().await;
257 assert!(!elapsed.load(Ordering::SeqCst), "script remaining was lost");
258 tokio::time::advance(Duration::from_millis(2)).await;
259 handle.await.expect("timer task");
260 assert!(elapsed.load(Ordering::SeqCst));
261 }
262
263 #[tokio::test(start_paused = true)]
264 async fn hold_acquired_after_fire_is_harmless() {
265 let (watchdog, elapsed, _token, handle) = spawn_watchdog(Duration::from_millis(10));
266 tokio::time::advance(Duration::from_millis(11)).await;
267 handle.await.expect("timer task");
268 assert!(elapsed.load(Ordering::SeqCst));
269 let hold = watchdog.hold(Duration::from_secs(5));
271 drop(hold);
272 }
273}