use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
const MAX_TASKS: usize = 50;
const DEFAULT_INTERVAL_SECS: u64 = 600; const MAX_RECURRING_AGE: Duration = Duration::from_secs(3 * 24 * 3600); const MAX_JITTER: Duration = Duration::from_secs(15 * 60);
pub struct ScheduledTask {
pub id: String,
pub prompt: String,
pub interval: Duration,
pub recurring: bool,
pub created_at: Instant,
pub next_fire: Instant,
pub fire_count: usize,
}
pub struct ScheduledFire {
pub task_id: String,
pub prompt: String,
}
pub struct ScheduledTaskInfo {
pub id: String,
pub prompt: String,
pub interval_secs: u64,
pub recurring: bool,
pub fire_count: usize,
pub next_fire_in_secs: u64,
}
struct CronSchedulerInner {
tasks: HashMap<String, ScheduledTask>,
}
pub struct CronScheduler {
inner: Mutex<CronSchedulerInner>,
prompt_tx: mpsc::UnboundedSender<ScheduledFire>,
stopped: AtomicBool,
}
impl CronScheduler {
pub fn new() -> (Arc<Self>, mpsc::UnboundedReceiver<ScheduledFire>) {
let (tx, rx) = mpsc::unbounded_channel();
let scheduler = Arc::new(Self {
inner: Mutex::new(CronSchedulerInner {
tasks: HashMap::new(),
}),
prompt_tx: tx,
stopped: AtomicBool::new(false),
});
(scheduler, rx)
}
pub fn start(scheduler: Arc<Self>) {
let weak = Arc::downgrade(&scheduler);
drop(scheduler);
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(1));
interval.tick().await; loop {
interval.tick().await;
match weak.upgrade() {
Some(s) => {
if s.stopped.load(Ordering::Relaxed) {
break;
}
s.tick();
}
None => break, }
}
});
}
pub fn stop(&self) {
self.stopped.store(true, Ordering::Relaxed);
if let Ok(mut inner) = self.inner.lock() {
inner.tasks.clear();
}
}
fn tick(&self) {
let now = Instant::now();
let mut to_fire: Vec<(String, String)> = Vec::new();
let mut to_remove: Vec<String> = Vec::new();
{
let inner = match self.inner.lock() {
Ok(g) => g,
Err(_) => return,
};
for (id, task) in &inner.tasks {
if now >= task.next_fire {
to_fire.push((id.clone(), task.prompt.clone()));
let age = now - task.created_at;
if !task.recurring || age >= MAX_RECURRING_AGE {
to_remove.push(id.clone());
}
}
}
}
if to_fire.is_empty() {
return;
}
{
let mut inner = match self.inner.lock() {
Ok(g) => g,
Err(_) => return,
};
for (id, prompt) in &to_fire {
if let Some(task) = inner.tasks.get_mut(id) {
task.fire_count += 1;
if task.recurring && !to_remove.contains(id) {
let jitter = compute_jitter(id, task.interval);
task.next_fire = Instant::now() + task.interval + jitter;
}
}
let _ = self.prompt_tx.send(ScheduledFire {
task_id: id.clone(),
prompt: prompt.clone(),
});
}
for id in &to_remove {
inner.tasks.remove(id);
}
}
}
pub fn create_task(
&self,
prompt: String,
interval: Duration,
recurring: bool,
) -> Result<String, String> {
let mut inner = self
.inner
.lock()
.map_err(|_| "scheduler lock poisoned".to_string())?;
if inner.tasks.len() >= MAX_TASKS {
return Err(format!(
"maximum of {MAX_TASKS} scheduled tasks reached; cancel one with /cron-cancel"
));
}
let id = new_task_id();
let jitter = compute_jitter(&id, interval);
let now = Instant::now();
inner.tasks.insert(
id.clone(),
ScheduledTask {
id: id.clone(),
prompt,
interval,
recurring,
created_at: now,
next_fire: now + interval + jitter,
fire_count: 0,
},
);
Ok(id)
}
pub fn list_tasks(&self) -> Vec<ScheduledTaskInfo> {
let inner = match self.inner.lock() {
Ok(g) => g,
Err(_) => return vec![],
};
let now = Instant::now();
let mut tasks: Vec<_> = inner
.tasks
.values()
.map(|t| ScheduledTaskInfo {
id: t.id.clone(),
prompt: t.prompt.clone(),
interval_secs: t.interval.as_secs(),
recurring: t.recurring,
fire_count: t.fire_count,
next_fire_in_secs: if t.next_fire > now {
(t.next_fire - now).as_secs()
} else {
0
},
})
.collect();
tasks.sort_by(|a, b| a.id.cmp(&b.id));
tasks
}
pub fn cancel_task(&self, id: &str) -> bool {
self.inner
.lock()
.ok()
.map(|mut g| g.tasks.remove(id).is_some())
.unwrap_or(false)
}
pub fn task_count(&self) -> usize {
self.inner.lock().map(|g| g.tasks.len()).unwrap_or(0)
}
}
fn new_task_id() -> String {
let id = uuid::Uuid::new_v4().to_string();
id[..8].to_string()
}
fn compute_jitter(id: &str, interval: Duration) -> Duration {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut h = DefaultHasher::new();
id.hash(&mut h);
let fraction = (h.finish() % 1000) as f64 / 10000.0; let raw_secs = (interval.as_secs_f64() * fraction) as u64;
Duration::from_secs(raw_secs.min(MAX_JITTER.as_secs()))
}
pub fn parse_interval(s: &str) -> Option<Duration> {
if s.len() < 2 {
return None;
}
let (num_part, unit) = s.split_at(s.len() - 1);
let n: u64 = num_part.parse().ok()?;
match unit {
"s" => Some(Duration::from_secs(n)),
"m" => Some(Duration::from_secs(n * 60)),
"h" => Some(Duration::from_secs(n * 3600)),
"d" => Some(Duration::from_secs(n * 86400)),
_ => None,
}
}
pub fn parse_loop_args(args: &str) -> (Duration, String) {
let args = args.trim();
if let Some(space) = args.find(char::is_whitespace) {
let first = &args[..space];
let rest = args[space..].trim();
if let Some(interval) = parse_interval(first) {
if !rest.is_empty() {
return (interval, rest.to_string());
}
}
}
const EVERY_NEEDLE: &str = " every ";
if let Some(every_pos) = find_every_clause(args) {
let prompt_part = args[..every_pos].trim();
let interval_token = args[every_pos + EVERY_NEEDLE.len()..]
.split_whitespace()
.next()
.unwrap_or("");
if let Some(interval) = parse_interval(interval_token) {
if !prompt_part.is_empty() {
return (interval, prompt_part.to_string());
}
}
}
(Duration::from_secs(DEFAULT_INTERVAL_SECS), args.to_string())
}
fn find_every_clause(s: &str) -> Option<usize> {
let needle = " every ";
let mut best: Option<usize> = None;
let mut search_from = 0;
while let Some(rel) = s[search_from..].find(needle) {
let abs = search_from + rel;
let after = s[abs + needle.len()..]
.split_whitespace()
.next()
.unwrap_or("");
if parse_interval(after).is_some() {
best = Some(abs);
}
search_from = abs + 1;
}
best
}
pub fn format_duration(secs: u64) -> String {
if secs == 0 {
return "now".to_string();
}
let d = secs / 86400;
let h = (secs % 86400) / 3600;
let m = (secs % 3600) / 60;
let s = secs % 60;
let mut parts: Vec<String> = Vec::new();
if d > 0 {
parts.push(format!("{d}d"));
}
if h > 0 {
parts.push(format!("{h}h"));
}
if m > 0 {
parts.push(format!("{m}m"));
}
if s > 0 && parts.is_empty() {
parts.push(format!("{s}s"));
}
parts.join(" ")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_interval() {
assert_eq!(parse_interval("30s"), Some(Duration::from_secs(30)));
assert_eq!(parse_interval("5m"), Some(Duration::from_secs(300)));
assert_eq!(parse_interval("2h"), Some(Duration::from_secs(7200)));
assert_eq!(parse_interval("1d"), Some(Duration::from_secs(86400)));
assert_eq!(parse_interval("bad"), None);
assert_eq!(parse_interval(""), None);
assert_eq!(parse_interval("m"), None);
assert_eq!(parse_interval("0m"), Some(Duration::from_secs(0)));
}
#[test]
fn test_parse_loop_args_leading_interval() {
let (interval, prompt) = parse_loop_args("5m check the deployment");
assert_eq!(interval, Duration::from_secs(300));
assert_eq!(prompt, "check the deployment");
}
#[test]
fn test_parse_loop_args_trailing_every() {
let (interval, prompt) = parse_loop_args("monitor memory usage every 2h");
assert_eq!(interval, Duration::from_secs(7200));
assert_eq!(prompt, "monitor memory usage");
}
#[test]
fn test_parse_loop_args_default_interval() {
let (interval, prompt) = parse_loop_args("check the build");
assert_eq!(interval, Duration::from_secs(600));
assert_eq!(prompt, "check the build");
}
#[test]
fn test_parse_loop_args_leading_seconds() {
let (interval, prompt) = parse_loop_args("30s ping the server");
assert_eq!(interval, Duration::from_secs(30));
assert_eq!(prompt, "ping the server");
}
#[test]
fn test_format_duration() {
assert_eq!(format_duration(0), "now");
assert_eq!(format_duration(30), "30s");
assert_eq!(format_duration(300), "5m");
assert_eq!(format_duration(3600), "1h");
assert_eq!(format_duration(3660), "1h 1m");
assert_eq!(format_duration(86400), "1d");
assert_eq!(format_duration(90000), "1d 1h");
}
#[test]
fn test_create_and_list_tasks() {
let (scheduler, _rx) = CronScheduler::new();
let id = scheduler
.create_task("hello".to_string(), Duration::from_secs(60), true)
.unwrap();
assert_eq!(id.len(), 8);
let tasks = scheduler.list_tasks();
assert_eq!(tasks.len(), 1);
assert_eq!(tasks[0].prompt, "hello");
assert_eq!(tasks[0].interval_secs, 60);
assert!(tasks[0].recurring);
assert_eq!(tasks[0].fire_count, 0);
}
#[test]
fn test_cancel_task() {
let (scheduler, _rx) = CronScheduler::new();
let id = scheduler
.create_task("test".to_string(), Duration::from_secs(60), true)
.unwrap();
assert!(scheduler.cancel_task(&id));
assert!(!scheduler.cancel_task(&id)); assert_eq!(scheduler.list_tasks().len(), 0);
}
#[test]
fn test_max_tasks_limit() {
let (scheduler, _rx) = CronScheduler::new();
for i in 0..MAX_TASKS {
scheduler
.create_task(format!("task {i}"), Duration::from_secs(60), true)
.unwrap();
}
let err = scheduler
.create_task("overflow".to_string(), Duration::from_secs(60), true)
.unwrap_err();
assert!(err.contains("maximum"));
}
#[test]
fn test_compute_jitter_deterministic() {
let j1 = compute_jitter("abc123", Duration::from_secs(600));
let j2 = compute_jitter("abc123", Duration::from_secs(600));
assert_eq!(j1, j2);
}
#[test]
fn test_compute_jitter_bounded() {
for id in &["aaa", "bbb", "ccc", "ddd"] {
let jitter = compute_jitter(id, Duration::from_secs(600));
assert!(jitter <= Duration::from_secs(60)); }
let jitter = compute_jitter("aaa", Duration::from_secs(10 * 3600));
assert!(jitter <= MAX_JITTER);
}
}