1use std::collections::HashMap;
18use std::sync::atomic::{AtomicBool, Ordering};
19use std::sync::{Arc, Mutex};
20use std::time::{Duration, Instant};
21use tokio::sync::mpsc;
22
23const MAX_TASKS: usize = 50;
25const DEFAULT_INTERVAL_SECS: u64 = 600; const MAX_RECURRING_AGE: Duration = Duration::from_secs(3 * 24 * 3600); const MAX_JITTER: Duration = Duration::from_secs(15 * 60); pub struct ScheduledTask {
36 pub id: String,
37 pub prompt: String,
38 pub interval: Duration,
39 pub recurring: bool,
40 pub created_at: Instant,
41 pub next_fire: Instant,
42 pub fire_count: usize,
43}
44
45pub struct ScheduledFire {
47 pub task_id: String,
48 pub prompt: String,
49}
50
51pub struct ScheduledTaskInfo {
53 pub id: String,
54 pub prompt: String,
55 pub interval_secs: u64,
56 pub recurring: bool,
57 pub fire_count: usize,
58 pub next_fire_in_secs: u64,
60}
61
62struct CronSchedulerInner {
65 tasks: HashMap<String, ScheduledTask>,
66}
67
68pub struct CronScheduler {
73 inner: Mutex<CronSchedulerInner>,
74 prompt_tx: mpsc::UnboundedSender<ScheduledFire>,
75 stopped: AtomicBool,
76}
77
78impl CronScheduler {
79 pub fn new() -> (Arc<Self>, mpsc::UnboundedReceiver<ScheduledFire>) {
83 let (tx, rx) = mpsc::unbounded_channel();
84 let scheduler = Arc::new(Self {
85 inner: Mutex::new(CronSchedulerInner {
86 tasks: HashMap::new(),
87 }),
88 prompt_tx: tx,
89 stopped: AtomicBool::new(false),
90 });
91 (scheduler, rx)
92 }
93
94 pub fn start(scheduler: Arc<Self>) {
99 let weak = Arc::downgrade(&scheduler);
100 drop(scheduler);
101 tokio::spawn(async move {
102 let mut interval = tokio::time::interval(Duration::from_secs(1));
103 interval.tick().await; loop {
105 interval.tick().await;
106 match weak.upgrade() {
107 Some(s) => {
108 if s.stopped.load(Ordering::Relaxed) {
109 break;
110 }
111 s.tick();
112 }
113 None => break, }
115 }
116 });
117 }
118
119 pub fn stop(&self) {
121 self.stopped.store(true, Ordering::Relaxed);
122 if let Ok(mut inner) = self.inner.lock() {
123 inner.tasks.clear();
124 }
125 }
126
127 fn tick(&self) {
129 let now = Instant::now();
130 let mut to_fire: Vec<(String, String)> = Vec::new();
131 let mut to_remove: Vec<String> = Vec::new();
132
133 {
134 let inner = match self.inner.lock() {
135 Ok(g) => g,
136 Err(_) => return,
137 };
138 for (id, task) in &inner.tasks {
139 if now >= task.next_fire {
140 to_fire.push((id.clone(), task.prompt.clone()));
141 let age = now - task.created_at;
142 if !task.recurring || age >= MAX_RECURRING_AGE {
143 to_remove.push(id.clone());
144 }
145 }
146 }
147 }
148
149 if to_fire.is_empty() {
150 return;
151 }
152
153 {
154 let mut inner = match self.inner.lock() {
155 Ok(g) => g,
156 Err(_) => return,
157 };
158 for (id, prompt) in &to_fire {
159 if let Some(task) = inner.tasks.get_mut(id) {
160 task.fire_count += 1;
161 if task.recurring && !to_remove.contains(id) {
162 let jitter = compute_jitter(id, task.interval);
163 task.next_fire = Instant::now() + task.interval + jitter;
164 }
165 }
166 let _ = self.prompt_tx.send(ScheduledFire {
167 task_id: id.clone(),
168 prompt: prompt.clone(),
169 });
170 }
171 for id in &to_remove {
172 inner.tasks.remove(id);
173 }
174 }
175 }
176
177 pub fn create_task(
179 &self,
180 prompt: String,
181 interval: Duration,
182 recurring: bool,
183 ) -> Result<String, String> {
184 let mut inner = self
185 .inner
186 .lock()
187 .map_err(|_| "scheduler lock poisoned".to_string())?;
188 if inner.tasks.len() >= MAX_TASKS {
189 return Err(format!(
190 "maximum of {MAX_TASKS} scheduled tasks reached; cancel one with /cron-cancel"
191 ));
192 }
193 let id = new_task_id();
194 let jitter = compute_jitter(&id, interval);
195 let now = Instant::now();
196 inner.tasks.insert(
197 id.clone(),
198 ScheduledTask {
199 id: id.clone(),
200 prompt,
201 interval,
202 recurring,
203 created_at: now,
204 next_fire: now + interval + jitter,
205 fire_count: 0,
206 },
207 );
208 Ok(id)
209 }
210
211 pub fn list_tasks(&self) -> Vec<ScheduledTaskInfo> {
213 let inner = match self.inner.lock() {
214 Ok(g) => g,
215 Err(_) => return vec![],
216 };
217 let now = Instant::now();
218 let mut tasks: Vec<_> = inner
219 .tasks
220 .values()
221 .map(|t| ScheduledTaskInfo {
222 id: t.id.clone(),
223 prompt: t.prompt.clone(),
224 interval_secs: t.interval.as_secs(),
225 recurring: t.recurring,
226 fire_count: t.fire_count,
227 next_fire_in_secs: if t.next_fire > now {
228 (t.next_fire - now).as_secs()
229 } else {
230 0
231 },
232 })
233 .collect();
234 tasks.sort_by(|a, b| a.id.cmp(&b.id));
235 tasks
236 }
237
238 pub fn cancel_task(&self, id: &str) -> bool {
240 self.inner
241 .lock()
242 .ok()
243 .map(|mut g| g.tasks.remove(id).is_some())
244 .unwrap_or(false)
245 }
246
247 pub fn task_count(&self) -> usize {
249 self.inner.lock().map(|g| g.tasks.len()).unwrap_or(0)
250 }
251}
252
253fn new_task_id() -> String {
257 let id = uuid::Uuid::new_v4().to_string();
258 id[..8].to_string()
259}
260
261fn compute_jitter(id: &str, interval: Duration) -> Duration {
267 use std::collections::hash_map::DefaultHasher;
268 use std::hash::{Hash, Hasher};
269 let mut h = DefaultHasher::new();
270 id.hash(&mut h);
271 let fraction = (h.finish() % 1000) as f64 / 10000.0; let raw_secs = (interval.as_secs_f64() * fraction) as u64;
273 Duration::from_secs(raw_secs.min(MAX_JITTER.as_secs()))
274}
275
276pub fn parse_interval(s: &str) -> Option<Duration> {
282 if s.len() < 2 {
283 return None;
284 }
285 let (num_part, unit) = s.split_at(s.len() - 1);
286 let n: u64 = num_part.parse().ok()?;
287 match unit {
288 "s" => Some(Duration::from_secs(n)),
289 "m" => Some(Duration::from_secs(n * 60)),
290 "h" => Some(Duration::from_secs(n * 3600)),
291 "d" => Some(Duration::from_secs(n * 86400)),
292 _ => None,
293 }
294}
295
296pub fn parse_loop_args(args: &str) -> (Duration, String) {
305 let args = args.trim();
306
307 if let Some(space) = args.find(char::is_whitespace) {
309 let first = &args[..space];
310 let rest = args[space..].trim();
311 if let Some(interval) = parse_interval(first) {
312 if !rest.is_empty() {
313 return (interval, rest.to_string());
314 }
315 }
316 }
317
318 const EVERY_NEEDLE: &str = " every ";
321 if let Some(every_pos) = find_every_clause(args) {
322 let prompt_part = args[..every_pos].trim();
323 let interval_token = args[every_pos + EVERY_NEEDLE.len()..]
324 .split_whitespace()
325 .next()
326 .unwrap_or("");
327 if let Some(interval) = parse_interval(interval_token) {
328 if !prompt_part.is_empty() {
329 return (interval, prompt_part.to_string());
330 }
331 }
332 }
333
334 (Duration::from_secs(DEFAULT_INTERVAL_SECS), args.to_string())
336}
337
338fn find_every_clause(s: &str) -> Option<usize> {
340 let needle = " every ";
341 let mut best: Option<usize> = None;
342 let mut search_from = 0;
343 while let Some(rel) = s[search_from..].find(needle) {
344 let abs = search_from + rel;
345 let after = s[abs + needle.len()..]
346 .split_whitespace()
347 .next()
348 .unwrap_or("");
349 if parse_interval(after).is_some() {
350 best = Some(abs);
351 }
352 search_from = abs + 1;
353 }
354 best
355}
356
357pub fn format_duration(secs: u64) -> String {
359 if secs == 0 {
360 return "now".to_string();
361 }
362 let d = secs / 86400;
363 let h = (secs % 86400) / 3600;
364 let m = (secs % 3600) / 60;
365 let s = secs % 60;
366 let mut parts: Vec<String> = Vec::new();
367 if d > 0 {
368 parts.push(format!("{d}d"));
369 }
370 if h > 0 {
371 parts.push(format!("{h}h"));
372 }
373 if m > 0 {
374 parts.push(format!("{m}m"));
375 }
376 if s > 0 && parts.is_empty() {
377 parts.push(format!("{s}s"));
378 }
379 parts.join(" ")
380}
381
382#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
389 fn test_parse_interval() {
390 assert_eq!(parse_interval("30s"), Some(Duration::from_secs(30)));
391 assert_eq!(parse_interval("5m"), Some(Duration::from_secs(300)));
392 assert_eq!(parse_interval("2h"), Some(Duration::from_secs(7200)));
393 assert_eq!(parse_interval("1d"), Some(Duration::from_secs(86400)));
394 assert_eq!(parse_interval("bad"), None);
395 assert_eq!(parse_interval(""), None);
396 assert_eq!(parse_interval("m"), None);
397 assert_eq!(parse_interval("0m"), Some(Duration::from_secs(0)));
398 }
399
400 #[test]
401 fn test_parse_loop_args_leading_interval() {
402 let (interval, prompt) = parse_loop_args("5m check the deployment");
403 assert_eq!(interval, Duration::from_secs(300));
404 assert_eq!(prompt, "check the deployment");
405 }
406
407 #[test]
408 fn test_parse_loop_args_trailing_every() {
409 let (interval, prompt) = parse_loop_args("monitor memory usage every 2h");
410 assert_eq!(interval, Duration::from_secs(7200));
411 assert_eq!(prompt, "monitor memory usage");
412 }
413
414 #[test]
415 fn test_parse_loop_args_default_interval() {
416 let (interval, prompt) = parse_loop_args("check the build");
417 assert_eq!(interval, Duration::from_secs(600));
418 assert_eq!(prompt, "check the build");
419 }
420
421 #[test]
422 fn test_parse_loop_args_leading_seconds() {
423 let (interval, prompt) = parse_loop_args("30s ping the server");
424 assert_eq!(interval, Duration::from_secs(30));
425 assert_eq!(prompt, "ping the server");
426 }
427
428 #[test]
429 fn test_format_duration() {
430 assert_eq!(format_duration(0), "now");
431 assert_eq!(format_duration(30), "30s");
432 assert_eq!(format_duration(300), "5m");
433 assert_eq!(format_duration(3600), "1h");
434 assert_eq!(format_duration(3660), "1h 1m");
435 assert_eq!(format_duration(86400), "1d");
436 assert_eq!(format_duration(90000), "1d 1h");
437 }
438
439 #[test]
440 fn test_create_and_list_tasks() {
441 let (scheduler, _rx) = CronScheduler::new();
442 let id = scheduler
443 .create_task("hello".to_string(), Duration::from_secs(60), true)
444 .unwrap();
445 assert_eq!(id.len(), 8);
446 let tasks = scheduler.list_tasks();
447 assert_eq!(tasks.len(), 1);
448 assert_eq!(tasks[0].prompt, "hello");
449 assert_eq!(tasks[0].interval_secs, 60);
450 assert!(tasks[0].recurring);
451 assert_eq!(tasks[0].fire_count, 0);
452 }
453
454 #[test]
455 fn test_cancel_task() {
456 let (scheduler, _rx) = CronScheduler::new();
457 let id = scheduler
458 .create_task("test".to_string(), Duration::from_secs(60), true)
459 .unwrap();
460 assert!(scheduler.cancel_task(&id));
461 assert!(!scheduler.cancel_task(&id)); assert_eq!(scheduler.list_tasks().len(), 0);
463 }
464
465 #[test]
466 fn test_max_tasks_limit() {
467 let (scheduler, _rx) = CronScheduler::new();
468 for i in 0..MAX_TASKS {
469 scheduler
470 .create_task(format!("task {i}"), Duration::from_secs(60), true)
471 .unwrap();
472 }
473 let err = scheduler
474 .create_task("overflow".to_string(), Duration::from_secs(60), true)
475 .unwrap_err();
476 assert!(err.contains("maximum"));
477 }
478
479 #[test]
480 fn test_compute_jitter_deterministic() {
481 let j1 = compute_jitter("abc123", Duration::from_secs(600));
482 let j2 = compute_jitter("abc123", Duration::from_secs(600));
483 assert_eq!(j1, j2);
484 }
485
486 #[test]
487 fn test_compute_jitter_bounded() {
488 for id in &["aaa", "bbb", "ccc", "ddd"] {
489 let jitter = compute_jitter(id, Duration::from_secs(600));
490 assert!(jitter <= Duration::from_secs(60)); }
492 let jitter = compute_jitter("aaa", Duration::from_secs(10 * 3600));
494 assert!(jitter <= MAX_JITTER);
495 }
496}