hyperi_rustlib/concurrency/
periodic.rs1use std::future::Future;
31use std::time::Duration;
32
33use tokio::task::JoinHandle;
34use tokio::time::{MissedTickBehavior, interval};
35use tokio_util::sync::CancellationToken;
36use tracing::warn;
37
38use super::error::TickError;
39
40pub trait PeriodicTask: Send + 'static {
47 fn tick(&mut self) -> impl Future<Output = Result<(), TickError>> + Send;
49
50 fn shutdown(&mut self) -> impl Future<Output = Result<(), TickError>> + Send {
53 std::future::ready(Ok(()))
54 }
55}
56
57pub struct PeriodicWorker {
63 join: JoinHandle<()>,
64}
65
66impl PeriodicWorker {
67 pub fn spawn<T: PeriodicTask>(
72 mut task: T,
73 interval_duration: Duration,
74 shutdown: CancellationToken,
75 ) -> Self {
76 let join = tokio::spawn(async move {
77 let mut tick = interval(interval_duration);
78 tick.set_missed_tick_behavior(MissedTickBehavior::Delay);
79 tick.tick().await;
81
82 loop {
83 tokio::select! {
84 biased;
85 () = shutdown.cancelled() => {
86 if let Err(e) = task.shutdown().await {
87 warn!(error = %e, "periodic task shutdown hook failed");
88 }
89 return;
90 }
91 _ = tick.tick() => {
92 if let Err(e) = task.tick().await {
93 warn!(error = %e, "periodic task tick failed");
94 }
95 }
96 }
97 }
98 });
99 Self { join }
100 }
101
102 pub async fn join(self) -> Result<(), tokio::task::JoinError> {
104 self.join.await
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use std::sync::Arc;
112 use std::sync::atomic::{AtomicU32, Ordering};
113 use std::time::Instant;
114
115 struct CountingTask {
116 ticks: Arc<AtomicU32>,
117 }
118
119 impl PeriodicTask for CountingTask {
120 async fn tick(&mut self) -> Result<(), TickError> {
121 self.ticks.fetch_add(1, Ordering::SeqCst);
122 Ok(())
123 }
124 }
125
126 struct ShutdownTask {
127 ticks: Arc<AtomicU32>,
128 shutdown_called: Arc<AtomicU32>,
129 }
130
131 impl PeriodicTask for ShutdownTask {
132 async fn tick(&mut self) -> Result<(), TickError> {
133 self.ticks.fetch_add(1, Ordering::SeqCst);
134 Ok(())
135 }
136
137 async fn shutdown(&mut self) -> Result<(), TickError> {
138 self.shutdown_called.fetch_add(1, Ordering::SeqCst);
139 Ok(())
140 }
141 }
142
143 struct FailingTask {
144 ticks: Arc<AtomicU32>,
145 }
146
147 impl PeriodicTask for FailingTask {
148 async fn tick(&mut self) -> Result<(), TickError> {
149 self.ticks.fetch_add(1, Ordering::SeqCst);
150 Err(TickError::Generic("simulated".into()))
151 }
152 }
153
154 #[tokio::test]
155 async fn tick_fires_at_interval() {
156 let ticks = Arc::new(AtomicU32::new(0));
157 let shutdown = CancellationToken::new();
158 let _worker = PeriodicWorker::spawn(
159 CountingTask {
160 ticks: ticks.clone(),
161 },
162 Duration::from_millis(20),
163 shutdown.clone(),
164 );
165 tokio::time::sleep(Duration::from_millis(110)).await;
167 shutdown.cancel();
168 let n = ticks.load(Ordering::SeqCst);
169 assert!((4..=7).contains(&n), "got {n} ticks, expected 4-7");
170 }
171
172 #[tokio::test]
173 async fn first_tick_is_delayed_not_immediate() {
174 let ticks = Arc::new(AtomicU32::new(0));
177 let shutdown = CancellationToken::new();
178 let _worker = PeriodicWorker::spawn(
179 CountingTask {
180 ticks: ticks.clone(),
181 },
182 Duration::from_millis(100),
183 shutdown.clone(),
184 );
185 tokio::time::sleep(Duration::from_millis(10)).await;
188 assert_eq!(ticks.load(Ordering::SeqCst), 0);
189 shutdown.cancel();
190 }
191
192 #[tokio::test]
193 async fn shutdown_hook_called_exactly_once() {
194 let ticks = Arc::new(AtomicU32::new(0));
195 let shutdown_called = Arc::new(AtomicU32::new(0));
196 let shutdown = CancellationToken::new();
197 let worker = PeriodicWorker::spawn(
198 ShutdownTask {
199 ticks: ticks.clone(),
200 shutdown_called: shutdown_called.clone(),
201 },
202 Duration::from_mins(1), shutdown.clone(),
204 );
205 shutdown.cancel();
206 worker.join().await.expect("clean exit");
207 assert_eq!(shutdown_called.load(Ordering::SeqCst), 1);
208 assert_eq!(ticks.load(Ordering::SeqCst), 0);
210 }
211
212 #[tokio::test]
213 async fn failing_tick_does_not_stop_worker() {
214 let ticks = Arc::new(AtomicU32::new(0));
215 let shutdown = CancellationToken::new();
216 let _worker = PeriodicWorker::spawn(
217 FailingTask {
218 ticks: ticks.clone(),
219 },
220 Duration::from_millis(15),
221 shutdown.clone(),
222 );
223 tokio::time::sleep(Duration::from_millis(80)).await;
225 shutdown.cancel();
226 let n = ticks.load(Ordering::SeqCst);
227 assert!(n >= 3, "got {n} ticks, expected >=3 even with errors");
229 }
230
231 #[tokio::test]
232 async fn biased_select_prioritises_shutdown_over_tick() {
233 let ticks = Arc::new(AtomicU32::new(0));
238 let shutdown = CancellationToken::new();
239 let worker = PeriodicWorker::spawn(
240 CountingTask {
241 ticks: ticks.clone(),
242 },
243 Duration::from_millis(1), shutdown.clone(),
245 );
246 let t0 = Instant::now();
247 tokio::time::sleep(Duration::from_millis(20)).await;
249 shutdown.cancel();
250 worker.join().await.expect("clean exit");
251 let elapsed = t0.elapsed();
252 assert!(
255 elapsed < Duration::from_millis(500),
256 "worker took {elapsed:?} to shut down (expected <500ms)",
257 );
258 }
259}