1use async_trait::async_trait;
2use std::fmt::{self, Debug};
3use time::{Date, OffsetDateTime, Time};
4use tokio::{
5 select,
6 time::{Duration, Instant, sleep, sleep_until},
7};
8pub use tokio_util::sync::CancellationToken;
9use tracing::{error, instrument};
10
11#[derive(Debug, Clone)]
12pub enum Skip {
13 Date(Date),
15 DateRange(Date, Date),
17 Day(Vec<u8>),
21 DayRange(usize, usize),
25 Time(Time),
27 TimeRange(Time, Time),
31 None,
33}
34
35impl Default for Skip {
36 fn default() -> Self {
37 Self::None
38 }
39}
40
41impl fmt::Display for Skip {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 match self {
44 Skip::Date(date) => write!(f, "date: {}", date),
45 Skip::DateRange(start, end) => write!(f, "date range: {} - {}", start, end),
46 Skip::Day(day) => write!(f, "day: {:?}", day),
47 Skip::DayRange(start, end) => write!(f, "day range: {} - {}", start, end),
48 Skip::Time(time) => write!(f, "time: {}", time),
49 Skip::TimeRange(start, end) => write!(f, "time range: {} - {}", start, end),
50 Skip::None => write!(f, "none"),
51 }
52 }
53}
54
55impl Skip {
56 pub fn is_skip(&self, time: OffsetDateTime) -> bool {
58 match self {
59 Skip::Date(date) => time.date() == *date,
60 Skip::DateRange(start, end) => time.date() >= *start && time.date() <= *end,
61 Skip::Day(day) => day.contains(&(time.day() + 1)),
62 Skip::DayRange(start, end) => {
63 time.day() + 1 >= *start as u8 && time.day() + 1 <= *end as u8
64 }
65 Skip::Time(time) => time.hour() == time.hour() && time.minute() == time.minute(),
66 Skip::TimeRange(start, end) => {
67 assert!(start < end, "start must be less than end");
68 time.hour() >= start.hour()
69 && time.hour() <= end.hour()
70 && time.minute() >= start.minute()
71 && time.minute() <= end.minute()
72 }
73 Skip::None => false,
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
79pub enum Task {
80 Wait(u64, Option<Vec<Skip>>),
82 Interval(u64, Option<Vec<Skip>>),
84 At(Time, Option<Vec<Skip>>),
86 Once(OffsetDateTime),
88}
89
90impl fmt::Display for Task {
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 match self {
93 Task::Wait(wait, skip) => {
94 let skip = skip
95 .clone()
96 .unwrap_or_default()
97 .into_iter()
98 .map(|s| s.to_string())
99 .collect::<Vec<String>>()
100 .join(", ");
101 write!(f, "wait: {} {}", wait, skip)
102 }
103 Task::Interval(interval, skip) => {
104 let skip = skip
105 .clone()
106 .unwrap_or_default()
107 .into_iter()
108 .map(|s| s.to_string())
109 .collect::<Vec<String>>()
110 .join(", ");
111 write!(f, "interval: {} {}", interval, skip)
112 }
113 Task::At(time, skip) => {
114 let skip = skip
115 .clone()
116 .unwrap_or_default()
117 .into_iter()
118 .map(|s| s.to_string())
119 .collect::<Vec<String>>()
120 .join(", ");
121 write!(f, "at: {} {}", time, skip)
122 }
123 Task::Once(time) => write!(f, "once: {}", time),
124 }
125 }
126}
127
128#[async_trait]
130pub trait ScheduledTask: Sync + Send {
131 fn get_schedule(&self) -> Task;
133
134 async fn on_time(&self, cancel: CancellationToken);
136
137 async fn on_skip(&self, cancel: CancellationToken);
139}
140
141pub struct Scheduler {
142 cancel: CancellationToken,
143}
144
145impl Scheduler {
146 pub fn new() -> Self {
148 Self {
149 cancel: CancellationToken::new(),
150 }
151 }
152
153 pub async fn start<T: ScheduledTask + 'static>(&self, task: T) {
155 let schedule = task.get_schedule();
156 let cancel = self.cancel.clone();
157
158 match schedule {
159 Task::Wait(..) => {
160 Scheduler::run_wait(task, cancel.clone()).await;
161 }
162 Task::Interval(..) => {
163 Scheduler::run_interval(task, cancel.clone()).await;
164 }
165 Task::At(..) => {
166 Scheduler::run_at(task, cancel.clone()).await;
167 }
168 Task::Once(..) => {
169 Scheduler::run_once(task, cancel.clone()).await;
170 }
171 }
172 }
173
174 pub fn stop(&self) {
178 self.cancel.cancel();
179 }
180
181 pub fn get_cancel(&self) -> CancellationToken {
183 self.cancel.clone()
184 }
185}
186
187fn get_next_time(now: OffsetDateTime, time: Time) -> OffsetDateTime {
188 let mut next = now.replace_time(time);
189 if next < now {
190 next = next + time::Duration::days(1);
191 }
192 next
193}
194
195fn get_now() -> Option<OffsetDateTime> {
196 match OffsetDateTime::now_local() {
197 Ok(now) => Some(now),
198 Err(e) => {
199 error!("failed to get local time: {}", e);
200 None
201 }
202 }
203}
204
205impl Scheduler {
206 #[instrument(skip(task, cancel))]
208 async fn run_wait<T: ScheduledTask + 'static>(task: T, cancel: CancellationToken) {
209 if let Task::Wait(wait, skip) = task.get_schedule() {
210 let task_ref = task;
211 tokio::task::spawn(async move {
212 select! {
213 _ = cancel.cancelled() => {
214 return;
215 }
216 _ = sleep(Duration::from_secs(wait)) => {
217 tracing::debug!(wait, "wait seconds");
218 }
219 };
220 if let Some(now) = get_now() {
221 if let Some(skip) = skip {
222 if skip.iter().any(|s| s.is_skip(now)) {
223 task_ref.on_skip(cancel.clone()).await;
224 return;
225 }
226 }
227 task_ref.on_time(cancel.clone()).await;
228 }
229 });
230 }
231 }
232
233 #[instrument(skip(task, cancel))]
235 async fn run_interval<T: ScheduledTask + 'static>(task: T, cancel: CancellationToken) {
236 if let Task::Interval(interval, skip) = task.get_schedule() {
237 let task_ref = task;
238 tokio::task::spawn(async move {
239 loop {
240 select! {
241 _ = cancel.cancelled() => {
242 return;
243 }
244 _ = sleep(Duration::from_secs(interval)) => {
245 tracing::debug!(interval, "interval");
246 }
247 };
248 if let Some(now) = get_now() {
249 if let Some(ref skip) = skip {
250 if skip.iter().any(|s| s.is_skip(now)) {
251 task_ref.on_skip(cancel.clone()).await;
252 continue;
253 }
254 }
255 task_ref.on_time(cancel.clone()).await;
256 }
257 }
258 });
259 }
260 }
261
262 #[instrument(skip(task, cancel))]
264 async fn run_at<T: ScheduledTask + 'static>(task: T, cancel: CancellationToken) {
265 if let Task::At(time, skip) = task.get_schedule() {
266 let task_ref = task;
267 tokio::task::spawn(async move {
268 let now = if let Some(now) = get_now() {
269 now
270 } else {
271 return;
272 };
273 let mut next = get_next_time(now, time);
274 loop {
275 let now = if let Some(now) = get_now() {
276 now
277 } else {
278 return;
279 };
280 let seconds = (next - now).as_seconds_f64() as u64;
281 let instant = Instant::now() + Duration::from_secs(seconds);
282 select! {
283 _ = cancel.cancelled() => {
284 return;
285 }
286 _ = sleep_until(instant) => {
287 tracing::debug!("at time");
288 }
289 }
290
291 if let Some(skip) = skip.clone() {
292 if skip.iter().any(|s| s.is_skip(now)) {
293 task_ref.on_skip(cancel.clone()).await;
294 return;
295 }
296 }
297
298 task_ref.on_time(cancel.clone()).await;
299
300 next += time::Duration::days(1);
301 }
302 });
303 }
304 }
305
306 #[instrument(skip(task, cancel))]
308 async fn run_once<T: ScheduledTask + 'static>(task: T, cancel: CancellationToken) {
309 if let Task::Once(next) = task.get_schedule() {
310 let task_ref = task;
311 tokio::task::spawn(async move {
312 if let Some(now) = get_now() {
313 if next < now {
314 task_ref.on_skip(cancel.clone()).await;
315 return;
316 }
317 let seconds = (next - now).as_seconds_f64() as u64;
318 let instant = Instant::now() + Duration::from_secs(seconds);
319
320 select! {
321 _ = cancel.cancelled() => {
322 return;
323 }
324 _ = sleep_until(instant) => {
325 tracing::debug!("once time");
326 }
327 }
328 task_ref.on_time(cancel.clone()).await;
329 }
330 });
331 }
332 }
333}