1use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19use std::time::{Duration, Instant};
20use tokio::sync::mpsc;
21
22const MAX_TASKS: usize = 50;
24const DEFAULT_INTERVAL_SECS: u64 = 600; const MAX_RECURRING_AGE: Duration = Duration::from_secs(3 * 24 * 3600); const MAX_JITTER: Duration = Duration::from_secs(15 * 60); pub struct ScheduledTask {
35 pub id: String,
36 pub prompt: String,
37 pub interval: Duration,
38 pub recurring: bool,
39 pub created_at: Instant,
40 pub next_fire: Instant,
41 pub fire_count: usize,
42}
43
44pub struct ScheduledFire {
46 pub task_id: String,
47 pub prompt: String,
48}
49
50pub struct ScheduledTaskInfo {
52 pub id: String,
53 pub prompt: String,
54 pub interval_secs: u64,
55 pub recurring: bool,
56 pub fire_count: usize,
57 pub next_fire_in_secs: u64,
59}
60
61struct CronSchedulerInner {
64 tasks: HashMap<String, ScheduledTask>,
65}
66
67pub struct CronScheduler {
72 inner: Mutex<CronSchedulerInner>,
73 prompt_tx: mpsc::UnboundedSender<ScheduledFire>,
74}
75
76impl CronScheduler {
77 pub fn new() -> (Arc<Self>, mpsc::UnboundedReceiver<ScheduledFire>) {
81 let (tx, rx) = mpsc::unbounded_channel();
82 let scheduler = Arc::new(Self {
83 inner: Mutex::new(CronSchedulerInner {
84 tasks: HashMap::new(),
85 }),
86 prompt_tx: tx,
87 });
88 (scheduler, rx)
89 }
90
91 pub fn start(scheduler: Arc<Self>) {
96 let weak = Arc::downgrade(&scheduler);
97 drop(scheduler);
98 tokio::spawn(async move {
99 let mut interval = tokio::time::interval(Duration::from_secs(1));
100 interval.tick().await; loop {
102 interval.tick().await;
103 match weak.upgrade() {
104 Some(s) => s.tick(),
105 None => break, }
107 }
108 });
109 }
110
111 fn tick(&self) {
113 let now = Instant::now();
114 let mut to_fire: Vec<(String, String)> = Vec::new();
115 let mut to_remove: Vec<String> = Vec::new();
116
117 {
118 let inner = match self.inner.lock() {
119 Ok(g) => g,
120 Err(_) => return,
121 };
122 for (id, task) in &inner.tasks {
123 if now >= task.next_fire {
124 to_fire.push((id.clone(), task.prompt.clone()));
125 let age = now - task.created_at;
126 if !task.recurring || age >= MAX_RECURRING_AGE {
127 to_remove.push(id.clone());
128 }
129 }
130 }
131 }
132
133 if to_fire.is_empty() {
134 return;
135 }
136
137 {
138 let mut inner = match self.inner.lock() {
139 Ok(g) => g,
140 Err(_) => return,
141 };
142 for (id, prompt) in &to_fire {
143 if let Some(task) = inner.tasks.get_mut(id) {
144 task.fire_count += 1;
145 if task.recurring && !to_remove.contains(id) {
146 let jitter = compute_jitter(id, task.interval);
147 task.next_fire = Instant::now() + task.interval + jitter;
148 }
149 }
150 let _ = self.prompt_tx.send(ScheduledFire {
151 task_id: id.clone(),
152 prompt: prompt.clone(),
153 });
154 }
155 for id in &to_remove {
156 inner.tasks.remove(id);
157 }
158 }
159 }
160
161 pub fn create_task(
163 &self,
164 prompt: String,
165 interval: Duration,
166 recurring: bool,
167 ) -> Result<String, String> {
168 let mut inner = self
169 .inner
170 .lock()
171 .map_err(|_| "scheduler lock poisoned".to_string())?;
172 if inner.tasks.len() >= MAX_TASKS {
173 return Err(format!(
174 "maximum of {MAX_TASKS} scheduled tasks reached; cancel one with /cron-cancel"
175 ));
176 }
177 let id = new_task_id();
178 let jitter = compute_jitter(&id, interval);
179 let now = Instant::now();
180 inner.tasks.insert(
181 id.clone(),
182 ScheduledTask {
183 id: id.clone(),
184 prompt,
185 interval,
186 recurring,
187 created_at: now,
188 next_fire: now + interval + jitter,
189 fire_count: 0,
190 },
191 );
192 Ok(id)
193 }
194
195 pub fn list_tasks(&self) -> Vec<ScheduledTaskInfo> {
197 let inner = match self.inner.lock() {
198 Ok(g) => g,
199 Err(_) => return vec![],
200 };
201 let now = Instant::now();
202 let mut tasks: Vec<_> = inner
203 .tasks
204 .values()
205 .map(|t| ScheduledTaskInfo {
206 id: t.id.clone(),
207 prompt: t.prompt.clone(),
208 interval_secs: t.interval.as_secs(),
209 recurring: t.recurring,
210 fire_count: t.fire_count,
211 next_fire_in_secs: if t.next_fire > now {
212 (t.next_fire - now).as_secs()
213 } else {
214 0
215 },
216 })
217 .collect();
218 tasks.sort_by(|a, b| a.id.cmp(&b.id));
219 tasks
220 }
221
222 pub fn cancel_task(&self, id: &str) -> bool {
224 self.inner
225 .lock()
226 .ok()
227 .map(|mut g| g.tasks.remove(id).is_some())
228 .unwrap_or(false)
229 }
230
231 pub fn task_count(&self) -> usize {
233 self.inner.lock().map(|g| g.tasks.len()).unwrap_or(0)
234 }
235}
236
237fn new_task_id() -> String {
241 let id = uuid::Uuid::new_v4().to_string();
242 id[..8].to_string()
243}
244
245fn compute_jitter(id: &str, interval: Duration) -> Duration {
251 use std::collections::hash_map::DefaultHasher;
252 use std::hash::{Hash, Hasher};
253 let mut h = DefaultHasher::new();
254 id.hash(&mut h);
255 let fraction = (h.finish() % 1000) as f64 / 10000.0; let raw_secs = (interval.as_secs_f64() * fraction) as u64;
257 Duration::from_secs(raw_secs.min(MAX_JITTER.as_secs()))
258}
259
260pub fn parse_interval(s: &str) -> Option<Duration> {
266 if s.len() < 2 {
267 return None;
268 }
269 let (num_part, unit) = s.split_at(s.len() - 1);
270 let n: u64 = num_part.parse().ok()?;
271 match unit {
272 "s" => Some(Duration::from_secs(n)),
273 "m" => Some(Duration::from_secs(n * 60)),
274 "h" => Some(Duration::from_secs(n * 3600)),
275 "d" => Some(Duration::from_secs(n * 86400)),
276 _ => None,
277 }
278}
279
280pub fn parse_loop_args(args: &str) -> (Duration, String) {
289 let args = args.trim();
290
291 if let Some(space) = args.find(char::is_whitespace) {
293 let first = &args[..space];
294 let rest = args[space..].trim();
295 if let Some(interval) = parse_interval(first) {
296 if !rest.is_empty() {
297 return (interval, rest.to_string());
298 }
299 }
300 }
301
302 const EVERY_NEEDLE: &str = " every ";
305 if let Some(every_pos) = find_every_clause(args) {
306 let prompt_part = args[..every_pos].trim();
307 let interval_token = args[every_pos + EVERY_NEEDLE.len()..]
308 .split_whitespace()
309 .next()
310 .unwrap_or("");
311 if let Some(interval) = parse_interval(interval_token) {
312 if !prompt_part.is_empty() {
313 return (interval, prompt_part.to_string());
314 }
315 }
316 }
317
318 (Duration::from_secs(DEFAULT_INTERVAL_SECS), args.to_string())
320}
321
322fn find_every_clause(s: &str) -> Option<usize> {
324 let needle = " every ";
325 let mut best: Option<usize> = None;
326 let mut search_from = 0;
327 while let Some(rel) = s[search_from..].find(needle) {
328 let abs = search_from + rel;
329 let after = s[abs + needle.len()..]
330 .split_whitespace()
331 .next()
332 .unwrap_or("");
333 if parse_interval(after).is_some() {
334 best = Some(abs);
335 }
336 search_from = abs + 1;
337 }
338 best
339}
340
341pub fn format_duration(secs: u64) -> String {
343 if secs == 0 {
344 return "now".to_string();
345 }
346 let d = secs / 86400;
347 let h = (secs % 86400) / 3600;
348 let m = (secs % 3600) / 60;
349 let s = secs % 60;
350 let mut parts: Vec<String> = Vec::new();
351 if d > 0 {
352 parts.push(format!("{d}d"));
353 }
354 if h > 0 {
355 parts.push(format!("{h}h"));
356 }
357 if m > 0 {
358 parts.push(format!("{m}m"));
359 }
360 if s > 0 && parts.is_empty() {
361 parts.push(format!("{s}s"));
362 }
363 parts.join(" ")
364}
365
366#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn test_parse_interval() {
374 assert_eq!(parse_interval("30s"), Some(Duration::from_secs(30)));
375 assert_eq!(parse_interval("5m"), Some(Duration::from_secs(300)));
376 assert_eq!(parse_interval("2h"), Some(Duration::from_secs(7200)));
377 assert_eq!(parse_interval("1d"), Some(Duration::from_secs(86400)));
378 assert_eq!(parse_interval("bad"), None);
379 assert_eq!(parse_interval(""), None);
380 assert_eq!(parse_interval("m"), None);
381 assert_eq!(parse_interval("0m"), Some(Duration::from_secs(0)));
382 }
383
384 #[test]
385 fn test_parse_loop_args_leading_interval() {
386 let (interval, prompt) = parse_loop_args("5m check the deployment");
387 assert_eq!(interval, Duration::from_secs(300));
388 assert_eq!(prompt, "check the deployment");
389 }
390
391 #[test]
392 fn test_parse_loop_args_trailing_every() {
393 let (interval, prompt) = parse_loop_args("monitor memory usage every 2h");
394 assert_eq!(interval, Duration::from_secs(7200));
395 assert_eq!(prompt, "monitor memory usage");
396 }
397
398 #[test]
399 fn test_parse_loop_args_default_interval() {
400 let (interval, prompt) = parse_loop_args("check the build");
401 assert_eq!(interval, Duration::from_secs(600));
402 assert_eq!(prompt, "check the build");
403 }
404
405 #[test]
406 fn test_parse_loop_args_leading_seconds() {
407 let (interval, prompt) = parse_loop_args("30s ping the server");
408 assert_eq!(interval, Duration::from_secs(30));
409 assert_eq!(prompt, "ping the server");
410 }
411
412 #[test]
413 fn test_format_duration() {
414 assert_eq!(format_duration(0), "now");
415 assert_eq!(format_duration(30), "30s");
416 assert_eq!(format_duration(300), "5m");
417 assert_eq!(format_duration(3600), "1h");
418 assert_eq!(format_duration(3660), "1h 1m");
419 assert_eq!(format_duration(86400), "1d");
420 assert_eq!(format_duration(90000), "1d 1h");
421 }
422
423 #[test]
424 fn test_create_and_list_tasks() {
425 let (scheduler, _rx) = CronScheduler::new();
426 let id = scheduler
427 .create_task("hello".to_string(), Duration::from_secs(60), true)
428 .unwrap();
429 assert_eq!(id.len(), 8);
430 let tasks = scheduler.list_tasks();
431 assert_eq!(tasks.len(), 1);
432 assert_eq!(tasks[0].prompt, "hello");
433 assert_eq!(tasks[0].interval_secs, 60);
434 assert!(tasks[0].recurring);
435 assert_eq!(tasks[0].fire_count, 0);
436 }
437
438 #[test]
439 fn test_cancel_task() {
440 let (scheduler, _rx) = CronScheduler::new();
441 let id = scheduler
442 .create_task("test".to_string(), Duration::from_secs(60), true)
443 .unwrap();
444 assert!(scheduler.cancel_task(&id));
445 assert!(!scheduler.cancel_task(&id)); assert_eq!(scheduler.list_tasks().len(), 0);
447 }
448
449 #[test]
450 fn test_max_tasks_limit() {
451 let (scheduler, _rx) = CronScheduler::new();
452 for i in 0..MAX_TASKS {
453 scheduler
454 .create_task(format!("task {i}"), Duration::from_secs(60), true)
455 .unwrap();
456 }
457 let err = scheduler
458 .create_task("overflow".to_string(), Duration::from_secs(60), true)
459 .unwrap_err();
460 assert!(err.contains("maximum"));
461 }
462
463 #[test]
464 fn test_compute_jitter_deterministic() {
465 let j1 = compute_jitter("abc123", Duration::from_secs(600));
466 let j2 = compute_jitter("abc123", Duration::from_secs(600));
467 assert_eq!(j1, j2);
468 }
469
470 #[test]
471 fn test_compute_jitter_bounded() {
472 for id in &["aaa", "bbb", "ccc", "ddd"] {
473 let jitter = compute_jitter(id, Duration::from_secs(600));
474 assert!(jitter <= Duration::from_secs(60)); }
476 let jitter = compute_jitter("aaa", Duration::from_secs(10 * 3600));
478 assert!(jitter <= MAX_JITTER);
479 }
480}