use std::time::Duration;
use chrono::{DateTime, Utc};
use crate::{Error, Result};
pub(crate) enum Schedule {
Cron(Box<croner::Cron>),
Interval(Duration),
}
impl Schedule {
pub(crate) fn parse(input: &str) -> Result<Self> {
let trimmed = input.trim();
match trimmed {
"@yearly" | "@annually" => return Self::parse_cron("0 0 0 1 1 *"),
"@monthly" => return Self::parse_cron("0 0 0 1 * *"),
"@weekly" => return Self::parse_cron("0 0 0 * * 0"),
"@daily" | "@midnight" => return Self::parse_cron("0 0 0 * * *"),
"@hourly" => return Self::parse_cron("0 0 * * * *"),
_ => {}
}
if let Some(dur_str) = trimmed.strip_prefix("@every ") {
let duration = parse_duration(dur_str.trim())?;
return Ok(Self::Interval(duration));
}
Self::parse_cron(trimmed)
}
fn parse_cron(expr: &str) -> Result<Self> {
let parser = croner::parser::CronParser::builder()
.seconds(croner::parser::Seconds::Optional)
.build();
let cron = parser.parse(expr).map_err(|e| {
Error::unprocessable_entity(format!("invalid cron expression '{expr}': {e}"))
})?;
Ok(Self::Cron(Box::new(cron)))
}
pub(crate) fn next_tick(&self, from: DateTime<Utc>) -> Option<DateTime<Utc>> {
match self {
Self::Cron(cron) => cron.find_next_occurrence(&from, false).ok(),
Self::Interval(dur) => {
Some(from + chrono::Duration::from_std(*dur).expect("interval duration overflow"))
}
}
}
}
fn parse_duration(s: &str) -> Result<Duration> {
let mut total_secs: u64 = 0;
let mut current_num = String::new();
let mut found_any = false;
for ch in s.chars() {
match ch {
'0'..='9' => current_num.push(ch),
'h' => {
let n: u64 = current_num.parse().map_err(|_| {
Error::unprocessable_entity(format!(
"invalid duration '{s}': bad number before 'h'"
))
})?;
total_secs += n * 3600;
current_num.clear();
found_any = true;
}
'm' => {
let n: u64 = current_num.parse().map_err(|_| {
Error::unprocessable_entity(format!(
"invalid duration '{s}': bad number before 'm'"
))
})?;
total_secs += n * 60;
current_num.clear();
found_any = true;
}
's' => {
let n: u64 = current_num.parse().map_err(|_| {
Error::unprocessable_entity(format!(
"invalid duration '{s}': bad number before 's'"
))
})?;
total_secs += n;
current_num.clear();
found_any = true;
}
_ => {
return Err(Error::unprocessable_entity(format!(
"invalid duration '{s}': unexpected character '{ch}'"
)));
}
}
}
if !current_num.is_empty() {
return Err(Error::unprocessable_entity(format!(
"invalid duration '{s}': trailing number without unit (use h, m, or s)"
)));
}
if !found_any {
return Err(Error::unprocessable_entity(format!(
"invalid duration '{s}': no duration components found"
)));
}
Ok(Duration::from_secs(total_secs))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_duration_hours() {
assert_eq!(parse_duration("2h").unwrap(), Duration::from_secs(7200));
}
#[test]
fn parse_duration_minutes() {
assert_eq!(parse_duration("15m").unwrap(), Duration::from_secs(900));
}
#[test]
fn parse_duration_seconds() {
assert_eq!(parse_duration("30s").unwrap(), Duration::from_secs(30));
}
#[test]
fn parse_duration_combined() {
assert_eq!(parse_duration("1h30m").unwrap(), Duration::from_secs(5400));
assert_eq!(
parse_duration("2h30m15s").unwrap(),
Duration::from_secs(9015)
);
}
#[test]
fn parse_duration_rejects_days() {
let err = parse_duration("1d").err().unwrap();
assert!(err.message().contains("invalid duration"));
}
#[test]
fn parse_duration_rejects_ms() {
let err = parse_duration("500ms").err().unwrap();
assert!(err.message().contains("invalid duration"));
}
#[test]
fn parse_duration_rejects_bare_number() {
let err = parse_duration("30").err().unwrap();
assert!(err.message().contains("trailing number without unit"));
}
#[test]
fn parse_duration_rejects_empty() {
let err = parse_duration("").err().unwrap();
assert!(err.message().contains("no duration components"));
}
#[test]
fn parse_named_aliases() {
let now = Utc::now();
for alias in &[
"@yearly",
"@annually",
"@monthly",
"@weekly",
"@daily",
"@midnight",
"@hourly",
] {
let s = Schedule::parse(alias).unwrap();
assert!(
matches!(s, Schedule::Cron(_)),
"{alias} should parse as Cron"
);
assert!(
s.next_tick(now).is_some(),
"{alias} should have a future tick"
);
}
}
#[test]
fn parse_every_interval() {
let s = Schedule::parse("@every 5m").unwrap();
match s {
Schedule::Interval(dur) => assert_eq!(dur, Duration::from_secs(300)),
_ => panic!("expected Interval variant"),
}
}
#[test]
fn parse_standard_cron() {
let s = Schedule::parse("0 30 * * * *").unwrap();
assert!(matches!(s, Schedule::Cron(_)));
assert!(s.next_tick(Utc::now()).is_some());
}
#[test]
fn parse_invalid_cron_returns_error() {
let err = Schedule::parse("not a cron").err().unwrap();
assert!(err.message().contains("invalid cron expression"));
}
}