easy_schedule/
lib.rs

1use async_trait::async_trait;
2use std::fmt::{self, Debug};
3use time::{Date, OffsetDateTime, Time, macros::format_description};
4use tokio::{
5    select,
6    time::{Duration, Instant, sleep, sleep_until},
7};
8use tokio_util::sync::CancellationToken;
9use tracing::{error, instrument};
10
11pub mod prelude {
12    pub use super::{Notifiable, Scheduler, Skip, Task};
13    pub use async_trait::async_trait;
14    pub use tokio_util::sync::CancellationToken;
15}
16
17#[derive(Debug, Clone)]
18pub enum Skip {
19    /// skip fixed date
20    Date(Date),
21    /// skip date range
22    DateRange(Date, Date),
23    /// skip days
24    ///
25    /// 1: Monday, 2: Tuesday, 3: Wednesday, 4: Thursday, 5: Friday, 6: Saturday, 7: Sunday
26    Day(Vec<u8>),
27    /// skip days range
28    ///
29    /// 1: Monday, 2: Tuesday, 3: Wednesday, 4: Thursday, 5: Friday, 6: Saturday, 7: Sunday
30    DayRange(usize, usize),
31    /// skip fixed time
32    Time(Time),
33    /// skip time range
34    ///
35    /// end must be greater than start
36    TimeRange(Time, Time),
37    /// no skip
38    None,
39}
40
41impl Default for Skip {
42    fn default() -> Self {
43        Self::None
44    }
45}
46
47impl fmt::Display for Skip {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        match self {
50            Skip::Date(date) => write!(f, "date: {}", date),
51            Skip::DateRange(start, end) => write!(f, "date range: {} - {}", start, end),
52            Skip::Day(day) => write!(f, "day: {:?}", day),
53            Skip::DayRange(start, end) => write!(f, "day range: {} - {}", start, end),
54            Skip::Time(time) => write!(f, "time: {}", time),
55            Skip::TimeRange(start, end) => write!(f, "time range: {} - {}", start, end),
56            Skip::None => write!(f, "none"),
57        }
58    }
59}
60
61impl Skip {
62    /// check if the time is skipped
63    pub fn is_skip(&self, time: OffsetDateTime) -> bool {
64        match self {
65            Skip::Date(date) => time.date() == *date,
66            Skip::DateRange(start, end) => time.date() >= *start && time.date() <= *end,
67            Skip::Day(day) => day.contains(&(time.day() + 1)),
68            Skip::DayRange(start, end) => {
69                time.day() + 1 >= *start as u8 && time.day() + 1 <= *end as u8
70            }
71            Skip::Time(time) => time.hour() == time.hour() && time.minute() == time.minute(),
72            Skip::TimeRange(start, end) => {
73                assert!(start < end, "start must be less than end");
74                time.hour() >= start.hour()
75                    && time.hour() <= end.hour()
76                    && time.minute() >= start.minute()
77                    && time.minute() <= end.minute()
78            }
79            Skip::None => false,
80        }
81    }
82}
83
84#[derive(Debug, Clone)]
85pub enum Task {
86    /// wait seconds
87    Wait(u64, Option<Vec<Skip>>),
88    /// interval seconds
89    Interval(u64, Option<Vec<Skip>>),
90    /// at time
91    At(Time, Option<Vec<Skip>>),
92    /// exact time
93    Once(OffsetDateTime),
94}
95
96impl PartialEq for Task {
97    fn eq(&self, other: &Self) -> bool {
98        match (self, other) {
99            (Task::Wait(a, _), Task::Wait(b, _)) => a == b,
100            (Task::Interval(a, _), Task::Interval(b, _)) => a == b,
101            (Task::At(a, _), Task::At(b, _)) => a == b,
102            (Task::Once(a), Task::Once(b)) => a == b,
103            _ => false,
104        }
105    }
106}
107
108impl From<&str> for Task {
109    ///
110    /// - wait=10
111    /// - interval=10
112    /// - at=10:00
113    /// - once=2024-01-01 10:00:00
114    fn from(s: &str) -> Self {
115        let parts = s.split("=").collect::<Vec<&str>>();
116        let task = parts[0];
117        let value = parts[1..].join("");
118        match task {
119            "wait" => {
120                let seconds = value.parse::<u64>().unwrap();
121                Task::Wait(seconds, None)
122            }
123            "interval" => {
124                let seconds = value.parse::<u64>().unwrap();
125                Task::Interval(seconds, None)
126            }
127            "at" => {
128                let format = format_description!("[hour]:[minute]");
129                let time = Time::parse(&value, &format).expect("parse time failed");
130                Task::At(time, None)
131            }
132            "once" => {
133                let format = format_description!(
134                    "[year]-[month]-[day] [hour]:[minute]:[second] [offset_hour sign:mandatory]"
135                );
136                println!("value: {}", value);
137                let datetime =
138                    OffsetDateTime::parse(&value, &format).expect("parse datetime failed");
139                Task::Once(datetime)
140            }
141            _ => {
142                panic!("invalid task: {}", task);
143            }
144        }
145    }
146}
147
148impl From<String> for Task {
149    fn from(s: String) -> Self {
150        Self::from(s.as_str())
151    }
152}
153
154impl From<&String> for Task {
155    fn from(s: &String) -> Self {
156        Self::from(s.as_str())
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_from_string() {
166        let task = Task::from("wait=10");
167        assert_eq!(task, Task::Wait(10, None));
168        let task = Task::from("wait=10".to_string());
169        assert_eq!(task, Task::Wait(10, None));
170        let task = Task::from(&"wait=10".to_string());
171        assert_eq!(task, Task::Wait(10, None));
172    }
173
174    #[test]
175    fn test_from_string_interval() {
176        let task = Task::from("interval=10");
177        assert_eq!(task, Task::Interval(10, None));
178    }
179
180    #[test]
181    fn test_from_string_at() {
182        let task = Task::from("at=10:00");
183        assert_eq!(task, Task::At(Time::from_hms(10, 0, 0).unwrap(), None));
184    }
185
186    #[test]
187    fn test_from_string_once() {
188        let task = Task::from("once=2024-01-01 10:00:00 +08");
189        assert_eq!(
190            task,
191            Task::Once(OffsetDateTime::from_unix_timestamp(1704074400).unwrap())
192        );
193    }
194}
195
196impl fmt::Display for Task {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        match self {
199            Task::Wait(wait, skip) => {
200                let skip = skip
201                    .clone()
202                    .unwrap_or_default()
203                    .into_iter()
204                    .map(|s| s.to_string())
205                    .collect::<Vec<String>>()
206                    .join(", ");
207                write!(f, "wait: {} {}", wait, skip)
208            }
209            Task::Interval(interval, skip) => {
210                let skip = skip
211                    .clone()
212                    .unwrap_or_default()
213                    .into_iter()
214                    .map(|s| s.to_string())
215                    .collect::<Vec<String>>()
216                    .join(", ");
217                write!(f, "interval: {} {}", interval, skip)
218            }
219            Task::At(time, skip) => {
220                let skip = skip
221                    .clone()
222                    .unwrap_or_default()
223                    .into_iter()
224                    .map(|s| s.to_string())
225                    .collect::<Vec<String>>()
226                    .join(", ");
227                write!(f, "at: {} {}", time, skip)
228            }
229            Task::Once(time) => write!(f, "once: {}", time),
230        }
231    }
232}
233
234/// a task that can be scheduled
235#[async_trait]
236pub trait Notifiable: Sync + Send {
237    /// get the schedule type
238    fn get_schedule(&self) -> Task;
239
240    /// called when the task is scheduled
241    ///
242    /// Default cancel on first trigger
243    async fn on_time(&self, cancel: CancellationToken) {
244        cancel.cancel();
245    }
246
247    /// called when the task is skipped
248    async fn on_skip(&self, _cancel: CancellationToken) {
249        // do nothing
250    }
251}
252
253pub struct Scheduler {
254    cancel: CancellationToken,
255}
256
257impl Scheduler {
258    /// create a new scheduler
259    pub fn new() -> Self {
260        Self {
261            cancel: CancellationToken::new(),
262        }
263    }
264
265    /// run the task
266    pub async fn run<T: Notifiable + 'static>(&self, task: T) {
267        let schedule = task.get_schedule();
268        let cancel = self.cancel.clone();
269
270        match schedule {
271            Task::Wait(..) => {
272                Scheduler::run_wait(task, cancel.clone()).await;
273            }
274            Task::Interval(..) => {
275                Scheduler::run_interval(task, cancel.clone()).await;
276            }
277            Task::At(..) => {
278                Scheduler::run_at(task, cancel.clone()).await;
279            }
280            Task::Once(..) => {
281                Scheduler::run_once(task, cancel.clone()).await;
282            }
283        }
284    }
285
286    /// stop the scheduler
287    ///
288    /// this will cancel all the tasks
289    pub fn stop(&self) {
290        self.cancel.cancel();
291    }
292
293    /// get the cancel token
294    pub fn get_cancel(&self) -> CancellationToken {
295        self.cancel.clone()
296    }
297}
298
299fn get_next_time(now: OffsetDateTime, time: Time) -> OffsetDateTime {
300    let mut next = now.replace_time(time);
301    if next < now {
302        next = next + time::Duration::days(1);
303    }
304    next
305}
306
307fn get_now() -> Option<OffsetDateTime> {
308    match OffsetDateTime::now_local() {
309        Ok(now) => Some(now),
310        Err(e) => {
311            error!("failed to get local time: {}", e);
312            None
313        }
314    }
315}
316
317impl Scheduler {
318    /// run wait task
319    #[instrument(skip(task, cancel))]
320    async fn run_wait<T: Notifiable + 'static>(task: T, cancel: CancellationToken) {
321        if let Task::Wait(wait, skip) = task.get_schedule() {
322            let task_ref = task;
323            tokio::task::spawn(async move {
324                select! {
325                    _ = cancel.cancelled() => {
326                        return;
327                    }
328                    _ = sleep(Duration::from_secs(wait)) => {
329                        tracing::debug!(wait, "wait seconds");
330                    }
331                };
332                if let Some(now) = get_now() {
333                    if let Some(skip) = skip {
334                        if skip.iter().any(|s| s.is_skip(now)) {
335                            task_ref.on_skip(cancel.clone()).await;
336                            return;
337                        }
338                    }
339                    task_ref.on_time(cancel.clone()).await;
340                }
341            });
342        }
343    }
344
345    /// run interval task
346    #[instrument(skip(task, cancel))]
347    async fn run_interval<T: Notifiable + 'static>(task: T, cancel: CancellationToken) {
348        if let Task::Interval(interval, skip) = task.get_schedule() {
349            let task_ref = task;
350            tokio::task::spawn(async move {
351                loop {
352                    select! {
353                        _ = cancel.cancelled() => {
354                            return;
355                        }
356                        _ = sleep(Duration::from_secs(interval)) => {
357                            tracing::debug!(interval, "interval");
358                        }
359                    };
360                    if let Some(now) = get_now() {
361                        if let Some(ref skip) = skip {
362                            if skip.iter().any(|s| s.is_skip(now)) {
363                                task_ref.on_skip(cancel.clone()).await;
364                                continue;
365                            }
366                        }
367                        task_ref.on_time(cancel.clone()).await;
368                    }
369                }
370            });
371        }
372    }
373
374    /// run at task
375    #[instrument(skip(task, cancel))]
376    async fn run_at<T: Notifiable + 'static>(task: T, cancel: CancellationToken) {
377        if let Task::At(time, skip) = task.get_schedule() {
378            let task_ref = task;
379            tokio::task::spawn(async move {
380                let now = if let Some(now) = get_now() {
381                    now
382                } else {
383                    return;
384                };
385                let mut next = get_next_time(now, time);
386                loop {
387                    let now = if let Some(now) = get_now() {
388                        now
389                    } else {
390                        return;
391                    };
392                    let seconds = (next - now).as_seconds_f64() as u64;
393                    let instant = Instant::now() + Duration::from_secs(seconds);
394                    select! {
395                        _ = cancel.cancelled() => {
396                            return;
397                        }
398                        _ = sleep_until(instant) => {
399                            tracing::debug!("at time");
400                        }
401                    }
402
403                    if let Some(skip) = skip.clone() {
404                        if skip.iter().any(|s| s.is_skip(now)) {
405                            task_ref.on_skip(cancel.clone()).await;
406                            return;
407                        }
408                    }
409
410                    task_ref.on_time(cancel.clone()).await;
411
412                    next += time::Duration::days(1);
413                }
414            });
415        }
416    }
417
418    /// run once task
419    #[instrument(skip(task, cancel))]
420    async fn run_once<T: Notifiable + 'static>(task: T, cancel: CancellationToken) {
421        if let Task::Once(next) = task.get_schedule() {
422            let task_ref = task;
423            tokio::task::spawn(async move {
424                if let Some(now) = get_now() {
425                    if next < now {
426                        task_ref.on_skip(cancel.clone()).await;
427                        return;
428                    }
429                    let seconds = (next - now).as_seconds_f64() as u64;
430                    let instant = Instant::now() + Duration::from_secs(seconds);
431
432                    select! {
433                        _ = cancel.cancelled() => {
434                            return;
435                        }
436                        _ = sleep_until(instant) => {
437                            tracing::debug!("once time");
438                        }
439                    }
440                    task_ref.on_time(cancel.clone()).await;
441                }
442            });
443        }
444    }
445}