1use chrono::{DateTime, Datelike, Duration, TimeZone, Timelike, Utc};
31use std::{collections::BTreeSet, error::Error, fmt, num, str::FromStr};
32
33#[derive(Debug)]
34pub enum ParseError {
35 InvalidCron,
36 InvalidRange,
37 InvalidValue,
38 ParseIntError(num::ParseIntError),
39 TryFromIntError(num::TryFromIntError),
40 InvalidTimezone,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44enum Dow {
45 Sun = 0,
46 Mon = 1,
47 Tue = 2,
48 Wed = 3,
49 Thu = 4,
50 Fri = 5,
51 Sat = 6,
52}
53
54impl FromStr for Dow {
55 type Err = ();
56
57 fn from_str(s: &str) -> Result<Self, Self::Err> {
58 match &*s.to_uppercase() {
59 "SUN" => Ok(Self::Sun),
60 "MON" => Ok(Self::Mon),
61 "TUE" => Ok(Self::Tue),
62 "WED" => Ok(Self::Wed),
63 "THU" => Ok(Self::Thu),
64 "FRI" => Ok(Self::Fri),
65 "SAT" => Ok(Self::Sat),
66 _ => Err(()),
67 }
68 }
69}
70
71impl fmt::Display for ParseError {
72 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
73 match *self {
74 Self::InvalidCron => write!(f, "invalid cron"),
75 Self::InvalidRange => write!(f, "invalid input"),
76 Self::InvalidValue => write!(f, "invalid value"),
77 Self::ParseIntError(ref err) => err.fmt(f),
78 Self::TryFromIntError(ref err) => err.fmt(f),
79 Self::InvalidTimezone => write!(f, "invalid timezone"),
80 }
81 }
82}
83
84impl Error for ParseError {}
85
86impl From<num::ParseIntError> for ParseError {
87 fn from(err: num::ParseIntError) -> Self {
88 Self::ParseIntError(err)
89 }
90}
91
92impl From<num::TryFromIntError> for ParseError {
93 fn from(err: num::TryFromIntError) -> Self {
94 Self::TryFromIntError(err)
95 }
96}
97
98pub fn parse<TZ: TimeZone>(cron: &str, dt: &DateTime<TZ>) -> Result<DateTime<TZ>, ParseError> {
126 let tz = dt.timezone();
127
128 let fields: Vec<&str> = cron.split_whitespace().collect();
129 let [
130 minute_str,
131 hour_str,
132 day_of_month_str,
133 month_str,
134 day_of_week_str,
135 ] = fields.as_slice()
136 else {
137 return Err(ParseError::InvalidCron);
138 };
139
140 let mut next = match Utc.from_local_datetime(&dt.naive_local()) {
141 chrono::LocalResult::Single(datetime) => datetime + Duration::minutes(1),
142 chrono::LocalResult::Ambiguous(earlier, _later) => earlier + Duration::minutes(1),
143 chrono::LocalResult::None => return Err(ParseError::InvalidTimezone),
144 };
145
146 next = make_utc_datetime(
147 next.year(),
148 next.month(),
149 next.day(),
150 next.hour(),
151 next.minute(),
152 0,
153 )?;
154
155 let result = loop {
156 if next.year() - dt.year() > 4 {
158 return Err(ParseError::InvalidCron);
159 }
160
161 let month = parse_field(month_str, 1, 12)?;
163 if !month.contains(&next.month()) {
164 next = make_utc_datetime(
165 if next.month() == 12 {
166 next.year() + 1
167 } else {
168 next.year()
169 },
170 if next.month() == 12 {
171 1
172 } else {
173 next.month() + 1
174 },
175 1,
176 0,
177 0,
178 0,
179 )?;
180 continue;
181 }
182
183 let do_m = parse_field(day_of_month_str, 1, 31)?;
185 if !do_m.contains(&next.day()) {
186 next += Duration::days(1);
187 next = make_utc_datetime(next.year(), next.month(), next.day(), 0, 0, 0)?;
188 continue;
189 }
190
191 let hour = parse_field(hour_str, 0, 23)?;
193 if !hour.contains(&next.hour()) {
194 next += Duration::hours(1);
195 next = make_utc_datetime(next.year(), next.month(), next.day(), next.hour(), 0, 0)?;
196 continue;
197 }
198
199 let minute = parse_field(minute_str, 0, 59)?;
201 if !minute.contains(&next.minute()) {
202 next += Duration::minutes(1);
203 continue;
204 }
205
206 let do_w = parse_field(day_of_week_str, 0, 6)?;
208 if !do_w.contains(&next.weekday().num_days_from_sunday()) {
209 next += Duration::days(1);
210 continue;
211 }
212
213 match tz.from_local_datetime(&next.naive_local()) {
215 chrono::LocalResult::Single(dt) => break dt,
216 chrono::LocalResult::Ambiguous(earlier, _later) => break earlier,
217 chrono::LocalResult::None => {
218 next += Duration::minutes(1);
219 }
220 }
221 };
222
223 Ok(result)
224}
225
226pub fn parse_field(field: &str, min: u32, max: u32) -> Result<BTreeSet<u32>, ParseError> {
283 let mut values = BTreeSet::<u32>::new();
284
285 let fields: Vec<&str> = field.split(',').filter(|s| !s.is_empty()).collect();
287
288 for field in fields {
290 match field {
291 "*" => {
293 for i in min..=max {
294 values.insert(i);
295 }
296 }
297
298 f if f.starts_with("*/") => {
300 let step: u32 = f.trim_start_matches("*/").parse()?;
301
302 if step == 0 || step > max {
303 return Err(ParseError::InvalidValue);
304 }
305
306 for i in (min..=max).step_by(step as usize) {
307 values.insert(i);
308 }
309 }
310
311 f if f.contains('/') => {
313 let tmp_fields: Vec<&str> = f.split('/').collect();
314 let [range_part, step_part] = tmp_fields.as_slice() else {
315 return Err(ParseError::InvalidRange);
316 };
317
318 let step: u32 = step_part.parse()?;
320
321 if step == 0 || step > max {
322 return Err(ParseError::InvalidValue);
323 }
324
325 if range_part.contains('-') {
327 let tmp_range: Vec<&str> = range_part.split('-').collect();
328 let [start_str, end_str] = tmp_range.as_slice() else {
329 return Err(ParseError::InvalidRange);
330 };
331
332 let start = parse_cron_value(start_str, min, max)?;
333 let end = parse_cron_value(end_str, min, max)?;
334
335 if start > end {
336 return Err(ParseError::InvalidRange);
337 }
338
339 for i in (start..=end).step_by(step as usize) {
340 values.insert(i);
341 }
342 } else {
343 let start = parse_cron_value(range_part, min, max)?;
344
345 for i in (start..=max).step_by(step as usize) {
346 values.insert(i);
347 }
348 }
349 }
350
351 f if f.contains('-') => {
353 let tmp_fields: Vec<&str> = f.split('-').collect();
354 let [start_str, end_str] = tmp_fields.as_slice() else {
355 return Err(ParseError::InvalidRange);
356 };
357
358 let start = parse_cron_value(start_str, min, max)?;
359 let end = parse_cron_value(end_str, min, max)?;
360
361 if start > end {
362 return Err(ParseError::InvalidRange);
363 }
364 for i in start..=end {
365 values.insert(i);
366 }
367 }
368
369 _ => {
371 let value = parse_cron_value(field, min, max)?;
372 values.insert(value);
373 }
374 }
375 }
376
377 Ok(values)
378}
379
380fn parse_cron_value(value: &str, min: u32, max: u32) -> Result<u32, ParseError> {
382 if let Ok(dow) = Dow::from_str(value) {
383 Ok(dow as u32)
384 } else {
385 let v: u32 = value.parse()?;
386 if v < min || v > max {
387 return Err(ParseError::InvalidValue);
388 }
389 Ok(v)
390 }
391}
392
393fn make_utc_datetime(
395 year: i32,
396 month: u32,
397 day: u32,
398 hour: u32,
399 minute: u32,
400 second: u32,
401) -> Result<DateTime<Utc>, ParseError> {
402 match Utc.with_ymd_and_hms(year, month, day, hour, minute, second) {
403 chrono::LocalResult::Single(datetime) => Ok(datetime),
404 chrono::LocalResult::Ambiguous(earlier, _later) => Ok(earlier),
405 chrono::LocalResult::None => Err(ParseError::InvalidTimezone),
406 }
407}
408
409#[cfg(test)]
410#[allow(clippy::expect_used)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_make_utc_datetime_valid() {
416 let result = make_utc_datetime(2024, 1, 15, 10, 30, 45);
418 assert!(result.is_ok());
419 let dt = result.expect("Should be valid");
420 assert_eq!(dt.year(), 2024);
421 assert_eq!(dt.month(), 1);
422 assert_eq!(dt.day(), 15);
423 assert_eq!(dt.hour(), 10);
424 assert_eq!(dt.minute(), 30);
425 assert_eq!(dt.second(), 45);
426 }
427
428 #[test]
429 fn test_make_utc_datetime_leap_year() {
430 assert!(make_utc_datetime(2024, 2, 29, 12, 0, 0).is_ok());
432 }
433
434 #[test]
435 fn test_make_utc_datetime_invalid_date() {
436 assert!(make_utc_datetime(2024, 2, 30, 12, 0, 0).is_err());
438
439 assert!(make_utc_datetime(2023, 2, 29, 12, 0, 0).is_err());
441
442 assert!(make_utc_datetime(2024, 4, 31, 12, 0, 0).is_err());
444 }
445
446 #[test]
447 fn test_make_utc_datetime_invalid_time() {
448 assert!(make_utc_datetime(2024, 1, 15, 24, 0, 0).is_err());
450
451 assert!(make_utc_datetime(2024, 1, 15, 12, 60, 0).is_err());
453
454 assert!(make_utc_datetime(2024, 1, 15, 12, 0, 60).is_err());
456 }
457
458 #[test]
459 fn test_make_utc_datetime_invalid_month() {
460 assert!(make_utc_datetime(2024, 0, 15, 12, 0, 0).is_err());
462 assert!(make_utc_datetime(2024, 13, 15, 12, 0, 0).is_err());
463 }
464
465 #[test]
466 fn test_make_utc_datetime_boundary_values() {
467 assert!(make_utc_datetime(2024, 1, 1, 0, 0, 0).is_ok());
469
470 assert!(make_utc_datetime(2024, 1, 1, 23, 59, 59).is_ok());
472
473 assert!(make_utc_datetime(2024, 12, 31, 23, 59, 59).is_ok());
475 }
476
477 #[test]
478 fn test_parse_error_display() {
479 let err = ParseError::InvalidCron;
481 assert_eq!(format!("{err}"), "invalid cron");
482
483 let err = ParseError::InvalidRange;
485 assert_eq!(format!("{err}"), "invalid input");
486
487 let err = ParseError::InvalidValue;
489 assert_eq!(format!("{err}"), "invalid value");
490
491 let parse_int_err = "abc".parse::<u32>().expect_err("Should fail");
493 let err = ParseError::ParseIntError(parse_int_err);
494 assert!(format!("{err}").contains("invalid digit"));
495
496 let try_from_err = u8::try_from(256u32).expect_err("Should fail");
498 let err = ParseError::TryFromIntError(try_from_err);
499 assert!(format!("{err}").contains("out of range"));
500
501 let err = ParseError::InvalidTimezone;
503 assert_eq!(format!("{err}"), "invalid timezone");
504 }
505
506 #[test]
507 fn test_parse_error_from_try_from_int_error() {
508 let try_from_err = u8::try_from(256u32).expect_err("Should fail");
510 let parse_err: ParseError = try_from_err.into();
511 assert!(matches!(parse_err, ParseError::TryFromIntError(_)));
512 }
513
514 #[test]
515 fn test_parse_error_implements_error_trait() {
516 let err: Box<dyn Error> = Box::new(ParseError::InvalidCron);
518 assert_eq!(err.to_string(), "invalid cron");
519 }
520}