1use std::collections::BinaryHeap;
2use std::fmt::{Debug, Formatter};
3use std::future::Future;
4use std::panic;
5use std::panic::AssertUnwindSafe;
6use std::str::FromStr;
7use std::sync::Arc;
8use chrono::{DateTime, Local, TimeZone, Utc};
9use cron::Schedule;
10use tokio::sync::Notify;
11use tokio::sync::RwLock;
12use tokio::time::timeout;
13use tracing::{trace, info, error};
14use futures::FutureExt;
15
16type JobFunction = Arc<dyn Fn() + Send + Sync + 'static>;
17
18struct InnerScheduler<T: TimeZone = Utc> {
19 scheduled_jobs: RwLock<BinaryHeap<ScheduledJob<T>>>,
20 notify: Notify,
21 timezone: T,
22}
23
24#[derive(Clone)]
25pub struct Scheduler<T: TimeZone = Utc> {
26 inner: Arc<InnerScheduler<T>>,
27}
28
29impl Scheduler<Utc> {
30 pub fn utc() -> Scheduler<Utc> {
31 Scheduler::new_in_timezone(Utc)
32 }
33
34 pub fn local() -> Scheduler<Local> {
35 Scheduler::new_in_timezone(Local)
36 }
37}
38
39impl<Tz: TimeZone + Send + Sync + Debug + 'static> Scheduler<Tz>
40 where
41 Tz::Offset: Send + Sync,
42{
43 pub fn new_in_timezone(tz: Tz) -> Self {
44 let r = Self {
45 inner: Arc::new(InnerScheduler {
46 scheduled_jobs: RwLock::new(BinaryHeap::new()),
47 notify: Notify::new(),
48 timezone: tz,
49 }),
50 };
51 r.run();
52 r
53 }
54
55 pub fn add(&mut self, job: Job) {
56 let Job {
57 cron_line,
58 func,
59 name,
60 } = job;
61 let cron = Schedule::from_str(&cron_line).unwrap();
62 let job = ScheduledJob {
63 dt: cron.upcoming(self.inner.timezone.clone()).next().unwrap(),
64 cron,
65 func,
66 name,
67 };
68 let inner = self.inner.clone();
69 tokio::spawn(async move {
70 info!(name=%job.name, cron_line=?cron_line, first_dt=?job.dt, "Added job to cron tab");
71 let mut queue = inner.scheduled_jobs.write()
72 .await;
73 queue.push(job);
74 drop(queue);
75 inner.notify.notify_one();
77 });
78 }
79
80 pub fn cancel_by_name(&mut self, name: &str) {
81 let name = name.to_string();
82 let inner = self.inner.clone();
83 tokio::spawn(async move {
84 let mut queue = inner.scheduled_jobs.write().await;
85 queue.retain(|job| job.name != name);
86 });
87 }
88
89 fn run(&self) {
90 let inner = self.inner.clone();
91 tokio::spawn(async move {
92 loop {
93 let mut lock = inner.scheduled_jobs.write()
94 .await;
95 let now = Utc::now();
96 while let Some(next) = lock.peek() {
97 if next.dt > now {
98 break;
99 }
100 let to_run = lock.pop().unwrap();
101 let this = to_run.dt.clone();
102 let f = AssertUnwindSafe(to_run.func.clone());
103 let res = panic::catch_unwind(move || f());
104 if res.is_err() {
105 error!(name=%to_run.name, this_dt=?this, "Cron job panicked");
106 }
107 let next = to_run.next(inner.timezone.clone());
108 info!(name=%next.name, this_dt=?this, next_dt=?next.dt, "Ran job (Async job is running in background)");
109 lock.push(next);
110 }
111 let t = lock.peek().map(|s| s.dt.with_timezone(&Utc) - now).unwrap_or(DateTime::<Utc>::MAX_UTC - now);
112 drop(lock);
113 trace!(sec=t.num_seconds(), "Sleep cron main loop until timeout or added job");
114 let _ = timeout(t.to_std().unwrap(), inner.notify.notified()).await;
116 }
117 });
118 }
119}
120
121struct ScheduledJob<T: TimeZone = Utc> {
122 dt: DateTime<T>,
123 cron: Schedule,
124 func: JobFunction,
125 name: String,
126}
127
128impl<Tz: TimeZone> ScheduledJob<Tz> {
129 pub fn next(mut self, tz: Tz) -> Self {
130 self.dt = self.cron.upcoming(tz).next().unwrap();
131 self
132 }
133}
134
135impl<Tz: TimeZone> PartialEq<ScheduledJob<Tz>> for ScheduledJob<Tz> {
136 fn eq(&self, other: &Self) -> bool {
137 self.dt.eq(&other.dt)
138 }
139}
140
141impl<Tz: TimeZone> Eq for ScheduledJob<Tz> {}
142
143
144impl<Tz: TimeZone> PartialOrd<ScheduledJob<Tz>> for ScheduledJob<Tz> {
145 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
146 Some(self.cmp(other))
147 }
148}
149
150impl<Tz: TimeZone> Ord for ScheduledJob<Tz> {
153 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
154 self.dt.cmp(&other.dt).reverse()
155 }
156}
157
158impl Debug for ScheduledJob {
159 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("ScheduledJob")
161 .field("dt", &self.dt)
162 .field("cron", &self.cron)
163 .field("name", &self.name)
164 .finish()
165 }
166}
167
168pub struct Job {
169 name: String,
170 cron_line: String,
171 func: JobFunction,
172}
173
174fn syncify_job<F, Fut>(name: &str, f: F) -> JobFunction
175 where
176 F: Fn() -> Fut + Send + Sync + 'static,
177 Fut: Future<Output=()> + Send + 'static,
178{
179 let name = name.to_string();
180 Arc::new(move || {
181 let fut = f();
182 let name = name.clone();
183 tokio::spawn(async move {
184 let res = AssertUnwindSafe(fut).catch_unwind().await;
185 if res.is_err() {
186 error!(name=%name, "Cron job panicked during async execution");
187 }
188 });
189 })
190}
191
192impl Job {
193 pub fn new<S, F, Fut>(cron: S, func: F) -> Self
194 where
195 F: Fn() -> Fut + Send + Sync + 'static,
196 Fut: Future<Output=()> + Send + 'static,
197 S: Into<String>,
198 {
199 Self {
200 name: "".to_string(),
201 cron_line: cron.into(),
202 func: syncify_job("", func),
203 }
204 }
205
206 pub fn new_sync<S, F>(cron: S, func: F) -> Self
208 where
209 F: Fn() -> () + Send + Sync + 'static,
210 S: Into<String> {
211 Self {
212 name: "".to_string(),
213 cron_line: cron.into(),
214 func: Arc::new(func),
215 }
216 }
217
218 pub fn named<S, F, Fut>(name: &str, cron: S, func: F) -> Self
220 where
221 F: Fn() -> Fut + Send + Sync + 'static,
222 Fut: Future<Output=()> + Send + 'static,
223 S: Into<String>,
224 {
225 Self {
226 name: name.to_string(),
227 cron_line: cron.into(),
228 func: syncify_job(name, func),
229 }
230 }
231
232 pub fn named_sync<S, F>(name: &str, cron: S, func: F) -> Self
233 where
234 F: Fn() -> () + Send + Sync + 'static,
235 S: Into<String>,
236 {
237 Self {
238 name: name.to_string(),
239 cron_line: cron.into(),
240 func: Arc::new(func),
241 }
242 }
243}
244
245pub fn daily(hour_spec: &str) -> String {
254 format!("0 0 {} * * * *", hour_spec)
255}
256
257pub fn hourly(minute_spec: &str) -> String {
266 format!("0 {} * * * * *", minute_spec)
267}
268
269pub fn weekly(week_spec: &str, hour_spec: &str) -> String {
278 format!("0 0 {hour_spec} * * {week_spec} *")
279}
280
281pub fn monthly(day_spec: &str, hour_spec: &str) -> String {
289 format!("0 0 {hour_spec} {day_spec} * * *")
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use std::sync::atomic::{AtomicUsize, Ordering};
296
297 #[tokio::test]
298 async fn it_works() {
299 tracing::subscriber::set_global_default(tracing_subscriber::FmtSubscriber::builder()
300 .with_max_level(tracing::Level::TRACE)
301 .finish()
302 ).unwrap();
303
304 async fn async_func() {
305 println!("Hello, world!");
306 }
307
308 let mut scheduler = Scheduler::local();
309 let counter = Arc::new(AtomicUsize::new(0));
310
311 let c = counter.clone();
317 scheduler.add(Job::new("*/2 * * * * *", move || {
318 let counter = c.clone();
319 async move {
320 counter.fetch_add(1, Ordering::SeqCst);
321 println!("Hello, world!");
322 }
323 }));
324
325 scheduler.add(Job::new_sync("*/1 * * * * *", move || {
327 println!("Hello, world!");
328 }));
329
330 scheduler.add(Job::new("*/1 * * * * *", async_func));
332 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
333 let result = counter.clone().load(Ordering::SeqCst);
334 assert!(result <= 2 && result >= 1);
336 }
337
338 #[tokio::test]
339 async fn test_fancy() {
340 async fn async_fn_with_args(counter: Arc<AtomicUsize>) {
341 counter.fetch_add(1, Ordering::SeqCst);
342 }
343
344 let mut scheduler = Scheduler::local();
345 let counter = Arc::new(AtomicUsize::new(0));
346
347 scheduler.add(Job::named_sync("foo", hourly("1"), move || {
348 println!("One minute into the hour!");
349 }));
350
351 scheduler.add(Job::named("foo", hourly("2"), move || {
352 async move {
353 println!("Two minutes into the hour!");
354 }
355 }));
356
357 let c = counter.clone();
358 scheduler.add(Job::named("increase-counter", "*/2 * * * * * *", move || {
359 async_fn_with_args(c.clone())
360 }));
361
362 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
363 let result = counter.clone().load(Ordering::SeqCst);
364 assert!(result <= 2 && result >= 1);
366 }
367
368 #[tokio::test]
369 async fn test_panic_doesnt_take_everything_down() {
370 let mut scheduler = Scheduler::local();
371
372 scheduler.add(Job::named_sync("causes-panic", "* * * * * * *", move || {
373 panic!("This should not take down the scheduler!");
374 }));
375
376 scheduler.add(Job::named("panics-in-async", "* * * * * * *", move || {
377 async move {
378 panic!("This should not take down the scheduler!");
379 }
380 }));
381
382 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
383 }
384}