use chrono::{DateTime, Datelike, Duration, TimeZone, Timelike, Utc};
use std::{collections::BTreeSet, error::Error, fmt, num, str::FromStr};
#[derive(Debug)]
pub enum ParseError {
InvalidCron,
InvalidRange,
InvalidValue,
ParseIntError(num::ParseIntError),
TryFromIntError(num::TryFromIntError),
}
enum Dow {
Sun = 0,
Mon = 1,
Tue = 2,
Wed = 3,
Thu = 4,
Fri = 5,
Sat = 6,
}
impl FromStr for Dow {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match &*s.to_uppercase() {
"SUN" => Ok(Self::Sun),
"MON" => Ok(Self::Mon),
"TUE" => Ok(Self::Tue),
"WED" => Ok(Self::Wed),
"THU" => Ok(Self::Thu),
"FRI" => Ok(Self::Fri),
"SAT" => Ok(Self::Sat),
_ => Err(()),
}
}
}
impl fmt::Display for ParseError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::InvalidCron => write!(f, "invalid cron"),
Self::InvalidRange => write!(f, "invalid input"),
Self::InvalidValue => write!(f, "invalid value"),
Self::ParseIntError(ref err) => err.fmt(f),
Self::TryFromIntError(ref err) => err.fmt(f),
}
}
}
impl Error for ParseError {}
impl From<num::ParseIntError> for ParseError {
fn from(err: num::ParseIntError) -> Self {
Self::ParseIntError(err)
}
}
impl From<num::TryFromIntError> for ParseError {
fn from(err: num::TryFromIntError) -> Self {
Self::TryFromIntError(err)
}
}
pub fn parse<TZ: TimeZone>(cron: &str, dt: &DateTime<TZ>) -> Result<DateTime<TZ>, ParseError> {
let tz = dt.timezone();
let mut next = Utc.from_local_datetime(&dt.naive_local()).unwrap() + Duration::minutes(1);
let fields: Vec<&str> = cron.split_whitespace().collect();
if fields.len() > 5 {
return Err(ParseError::InvalidCron);
}
next = Utc
.ymd(next.year(), next.month(), next.day())
.and_hms(next.hour(), next.minute(), 0);
let result = loop {
if next.year() - dt.year() > 4 {
return Err(ParseError::InvalidCron);
}
let month = parse_field(fields[3], 1, 12)?;
if !month.contains(&next.month()) {
if next.month() == 12 {
next = Utc.ymd(next.year() + 1, 1, 1).and_hms(0, 0, 0);
} else {
next = Utc.ymd(next.year(), next.month() + 1, 1).and_hms(0, 0, 0);
}
continue;
}
let do_m = parse_field(fields[2], 1, 31)?;
if !do_m.contains(&next.day()) {
next = next + Duration::days(1);
next = Utc
.ymd(next.year(), next.month(), next.day())
.and_hms(0, 0, 0);
continue;
}
let hour = parse_field(fields[1], 0, 23)?;
if !hour.contains(&next.hour()) {
next = next + Duration::hours(1);
next = Utc
.ymd(next.year(), next.month(), next.day())
.and_hms(next.hour(), 0, 0);
continue;
}
let minute = parse_field(fields[0], 0, 59)?;
if !minute.contains(&next.minute()) {
next = next + Duration::minutes(1);
continue;
}
let do_w = parse_field(fields[4], 0, 6)?;
if !do_w.contains(&next.weekday().num_days_from_sunday()) {
next = next + Duration::days(1);
continue;
}
if let Some(dt) = tz.from_local_datetime(&next.naive_local()).latest() {
break dt;
}
next = next + Duration::minutes(1);
};
Ok(result)
}
pub fn parse_field(field: &str, min: u32, max: u32) -> Result<BTreeSet<u32>, ParseError> {
let mut values = BTreeSet::<u32>::new();
let fields: Vec<&str> = field.split(',').filter(|s| !s.is_empty()).collect();
for field in fields {
match field {
"*" => {
for i in min..=max {
values.insert(i);
}
}
f if field.starts_with("*/") => {
let f: u32 = f.trim_start_matches("*/").parse()?;
if f > max {
return Err(ParseError::InvalidValue);
}
for i in (min..=max).step_by(f as usize).collect::<Vec<u32>>() {
values.insert(i);
}
}
f if f.contains('-') => {
let tmp_fields: Vec<&str> = f.split('-').collect();
if tmp_fields.len() != 2 {
return Err(ParseError::InvalidRange);
}
let mut fields: Vec<u32> = Vec::new();
if let Ok(dow) = Dow::from_str(tmp_fields[0]) {
fields.push(dow as u32);
} else {
fields.push(tmp_fields[0].parse::<u32>()?);
};
if let Ok(dow) = Dow::from_str(tmp_fields[1]) {
fields.push(dow as u32);
} else {
fields.push(tmp_fields[1].parse::<u32>()?);
}
if fields[0] > fields[1] || fields[1] > max {
return Err(ParseError::InvalidRange);
}
for i in (fields[0]..=fields[1]).collect::<Vec<u32>>() {
values.insert(i);
}
}
_ => {
if let Ok(dow) = Dow::from_str(field) {
values.insert(dow as u32);
} else {
let f = field.parse::<u32>()?;
if f > max {
return Err(ParseError::InvalidValue);
}
values.insert(f);
}
}
}
}
Ok(values)
}