1use std::sync::Arc;
41use std::time::Duration;
42
43use tracing::{debug, error, info, warn};
44
45use futures::future::join_all;
46
47use tokio::task::JoinHandle;
48use tokio::time::{interval, sleep, timeout};
49use tokio::{select, spawn};
50pub use tokio_util::sync::CancellationToken;
51
52#[cfg(test)]
53use mock_instant::global::SystemTime;
54#[cfg(not(test))]
55use std::time::SystemTime;
56
57fn now_since_epoch_millis() -> u128 {
59 SystemTime::now()
60 .duration_since(SystemTime::UNIX_EPOCH)
61 .expect("Y2k happened?")
62 .as_millis()
63}
64
65#[async_trait::async_trait]
67pub trait AsyncTask: Send + Sync {
68 async fn run(&self, cancel: CancellationToken) -> Result<(), String>;
74}
75
76#[derive(Clone)]
78struct ManagedTask {
79 name: String,
80 interval: Duration,
81 offset: Duration,
82 task: Arc<dyn AsyncTask>,
83}
84
85impl ManagedTask {
86 fn new(name: String, interval: Duration, offset: Duration, task: Arc<dyn AsyncTask>) -> Self {
87 Self {
88 name,
89 interval,
90 offset,
91 task,
92 }
93 }
94}
95
96#[derive(Clone)]
98pub struct TaskManager {
99 tasks: Vec<ManagedTask>,
100}
101
102impl TaskManager {
103 pub fn new() -> Self {
104 TaskManager { tasks: Vec::new() }
105 }
106
107 pub fn add<T>(&mut self, name: &str, interval: Duration, task: T)
121 where
122 T: AsyncTask + 'static,
123 {
124 self.add_offset(name, interval, Duration::ZERO, task)
125 }
126
127 pub fn add_offset<T>(&mut self, name: &str, interval: Duration, offset: Duration, task: T)
139 where
140 T: AsyncTask + 'static,
141 {
142 if interval == Duration::ZERO {
143 panic!("Interval must be nonzero!");
144 }
145 if offset >= interval {
146 panic!("Offset must be strictly less than interval!");
147 }
148
149 let managed = ManagedTask::new(name.to_owned(), interval, offset, Arc::new(task));
150 self.tasks.push(managed);
151 }
152
153 pub async fn run_forever(self) {
155 self.run_with_cancel(CancellationToken::new()).await
156 }
157
158 async fn task_spawn(
159 managed: ManagedTask,
160 running: Option<JoinHandle<()>>,
161 cancel: CancellationToken,
162 ) -> Option<JoinHandle<()>> {
163 if running.as_ref().is_some_and(|h| !h.is_finished()) {
165 debug!(
166 "Skipping run for task {} (previous run not finished)",
167 managed.name
168 );
169 running
170 } else {
171 let handle = spawn(async move {
172 debug!("Running task {}", managed.name);
173 if let Err(e) = managed.task.run(cancel).await {
174 warn!("Error in task {}: {e}", managed.name);
175 }
176 });
177 Some(handle)
178 }
179 }
180
181 pub async fn run_with_cancel(self, cancel: CancellationToken) {
188 join_all(self.tasks.clone().into_iter().map(|managed| {
189 let cancel = cancel.clone();
190 let mut running = None;
191 let initial_delay = calculate_initial_delay(managed.interval, managed.offset);
192
193 info!(
194 "Starting task {} in {} ms",
195 managed.name,
196 initial_delay.as_millis(),
197 );
198
199 spawn(async move {
200 select! {
201 _ = sleep(initial_delay) => {
202 let mut ticker = interval(managed.interval);
203
204 loop {
205 select! {
206 _ = ticker.tick() => {
207 let managed = managed.clone();
208 let cancel = cancel.child_token();
209 running = Self::task_spawn(managed, running, cancel).await;
210 }
211 _ = cancel.cancelled() => {
212 debug!("Cancelled Recurring Tasks Manager loop for '{}'", managed.name);
213 break;
214 }
215 }
216 }
217 }
218 _ = cancel.cancelled() => {
219 debug!("Cancelled Recurring Tasks Manager sleep for '{}'", managed.name);
220 }
221 }
222 })
223 }))
224 .await;
225 }
226
227 pub async fn run_with_signal(self, wait: Duration) {
231 let cancel = CancellationToken::new();
232
233 let mut handle = spawn({
234 let cancel = cancel.child_token();
235 async move {
236 self.run_with_cancel(cancel).await;
237 }
238 });
239
240 select! {
241 res = &mut handle => {
242 error!("Manager stopped unexpectedly: {res:?}")
243 }
244 _ = shutdown_signal() => {
245 warn!("Shutdown signal received, stopping recurring tasks...");
246 cancel.cancel();
248 match timeout(wait, &mut handle).await {
250 Ok(_) => debug!("Shutdown complete"),
251 Err(_) => {
252 warn!("Aborting tasks after timeout");
253 handle.abort();
254 let _ = handle.await;
256 }
257 }
258 }
259 }
260 }
261}
262
263async fn shutdown_signal() {
264 let sigint = async {
265 let _ = tokio::signal::ctrl_c().await;
266 };
267
268 #[cfg(unix)]
269 let sigterm = async {
270 use tokio::signal::unix::{SignalKind, signal};
271 if let Ok(mut s) = signal(SignalKind::terminate()) {
272 s.recv().await;
273 }
274 };
275
276 #[cfg(not(unix))]
277 let sigterm = std::future::pending::<()>();
278
279 tokio::select! {
280 _ = sigint => {},
281 _ = sigterm => {},
282 }
283}
284
285fn calculate_initial_delay(interval: Duration, offset: Duration) -> Duration {
289 let now_since_epoch_millis = now_since_epoch_millis();
290 let interval_millis = interval.as_millis();
291 let offset_millis = offset.as_millis();
292
293 let next_scheduled_time =
296 (now_since_epoch_millis / interval_millis) * interval_millis + offset_millis;
297 let scheduled_from_now = if next_scheduled_time > now_since_epoch_millis {
299 next_scheduled_time - now_since_epoch_millis
300 } else {
301 next_scheduled_time + interval_millis - now_since_epoch_millis
302 };
303 Duration::from_millis(scheduled_from_now as u64)
304}
305
306#[cfg(test)]
307mod tests {
308 use std::sync::Once;
309
310 use tokio::sync::Mutex;
311
312 use mock_instant::global::MockClock;
313
314 use super::*;
315
316 static INIT: Once = Once::new();
317
318 #[allow(unused)]
320 pub fn init_logging() {
321 use tracing_subscriber::{EnvFilter, fmt};
322
323 INIT.call_once(|| {
324 let filter =
325 EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
326 fmt().with_env_filter(filter).with_test_writer().init();
327 });
328 }
329
330 #[derive(Clone)]
331 pub struct TestTask {
332 count: Arc<Mutex<usize>>,
333 }
334
335 impl TestTask {
336 pub fn new() -> Self {
337 Self {
338 count: Arc::new(Mutex::new(0)),
339 }
340 }
341
342 pub async fn count(&self) -> usize {
343 *self.count.lock().await
344 }
345 }
346
347 #[async_trait::async_trait]
348 impl AsyncTask for TestTask {
349 async fn run(&self, _cancel: CancellationToken) -> Result<(), String> {
350 let mut count = self.count.lock().await;
351 *count += 1;
352 Ok(())
353 }
354 }
355
356 #[test]
357 fn half_offset() {
358 let interval = Duration::from_secs(60);
359 let offset = Duration::from_secs(30);
360
361 MockClock::set_system_time(Duration::from_secs(0));
362 let delay = calculate_initial_delay(interval, offset);
363 assert_eq!(delay, offset, "0 is offset");
364
365 MockClock::set_system_time(offset);
366 let delay = calculate_initial_delay(interval, offset);
367 assert_eq!(delay, interval, "offset is interval");
368
369 let diff = Duration::from_secs(15);
370 MockClock::set_system_time(offset - diff);
371 let delay = calculate_initial_delay(interval, offset);
372 assert_eq!(delay, diff, "less than offset is offset remainder");
373
374 let diff = Duration::from_secs(15);
375 MockClock::set_system_time(offset + diff);
376 let delay = calculate_initial_delay(interval, offset);
377 assert_eq!(
378 delay,
379 interval - diff,
380 "more than offset is interval remainder"
381 );
382 }
383
384 #[test]
385 fn quarter_offset() {
386 let interval = Duration::from_secs(60);
387 let offset = Duration::from_secs(15);
388
389 MockClock::set_system_time(Duration::from_secs(0));
390 let delay = calculate_initial_delay(interval, offset);
391 assert_eq!(delay, offset, "0 is offset");
392
393 MockClock::set_system_time(offset);
394 let delay = calculate_initial_delay(interval, offset);
395 assert_eq!(delay, interval, "offset is interval");
396
397 let diff = Duration::from_secs(5);
398 MockClock::set_system_time(offset - diff);
399 let delay = calculate_initial_delay(interval, offset);
400 assert_eq!(delay, diff, "less than offset is offset remainder");
401
402 let diff = Duration::from_secs(15);
403 MockClock::set_system_time(offset + diff);
404 let delay = calculate_initial_delay(interval, offset);
405 assert_eq!(
406 delay,
407 interval - diff,
408 "more than offset is interval remainder"
409 );
410 }
411
412 #[tokio::test]
413 #[should_panic(expected = "Interval must be nonzero!")]
414 async fn interval_nonzero() {
415 TaskManager::new().add("Fails", Duration::from_secs(0), TestTask::new());
416 }
417
418 #[tokio::test]
419 #[should_panic(expected = "Offset must be strictly less than interval!")]
420 async fn offset_match_interval() {
421 TaskManager::new().add_offset(
422 "Fails",
423 Duration::from_secs(10),
424 Duration::from_secs(10),
425 TestTask::new(),
426 );
427 }
428
429 #[tokio::test]
430 #[should_panic(expected = "Offset must be strictly less than interval!")]
431 async fn offset_exceed_interval() {
432 TaskManager::new().add_offset(
433 "Fails",
434 Duration::from_secs(10),
435 Duration::from_secs(20),
436 TestTask::new(),
437 );
438 }
439
440 #[tokio::test]
441 async fn run_cancelled() {
442 let mut manager = TaskManager::new();
444 let task = TestTask::new();
445 let cancel = CancellationToken::new();
446
447 manager.add("Test", Duration::from_millis(100), task.clone());
448
449 let mut run = spawn({
450 let cancel = cancel.clone();
451 async move { manager.run_with_cancel(cancel).await }
452 });
453
454 let mut test = spawn({
455 let cancel = cancel.clone();
456
457 async move {
458 sleep(Duration::from_millis(120)).await;
459 assert_eq!(task.count().await, 1);
460 sleep(Duration::from_millis(120)).await;
461 assert_eq!(task.count().await, 2);
462
463 cancel.cancel();
464 sleep(Duration::from_millis(120)).await;
465 panic!("Cancel did not stop manager");
466 }
467 });
468
469 select! {
470 res = &mut run => {
471 if res.is_err() || !cancel.is_cancelled() {
472 panic!("Manager stopped unexpectedly: {res:?}");
473 }
474 }
475 res = &mut test => {
476 run.abort();
477 res.unwrap();
478 }
479 }
480 }
481
482 #[tokio::test]
483 async fn run_cancelled_early() {
484 let mut manager = TaskManager::new();
486 let task = TestTask::new();
487 let cancel = CancellationToken::new();
488
489 manager.add("Test", Duration::from_millis(10000), task.clone());
490
491 let mut run = spawn({
492 let cancel = cancel.clone();
493 async move { manager.run_with_cancel(cancel).await }
494 });
495
496 let mut test = spawn({
497 let cancel = cancel.clone();
498
499 async move {
500 cancel.cancel();
501 sleep(Duration::from_millis(120)).await;
502 panic!("Cancel did not stop manager");
503 }
504 });
505
506 select! {
507 res = &mut run => {
508 if res.is_err() || !cancel.is_cancelled() {
509 panic!("Manager stopped unexpectedly: {res:?}");
510 }
511 }
512 res = &mut test => {
513 run.abort();
514 res.unwrap();
515 }
516 }
517 }
518}