easy_schedule/
lib.rs

1use async_trait::async_trait;
2use std::fmt::{self, Debug};
3use time::{
4    Date, OffsetDateTime, Time,
5    macros::{format_description, offset},
6};
7use tokio::{
8    select,
9    time::{Duration, Instant, sleep, sleep_until},
10};
11use tokio_util::sync::CancellationToken;
12use tracing::instrument;
13
14pub mod prelude {
15    pub use super::{Notifiable, Scheduler, Skip, Task};
16    pub use async_trait::async_trait;
17    pub use tokio_util::sync::CancellationToken;
18}
19
20#[derive(Debug, Clone)]
21pub enum Skip {
22    /// skip fixed date
23    Date(Date),
24    /// skip date range
25    DateRange(Date, Date),
26    /// skip days
27    ///
28    /// 1: Monday, 2: Tuesday, 3: Wednesday, 4: Thursday, 5: Friday, 6: Saturday, 7: Sunday
29    Day(Vec<u8>),
30    /// skip days range
31    ///
32    /// 1: Monday, 2: Tuesday, 3: Wednesday, 4: Thursday, 5: Friday, 6: Saturday, 7: Sunday
33    DayRange(usize, usize),
34    /// skip fixed time
35    Time(Time),
36    /// skip time range
37    ///
38    /// end must be greater than start
39    TimeRange(Time, Time),
40    /// no skip
41    None,
42}
43
44impl Default for Skip {
45    fn default() -> Self {
46        Self::None
47    }
48}
49
50impl fmt::Display for Skip {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        match self {
53            Skip::Date(date) => write!(f, "date: {}", date),
54            Skip::DateRange(start, end) => write!(f, "date range: {} - {}", start, end),
55            Skip::Day(day) => write!(f, "day: {:?}", day),
56            Skip::DayRange(start, end) => write!(f, "day range: {} - {}", start, end),
57            Skip::Time(time) => write!(f, "time: {}", time),
58            Skip::TimeRange(start, end) => write!(f, "time range: {} - {}", start, end),
59            Skip::None => write!(f, "none"),
60        }
61    }
62}
63
64impl Skip {
65    /// check if the time is skipped
66    pub fn is_skip(&self, time: OffsetDateTime) -> bool {
67        match self {
68            Skip::Date(date) => time.date() == *date,
69            Skip::DateRange(start, end) => time.date() >= *start && time.date() <= *end,
70            Skip::Day(day) => day.contains(&(time.day() + 1)),
71            Skip::DayRange(start, end) => {
72                time.day() + 1 >= *start as u8 && time.day() + 1 <= *end as u8
73            }
74            Skip::Time(time) => time.hour() == time.hour() && time.minute() == time.minute(),
75            Skip::TimeRange(start, end) => {
76                assert!(start < end, "start must be less than end");
77                time.hour() >= start.hour()
78                    && time.hour() <= end.hour()
79                    && time.minute() >= start.minute()
80                    && time.minute() <= end.minute()
81            }
82            Skip::None => false,
83        }
84    }
85}
86
87#[derive(Debug, Clone)]
88pub enum Task {
89    /// wait seconds
90    Wait(u64, Option<Vec<Skip>>),
91    /// interval seconds
92    Interval(u64, Option<Vec<Skip>>),
93    /// at time
94    At(Time, Option<Vec<Skip>>),
95    /// exact time
96    Once(OffsetDateTime),
97}
98
99impl PartialEq for Task {
100    fn eq(&self, other: &Self) -> bool {
101        match (self, other) {
102            (Task::Wait(a, _), Task::Wait(b, _)) => a == b,
103            (Task::Interval(a, _), Task::Interval(b, _)) => a == b,
104            (Task::At(a, _), Task::At(b, _)) => a == b,
105            (Task::Once(a), Task::Once(b)) => a == b,
106            _ => false,
107        }
108    }
109}
110
111impl From<&str> for Task {
112    ///
113    /// - wait=10
114    /// - interval=10
115    /// - at=10:00
116    /// - once=2024-01-01 10:00:00
117    fn from(s: &str) -> Self {
118        let parts = s.split("=").collect::<Vec<&str>>();
119        let task = parts[0];
120        let value = parts[1..].join("");
121        match task {
122            "wait" => {
123                let seconds = value.parse::<u64>().unwrap();
124                Task::Wait(seconds, None)
125            }
126            "interval" => {
127                let seconds = value.parse::<u64>().unwrap();
128                Task::Interval(seconds, None)
129            }
130            "at" => {
131                let format = format_description!("[hour]:[minute]");
132                let time = Time::parse(&value, &format).expect("parse time failed");
133                Task::At(time, None)
134            }
135            "once" => {
136                let format = format_description!(
137                    "[year]-[month]-[day] [hour]:[minute]:[second] [offset_hour sign:mandatory]"
138                );
139                let datetime =
140                    OffsetDateTime::parse(&value, &format).expect("parse datetime failed");
141                Task::Once(datetime)
142            }
143            _ => Task::Wait(5, None),
144        }
145    }
146}
147
148impl From<String> for Task {
149    fn from(s: String) -> Self {
150        Self::from(s.as_str())
151    }
152}
153
154impl From<&String> for Task {
155    fn from(s: &String) -> Self {
156        Self::from(s.as_str())
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_from_string() {
166        let task = Task::from("wait=10");
167        assert_eq!(task, Task::Wait(10, None));
168        let task = Task::from("wait=10".to_string());
169        assert_eq!(task, Task::Wait(10, None));
170        let task = Task::from(&"wait=10".to_string());
171        assert_eq!(task, Task::Wait(10, None));
172    }
173
174    #[test]
175    fn test_from_string_interval() {
176        let task = Task::from("interval=10");
177        assert_eq!(task, Task::Interval(10, None));
178    }
179
180    #[test]
181    fn test_from_string_at() {
182        let task = Task::from("at=10:00");
183        assert_eq!(task, Task::At(Time::from_hms(10, 0, 0).unwrap(), None));
184    }
185
186    #[test]
187    fn test_from_string_once() {
188        let task = Task::from("once=2024-01-01 10:00:00 +08");
189        assert_eq!(
190            task,
191            Task::Once(OffsetDateTime::from_unix_timestamp(1704074400).unwrap())
192        );
193    }
194}
195
196impl fmt::Display for Task {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        match self {
199            Task::Wait(wait, skip) => {
200                let skip = skip
201                    .clone()
202                    .unwrap_or_default()
203                    .into_iter()
204                    .map(|s| s.to_string())
205                    .collect::<Vec<String>>()
206                    .join(", ");
207                write!(f, "wait: {} {}", wait, skip)
208            }
209            Task::Interval(interval, skip) => {
210                let skip = skip
211                    .clone()
212                    .unwrap_or_default()
213                    .into_iter()
214                    .map(|s| s.to_string())
215                    .collect::<Vec<String>>()
216                    .join(", ");
217                write!(f, "interval: {} {}", interval, skip)
218            }
219            Task::At(time, skip) => {
220                let skip = skip
221                    .clone()
222                    .unwrap_or_default()
223                    .into_iter()
224                    .map(|s| s.to_string())
225                    .collect::<Vec<String>>()
226                    .join(", ");
227                write!(f, "at: {} {}", time, skip)
228            }
229            Task::Once(time) => write!(f, "once: {}", time),
230        }
231    }
232}
233
234/// a task that can be scheduled
235#[async_trait]
236pub trait Notifiable: Sync + Send + Debug {
237    /// get the schedule type
238    fn get_schedule(&self) -> Task;
239
240    /// called when the task is scheduled
241    ///
242    /// Default cancel on first trigger
243    async fn on_time(&self, cancel: CancellationToken) {
244        cancel.cancel();
245    }
246
247    /// called when the task is skipped
248    async fn on_skip(&self, _cancel: CancellationToken) {
249        // do nothing
250    }
251}
252
253pub struct Scheduler {
254    cancel: CancellationToken,
255}
256
257impl Scheduler {
258    /// create a new scheduler
259    pub fn new() -> Self {
260        Self {
261            cancel: CancellationToken::new(),
262        }
263    }
264
265    /// run the task
266    pub async fn run<T: Notifiable + 'static>(&self, task: T) {
267        let schedule = task.get_schedule();
268        let cancel = self.cancel.clone();
269
270        match schedule {
271            Task::Wait(..) => {
272                Scheduler::run_wait(task, cancel.clone()).await;
273            }
274            Task::Interval(..) => {
275                Scheduler::run_interval(task, cancel.clone()).await;
276            }
277            Task::At(..) => {
278                Scheduler::run_at(task, cancel.clone()).await;
279            }
280            Task::Once(..) => {
281                Scheduler::run_once(task, cancel.clone()).await;
282            }
283        }
284    }
285
286    /// stop the scheduler
287    ///
288    /// this will cancel all the tasks
289    pub fn stop(&self) {
290        self.cancel.cancel();
291    }
292
293    /// get the cancel token
294    pub fn get_cancel(&self) -> CancellationToken {
295        self.cancel.clone()
296    }
297}
298
299fn get_next_time(now: OffsetDateTime, time: Time) -> OffsetDateTime {
300    let mut next = now.replace_time(time);
301    if next < now {
302        next = next + time::Duration::days(1);
303    }
304    next
305}
306
307fn get_now() -> OffsetDateTime {
308    // FIXME:
309    OffsetDateTime::now_utc().to_offset(offset!(+8))
310}
311
312impl Scheduler {
313    /// run wait task
314    #[instrument(skip(cancel))]
315    async fn run_wait<T: Notifiable + 'static>(task: T, cancel: CancellationToken) {
316        if let Task::Wait(wait, skip) = task.get_schedule() {
317            let task_ref = task;
318            tokio::task::spawn(async move {
319                select! {
320                    _ = cancel.cancelled() => {
321                        return;
322                    }
323                    _ = sleep(Duration::from_secs(wait)) => {
324                        tracing::debug!(wait, "wait seconds");
325                    }
326                };
327                let now = get_now();
328                if let Some(skip) = skip {
329                    if skip.iter().any(|s| s.is_skip(now)) {
330                        task_ref.on_skip(cancel.clone()).await;
331                        return;
332                    }
333                }
334                task_ref.on_time(cancel.clone()).await;
335            });
336        }
337    }
338
339    /// run interval task
340    #[instrument(skip(cancel))]
341    async fn run_interval<T: Notifiable + 'static>(task: T, cancel: CancellationToken) {
342        if let Task::Interval(interval, skip) = task.get_schedule() {
343            let task_ref = task;
344            tokio::task::spawn(async move {
345                loop {
346                    select! {
347                        _ = cancel.cancelled() => {
348                            return;
349                        }
350                        _ = sleep(Duration::from_secs(interval)) => {
351                            tracing::debug!(interval, "interval");
352                        }
353                    };
354                    let now = get_now();
355                    if let Some(ref skip) = skip {
356                        if skip.iter().any(|s| s.is_skip(now)) {
357                            task_ref.on_skip(cancel.clone()).await;
358                            continue;
359                        }
360                    }
361                    task_ref.on_time(cancel.clone()).await;
362                }
363            });
364        }
365    }
366
367    /// run at task
368    #[instrument(skip(cancel))]
369    async fn run_at<T: Notifiable + 'static>(task: T, cancel: CancellationToken) {
370        if let Task::At(time, skip) = task.get_schedule() {
371            let task_ref = task;
372            tokio::task::spawn(async move {
373                let now = get_now();
374                let mut next = get_next_time(now, time);
375                loop {
376                    let now = get_now();
377                    let seconds = (next - now).as_seconds_f64() as u64;
378                    let instant = Instant::now() + Duration::from_secs(seconds);
379                    select! {
380                        _ = cancel.cancelled() => {
381                            return;
382                        }
383                        _ = sleep_until(instant) => {
384                            tracing::debug!("at time");
385                        }
386                    }
387
388                    if let Some(skip) = skip.clone() {
389                        if skip.iter().any(|s| s.is_skip(now)) {
390                            task_ref.on_skip(cancel.clone()).await;
391                            return;
392                        }
393                    }
394
395                    task_ref.on_time(cancel.clone()).await;
396
397                    next += time::Duration::days(1);
398                }
399            });
400        }
401    }
402
403    /// run once task
404    #[instrument(skip(task, cancel))]
405    async fn run_once<T: Notifiable + 'static>(task: T, cancel: CancellationToken) {
406        if let Task::Once(next) = task.get_schedule() {
407            let task_ref = task;
408            tokio::task::spawn(async move {
409                let now = get_now();
410                if next < now {
411                    task_ref.on_skip(cancel.clone()).await;
412                    return;
413                }
414                let seconds = (next - now).as_seconds_f64();
415                let instant = Instant::now() + Duration::from_secs(seconds as u64);
416
417                select! {
418                    _ = cancel.cancelled() => {
419                        return;
420                    }
421                    _ = sleep_until(instant) => {
422                        tracing::debug!("once time");
423                    }
424                }
425                task_ref.on_time(cancel.clone()).await;
426            });
427        }
428    }
429}