1use std::sync::Arc;
43use std::time::Duration;
44
45#[cfg(test)]
46use mock_instant::global::SystemTime;
47#[cfg(not(test))]
48use std::time::SystemTime;
49
50use tracing::{debug, warn};
51
52use tokio::sync::Mutex;
53use tokio::time::sleep;
54use tokio::{select, signal, spawn};
55
56#[cfg(all(feature = "instant", test))]
57use mock_instant::global::Instant;
58#[cfg(all(feature = "instant", not(test)))]
59use std::time::Instant;
60
61#[cfg(feature = "instant")]
62type RunTimer = Instant;
63
64#[cfg(feature = "system")]
65type RunTimer = SystemTime;
66
67fn now_since_epoch_millis() -> u128 {
69 SystemTime::now()
70 .duration_since(SystemTime::UNIX_EPOCH)
71 .expect("Y2k happened?")
72 .as_millis()
73}
74
75#[cfg(feature = "instant")]
76fn run_timer_now() -> RunTimer {
77 Instant::now()
78}
79
80#[cfg(feature = "instant")]
81fn duration_since(now: RunTimer, old: RunTimer) -> Duration {
82 now - old
83}
84
85#[cfg(feature = "system")]
86fn run_timer_now() -> RunTimer {
87 SystemTime::now()
88}
89
90#[cfg(feature = "system")]
91fn duration_since(now: RunTimer, old: RunTimer) -> Duration {
92 now.duration_since(old).expect("Old before now?")
93}
94
95#[async_trait::async_trait]
97pub trait AsyncTask: Send + Sync {
98 async fn run(&self) -> Result<(), String>;
100}
101
102struct ManagedTask {
104 name: String,
105 interval: Duration,
106 offset: Duration,
107
108 task: Arc<dyn AsyncTask>,
109 started_at: Option<RunTimer>,
110 next_run: RunTimer,
111}
112
113impl ManagedTask {
114 fn new(name: String, interval: Duration, offset: Duration, task: Arc<dyn AsyncTask>) -> Self {
115 Self {
116 name,
117 interval,
118 offset,
119 task,
120 started_at: None,
121 next_run: run_timer_now(),
122 }
123 }
124
125 fn started_at(&self) -> Option<RunTimer> {
126 self.started_at
127 }
128
129 fn start(&mut self) {
130 self.started_at = Some(run_timer_now());
131 }
132
133 fn stop(&mut self) {
134 self.started_at = None;
135 }
136}
137
138#[derive(Clone)]
140pub struct TaskManager {
141 tasks: Arc<Mutex<Vec<Arc<Mutex<ManagedTask>>>>>,
142 scheduler_tick: Duration,
144}
145
146impl Default for TaskManager {
147 fn default() -> Self {
149 Self::new(500)
150 }
151}
152
153impl TaskManager {
154 pub fn new(millis: u64) -> Self {
157 TaskManager {
158 tasks: Arc::new(Mutex::new(Vec::new())),
159 scheduler_tick: Duration::from_millis(millis),
160 }
161 }
162
163 pub async fn add<T>(&self, name: &str, interval: Duration, task: T)
177 where
178 T: AsyncTask + 'static,
179 {
180 self.add_offset(name, interval, Duration::ZERO, task).await
181 }
182
183 pub async fn add_offset<T>(&self, name: &str, interval: Duration, offset: Duration, task: T)
195 where
196 T: AsyncTask + 'static,
197 {
198 if interval == Duration::ZERO {
199 panic!("Interval must be nonzero!");
200 }
201 if offset >= interval {
202 panic!("Offset must be strictly less than interval!");
203 }
204
205 let mut tasks = self.tasks.lock().await;
206
207 let managed = ManagedTask::new(name.to_owned(), interval, offset, Arc::new(task));
208 tasks.push(Arc::new(Mutex::new(managed)));
209 }
210
211 pub async fn run(&self) {
213 debug!(
214 "Initializing Recurring Tasks Manager using {}",
215 if cfg!(feature = "instant") {
216 "Instant"
217 } else if cfg!(feature = "system") {
218 "SystemTime"
219 } else {
220 "UNKNOWN"
221 }
222 );
223
224 for managed_task in self.tasks.lock().await.iter() {
225 let mut managed = managed_task.lock().await;
226
227 let initial_delay = calculate_initial_delay(managed.interval, managed.offset);
228
229 debug!(
230 "Starting task {} in {} ms",
231 managed.name,
232 initial_delay.as_millis(),
233 );
234
235 managed.next_run = run_timer_now() + initial_delay;
236 }
237
238 let tasks = self.tasks.clone();
239 loop {
240 let tasks = tasks.lock().await;
241 for managed_task in tasks.iter() {
242 let mut managed = managed_task.lock().await;
243 let task_name = managed.name.clone();
244
245 let now = run_timer_now();
246 let prev_run = managed.next_run;
247 if now >= prev_run {
248 if let Some(started_at) = managed.started_at() {
250 debug!(
251 "Skipping run for task {task_name} (previous run from {:?} not finished)",
252 started_at
253 );
254 } else {
255 managed.start();
257 let interval = managed.interval;
258 let next_run = prev_run + interval;
259 managed.next_run = if next_run >= now {
261 next_run
262 } else {
263 let diff = duration_since(now, next_run);
264 warn!(
265 "Falling behind schedule on {task_name} by {} ms",
266 diff.as_millis()
267 );
268 now + interval
269 };
270
271 let managed_task = managed_task.clone();
272 spawn(async move {
273 debug!("Running task {task_name}");
274 if let Err(e) = managed_task.lock().await.task.run().await {
275 warn!("Error in task {task_name}: {e}");
276 }
277 managed_task.lock().await.stop();
278 });
279 }
280 }
281 }
282
283 sleep(self.scheduler_tick).await;
284 }
285 }
286
287 pub async fn run_with_signal(&self) {
289 let manager = self.clone();
290
291 let run_handle = spawn(async move {
292 manager.run().await;
293 });
294
295 select! {
296 _ = signal::ctrl_c() => {
297 warn!("Ctrl+C received, shutting down recurring tasks...");
298 }
299 _ = run_handle => {}
300 }
301 }
302}
303
304fn calculate_initial_delay(interval: Duration, offset: Duration) -> Duration {
308 let now_since_epoch_millis = now_since_epoch_millis();
309 let interval_millis = interval.as_millis();
310 let offset_millis = offset.as_millis();
311
312 let next_scheduled_time =
315 (now_since_epoch_millis / interval_millis) * interval_millis + offset_millis;
316 let scheduled_from_now = if next_scheduled_time > now_since_epoch_millis {
318 next_scheduled_time - now_since_epoch_millis
319 } else {
320 next_scheduled_time + interval_millis - now_since_epoch_millis
321 };
322 Duration::from_millis(scheduled_from_now as u64)
323}
324
325#[cfg(test)]
326mod tests {
327 use mock_instant::global::MockClock;
328
329 use super::*;
330
331 pub struct TestTask;
332
333 #[async_trait::async_trait]
334 impl AsyncTask for TestTask {
335 async fn run(&self) -> Result<(), String> {
336 Ok(())
337 }
338 }
339
340 #[test]
341 fn half_offset() {
342 let interval = Duration::from_secs(60);
343 let offset = Duration::from_secs(30);
344
345 MockClock::set_system_time(Duration::from_secs(0));
346 let delay = calculate_initial_delay(interval, offset);
347 assert_eq!(delay, offset, "0 is offset");
348
349 MockClock::set_system_time(offset);
350 let delay = calculate_initial_delay(interval, offset);
351 assert_eq!(delay, interval, "offset is interval");
352
353 let diff = Duration::from_secs(15);
354 MockClock::set_system_time(offset - diff);
355 let delay = calculate_initial_delay(interval, offset);
356 assert_eq!(delay, diff, "less than offset is offset remainder");
357
358 let diff = Duration::from_secs(15);
359 MockClock::set_system_time(offset + diff);
360 let delay = calculate_initial_delay(interval, offset);
361 assert_eq!(
362 delay,
363 interval - diff,
364 "more than offset is interval remainder"
365 );
366 }
367
368 #[test]
369 fn quarter_offset() {
370 let interval = Duration::from_secs(60);
371 let offset = Duration::from_secs(15);
372
373 MockClock::set_system_time(Duration::from_secs(0));
374 let delay = calculate_initial_delay(interval, offset);
375 assert_eq!(delay, offset, "0 is offset");
376
377 MockClock::set_system_time(offset);
378 let delay = calculate_initial_delay(interval, offset);
379 assert_eq!(delay, interval, "offset is interval");
380
381 let diff = Duration::from_secs(5);
382 MockClock::set_system_time(offset - diff);
383 let delay = calculate_initial_delay(interval, offset);
384 assert_eq!(delay, diff, "less than offset is offset remainder");
385
386 let diff = Duration::from_secs(15);
387 MockClock::set_system_time(offset + diff);
388 let delay = calculate_initial_delay(interval, offset);
389 assert_eq!(
390 delay,
391 interval - diff,
392 "more than offset is interval remainder"
393 );
394 }
395
396 #[tokio::test]
397 #[should_panic(expected = "Interval must be nonzero!")]
398 async fn interval_nonzero() {
399 TaskManager::default()
400 .add("Fails", Duration::from_secs(0), TestTask {})
401 .await;
402 }
403
404 #[tokio::test]
405 #[should_panic(expected = "Offset must be strictly less than interval!")]
406 async fn offset_match_interval() {
407 TaskManager::default()
408 .add_offset(
409 "Fails",
410 Duration::from_secs(10),
411 Duration::from_secs(10),
412 TestTask {},
413 )
414 .await;
415 }
416
417 #[tokio::test]
418 #[should_panic(expected = "Offset must be strictly less than interval!")]
419 async fn offset_exceed_interval() {
420 TaskManager::default()
421 .add_offset(
422 "Fails",
423 Duration::from_secs(10),
424 Duration::from_secs(20),
425 TestTask {},
426 )
427 .await;
428 }
429}