Skip to main content

a3s_code_core/
scheduler.rs

1//! Session-scoped prompt scheduler — backs `/loop`, `/cron-list`, `/cron-cancel`.
2//!
3//! Users schedule recurring prompts with `/loop 5m check the deployment`.
4//! The scheduler fires them via a tokio channel; `AgentSession` drains the
5//! channel after each `send()` and runs the fired prompts.
6//!
7//! ## Design
8//!
9//! - `CronScheduler` holds a `std::sync::Mutex<CronSchedulerInner>` so it can
10//!   be accessed from synchronous slash command handlers.
11//! - A background tokio task holds a `Weak<CronScheduler>` and ticks every
12//!   second. It exits automatically when the session is dropped.
13//! - Fired prompts are delivered via `tokio::sync::mpsc::UnboundedSender`.
14//! - Deterministic jitter (0–10% of period, capped at 15 min) is derived
15//!   from the task ID so the same task always fires with the same offset.
16
17use std::collections::HashMap;
18use std::sync::atomic::{AtomicBool, Ordering};
19use std::sync::{Arc, Mutex};
20use std::time::{Duration, Instant};
21use tokio::sync::mpsc;
22
23/// Maximum number of tasks a session may have at one time.
24const MAX_TASKS: usize = 50;
25/// Default interval when none is specified.
26const DEFAULT_INTERVAL_SECS: u64 = 600; // 10 minutes
27/// Recurring tasks auto-expire after this duration.
28const MAX_RECURRING_AGE: Duration = Duration::from_secs(3 * 24 * 3600); // 3 days
29/// Jitter cap.
30const MAX_JITTER: Duration = Duration::from_secs(15 * 60); // 15 minutes
31
32// ─── Public Types ────────────────────────────────────────────────────────────
33
34/// A pending scheduled task.
35pub struct ScheduledTask {
36    pub id: String,
37    pub prompt: String,
38    pub interval: Duration,
39    pub recurring: bool,
40    pub created_at: Instant,
41    pub next_fire: Instant,
42    pub fire_count: usize,
43}
44
45/// A task fire event delivered to `AgentSession`.
46pub struct ScheduledFire {
47    pub task_id: String,
48    pub prompt: String,
49}
50
51/// Snapshot of a task used for display (avoids exposing `Instant` in public API).
52pub struct ScheduledTaskInfo {
53    pub id: String,
54    pub prompt: String,
55    pub interval_secs: u64,
56    pub recurring: bool,
57    pub fire_count: usize,
58    /// Seconds until the next fire (0 if overdue).
59    pub next_fire_in_secs: u64,
60}
61
62// ─── Scheduler ───────────────────────────────────────────────────────────────
63
64struct CronSchedulerInner {
65    tasks: HashMap<String, ScheduledTask>,
66}
67
68/// Session-scoped prompt scheduler.
69///
70/// Create with [`CronScheduler::new`], then call [`CronScheduler::start`] to
71/// launch the background ticker.
72pub struct CronScheduler {
73    inner: Mutex<CronSchedulerInner>,
74    prompt_tx: mpsc::UnboundedSender<ScheduledFire>,
75    stopped: AtomicBool,
76}
77
78impl CronScheduler {
79    /// Create a new scheduler and its receiver channel.
80    ///
81    /// Call [`CronScheduler::start`] after wrapping in `Arc` to start the ticker.
82    pub fn new() -> (Arc<Self>, mpsc::UnboundedReceiver<ScheduledFire>) {
83        let (tx, rx) = mpsc::unbounded_channel();
84        let scheduler = Arc::new(Self {
85            inner: Mutex::new(CronSchedulerInner {
86                tasks: HashMap::new(),
87            }),
88            prompt_tx: tx,
89            stopped: AtomicBool::new(false),
90        });
91        (scheduler, rx)
92    }
93
94    /// Spawn the background 1-second ticker.
95    ///
96    /// The task holds a `Weak` reference so it exits automatically when the
97    /// `Arc<CronScheduler>` is dropped (i.e., when the session is dropped).
98    pub fn start(scheduler: Arc<Self>) {
99        let weak = Arc::downgrade(&scheduler);
100        drop(scheduler);
101        tokio::spawn(async move {
102            let mut interval = tokio::time::interval(Duration::from_secs(1));
103            interval.tick().await; // consume the immediate first tick
104            loop {
105                interval.tick().await;
106                match weak.upgrade() {
107                    Some(s) => {
108                        if s.stopped.load(Ordering::Relaxed) {
109                            break;
110                        }
111                        s.tick();
112                    }
113                    None => break, // session dropped — exit cleanly
114                }
115            }
116        });
117    }
118
119    /// Stop the background ticker and clear all scheduled tasks.
120    pub fn stop(&self) {
121        self.stopped.store(true, Ordering::Relaxed);
122        if let Ok(mut inner) = self.inner.lock() {
123            inner.tasks.clear();
124        }
125    }
126
127    /// Check all tasks and fire any that are due.
128    fn tick(&self) {
129        let now = Instant::now();
130        let mut to_fire: Vec<(String, String)> = Vec::new();
131        let mut to_remove: Vec<String> = Vec::new();
132
133        {
134            let inner = match self.inner.lock() {
135                Ok(g) => g,
136                Err(_) => return,
137            };
138            for (id, task) in &inner.tasks {
139                if now >= task.next_fire {
140                    to_fire.push((id.clone(), task.prompt.clone()));
141                    let age = now - task.created_at;
142                    if !task.recurring || age >= MAX_RECURRING_AGE {
143                        to_remove.push(id.clone());
144                    }
145                }
146            }
147        }
148
149        if to_fire.is_empty() {
150            return;
151        }
152
153        {
154            let mut inner = match self.inner.lock() {
155                Ok(g) => g,
156                Err(_) => return,
157            };
158            for (id, prompt) in &to_fire {
159                if let Some(task) = inner.tasks.get_mut(id) {
160                    task.fire_count += 1;
161                    if task.recurring && !to_remove.contains(id) {
162                        let jitter = compute_jitter(id, task.interval);
163                        task.next_fire = Instant::now() + task.interval + jitter;
164                    }
165                }
166                let _ = self.prompt_tx.send(ScheduledFire {
167                    task_id: id.clone(),
168                    prompt: prompt.clone(),
169                });
170            }
171            for id in &to_remove {
172                inner.tasks.remove(id);
173            }
174        }
175    }
176
177    /// Schedule a new task. Returns the task ID on success.
178    pub fn create_task(
179        &self,
180        prompt: String,
181        interval: Duration,
182        recurring: bool,
183    ) -> Result<String, String> {
184        let mut inner = self
185            .inner
186            .lock()
187            .map_err(|_| "scheduler lock poisoned".to_string())?;
188        if inner.tasks.len() >= MAX_TASKS {
189            return Err(format!(
190                "maximum of {MAX_TASKS} scheduled tasks reached; cancel one with /cron-cancel"
191            ));
192        }
193        let id = new_task_id();
194        let jitter = compute_jitter(&id, interval);
195        let now = Instant::now();
196        inner.tasks.insert(
197            id.clone(),
198            ScheduledTask {
199                id: id.clone(),
200                prompt,
201                interval,
202                recurring,
203                created_at: now,
204                next_fire: now + interval + jitter,
205                fire_count: 0,
206            },
207        );
208        Ok(id)
209    }
210
211    /// List all active tasks sorted by ID.
212    pub fn list_tasks(&self) -> Vec<ScheduledTaskInfo> {
213        let inner = match self.inner.lock() {
214            Ok(g) => g,
215            Err(_) => return vec![],
216        };
217        let now = Instant::now();
218        let mut tasks: Vec<_> = inner
219            .tasks
220            .values()
221            .map(|t| ScheduledTaskInfo {
222                id: t.id.clone(),
223                prompt: t.prompt.clone(),
224                interval_secs: t.interval.as_secs(),
225                recurring: t.recurring,
226                fire_count: t.fire_count,
227                next_fire_in_secs: if t.next_fire > now {
228                    (t.next_fire - now).as_secs()
229                } else {
230                    0
231                },
232            })
233            .collect();
234        tasks.sort_by(|a, b| a.id.cmp(&b.id));
235        tasks
236    }
237
238    /// Cancel a task by ID. Returns `true` if it existed.
239    pub fn cancel_task(&self, id: &str) -> bool {
240        self.inner
241            .lock()
242            .ok()
243            .map(|mut g| g.tasks.remove(id).is_some())
244            .unwrap_or(false)
245    }
246
247    /// Number of active tasks.
248    pub fn task_count(&self) -> usize {
249        self.inner.lock().map(|g| g.tasks.len()).unwrap_or(0)
250    }
251}
252
253// ─── Helpers ─────────────────────────────────────────────────────────────────
254
255/// Generate a short random task ID (8 hex chars via UUID v4).
256fn new_task_id() -> String {
257    let id = uuid::Uuid::new_v4().to_string();
258    id[..8].to_string()
259}
260
261/// Deterministic jitter: 0–10% of `interval`, capped at 15 min.
262///
263/// Derived from the task ID hash so the same task always fires with the
264/// same relative offset (avoids thundering herd when many tasks share the
265/// same interval).
266fn compute_jitter(id: &str, interval: Duration) -> Duration {
267    use std::collections::hash_map::DefaultHasher;
268    use std::hash::{Hash, Hasher};
269    let mut h = DefaultHasher::new();
270    id.hash(&mut h);
271    let fraction = (h.finish() % 1000) as f64 / 10000.0; // 0.000 .. 0.099
272    let raw_secs = (interval.as_secs_f64() * fraction) as u64;
273    Duration::from_secs(raw_secs.min(MAX_JITTER.as_secs()))
274}
275
276// ─── Interval / Arg Parsing ──────────────────────────────────────────────────
277
278/// Parse a duration string like `"5m"`, `"30s"`, `"2h"`, `"1d"`.
279///
280/// Returns `Some(Duration)` on success, `None` if the format is not recognized.
281pub fn parse_interval(s: &str) -> Option<Duration> {
282    if s.len() < 2 {
283        return None;
284    }
285    let (num_part, unit) = s.split_at(s.len() - 1);
286    let n: u64 = num_part.parse().ok()?;
287    match unit {
288        "s" => Some(Duration::from_secs(n)),
289        "m" => Some(Duration::from_secs(n * 60)),
290        "h" => Some(Duration::from_secs(n * 3600)),
291        "d" => Some(Duration::from_secs(n * 86400)),
292        _ => None,
293    }
294}
295
296/// Parse the argument string from `/loop <args>`.
297///
298/// Supports three forms:
299/// - Leading interval: `/loop 5m check the build` → (5m, "check the build")
300/// - Trailing clause:  `/loop check the build every 2h` → (2h, "check the build")
301/// - No interval:      `/loop check the build` → (10m default, "check the build")
302///
303/// Returns `(interval, prompt)`.
304pub fn parse_loop_args(args: &str) -> (Duration, String) {
305    let args = args.trim();
306
307    // Try leading interval: first whitespace-separated token is a valid interval
308    if let Some(space) = args.find(char::is_whitespace) {
309        let first = &args[..space];
310        let rest = args[space..].trim();
311        if let Some(interval) = parse_interval(first) {
312            if !rest.is_empty() {
313                return (interval, rest.to_string());
314            }
315        }
316    }
317
318    // Try trailing "every <interval>": last " every <token>" in the string.
319    // `find_every_clause` returns the byte position of the leading space before "every".
320    const EVERY_NEEDLE: &str = " every ";
321    if let Some(every_pos) = find_every_clause(args) {
322        let prompt_part = args[..every_pos].trim();
323        let interval_token = args[every_pos + EVERY_NEEDLE.len()..]
324            .split_whitespace()
325            .next()
326            .unwrap_or("");
327        if let Some(interval) = parse_interval(interval_token) {
328            if !prompt_part.is_empty() {
329                return (interval, prompt_part.to_string());
330            }
331        }
332    }
333
334    // Default: 10 minutes, full args as prompt
335    (Duration::from_secs(DEFAULT_INTERVAL_SECS), args.to_string())
336}
337
338/// Find the position of the last ` every <valid-interval>` clause in `s`.
339fn find_every_clause(s: &str) -> Option<usize> {
340    let needle = " every ";
341    let mut best: Option<usize> = None;
342    let mut search_from = 0;
343    while let Some(rel) = s[search_from..].find(needle) {
344        let abs = search_from + rel;
345        let after = s[abs + needle.len()..]
346            .split_whitespace()
347            .next()
348            .unwrap_or("");
349        if parse_interval(after).is_some() {
350            best = Some(abs);
351        }
352        search_from = abs + 1;
353    }
354    best
355}
356
357/// Format a duration in seconds to a human-readable string (`"5m"`, `"2h 30m"`, etc.).
358pub fn format_duration(secs: u64) -> String {
359    if secs == 0 {
360        return "now".to_string();
361    }
362    let d = secs / 86400;
363    let h = (secs % 86400) / 3600;
364    let m = (secs % 3600) / 60;
365    let s = secs % 60;
366    let mut parts: Vec<String> = Vec::new();
367    if d > 0 {
368        parts.push(format!("{d}d"));
369    }
370    if h > 0 {
371        parts.push(format!("{h}h"));
372    }
373    if m > 0 {
374        parts.push(format!("{m}m"));
375    }
376    if s > 0 && parts.is_empty() {
377        parts.push(format!("{s}s"));
378    }
379    parts.join(" ")
380}
381
382// ─── Tests ───────────────────────────────────────────────────────────────────
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    #[test]
389    fn test_parse_interval() {
390        assert_eq!(parse_interval("30s"), Some(Duration::from_secs(30)));
391        assert_eq!(parse_interval("5m"), Some(Duration::from_secs(300)));
392        assert_eq!(parse_interval("2h"), Some(Duration::from_secs(7200)));
393        assert_eq!(parse_interval("1d"), Some(Duration::from_secs(86400)));
394        assert_eq!(parse_interval("bad"), None);
395        assert_eq!(parse_interval(""), None);
396        assert_eq!(parse_interval("m"), None);
397        assert_eq!(parse_interval("0m"), Some(Duration::from_secs(0)));
398    }
399
400    #[test]
401    fn test_parse_loop_args_leading_interval() {
402        let (interval, prompt) = parse_loop_args("5m check the deployment");
403        assert_eq!(interval, Duration::from_secs(300));
404        assert_eq!(prompt, "check the deployment");
405    }
406
407    #[test]
408    fn test_parse_loop_args_trailing_every() {
409        let (interval, prompt) = parse_loop_args("monitor memory usage every 2h");
410        assert_eq!(interval, Duration::from_secs(7200));
411        assert_eq!(prompt, "monitor memory usage");
412    }
413
414    #[test]
415    fn test_parse_loop_args_default_interval() {
416        let (interval, prompt) = parse_loop_args("check the build");
417        assert_eq!(interval, Duration::from_secs(600));
418        assert_eq!(prompt, "check the build");
419    }
420
421    #[test]
422    fn test_parse_loop_args_leading_seconds() {
423        let (interval, prompt) = parse_loop_args("30s ping the server");
424        assert_eq!(interval, Duration::from_secs(30));
425        assert_eq!(prompt, "ping the server");
426    }
427
428    #[test]
429    fn test_format_duration() {
430        assert_eq!(format_duration(0), "now");
431        assert_eq!(format_duration(30), "30s");
432        assert_eq!(format_duration(300), "5m");
433        assert_eq!(format_duration(3600), "1h");
434        assert_eq!(format_duration(3660), "1h 1m");
435        assert_eq!(format_duration(86400), "1d");
436        assert_eq!(format_duration(90000), "1d 1h");
437    }
438
439    #[test]
440    fn test_create_and_list_tasks() {
441        let (scheduler, _rx) = CronScheduler::new();
442        let id = scheduler
443            .create_task("hello".to_string(), Duration::from_secs(60), true)
444            .unwrap();
445        assert_eq!(id.len(), 8);
446        let tasks = scheduler.list_tasks();
447        assert_eq!(tasks.len(), 1);
448        assert_eq!(tasks[0].prompt, "hello");
449        assert_eq!(tasks[0].interval_secs, 60);
450        assert!(tasks[0].recurring);
451        assert_eq!(tasks[0].fire_count, 0);
452    }
453
454    #[test]
455    fn test_cancel_task() {
456        let (scheduler, _rx) = CronScheduler::new();
457        let id = scheduler
458            .create_task("test".to_string(), Duration::from_secs(60), true)
459            .unwrap();
460        assert!(scheduler.cancel_task(&id));
461        assert!(!scheduler.cancel_task(&id)); // already gone
462        assert_eq!(scheduler.list_tasks().len(), 0);
463    }
464
465    #[test]
466    fn test_max_tasks_limit() {
467        let (scheduler, _rx) = CronScheduler::new();
468        for i in 0..MAX_TASKS {
469            scheduler
470                .create_task(format!("task {i}"), Duration::from_secs(60), true)
471                .unwrap();
472        }
473        let err = scheduler
474            .create_task("overflow".to_string(), Duration::from_secs(60), true)
475            .unwrap_err();
476        assert!(err.contains("maximum"));
477    }
478
479    #[test]
480    fn test_compute_jitter_deterministic() {
481        let j1 = compute_jitter("abc123", Duration::from_secs(600));
482        let j2 = compute_jitter("abc123", Duration::from_secs(600));
483        assert_eq!(j1, j2);
484    }
485
486    #[test]
487    fn test_compute_jitter_bounded() {
488        for id in &["aaa", "bbb", "ccc", "ddd"] {
489            let jitter = compute_jitter(id, Duration::from_secs(600));
490            assert!(jitter <= Duration::from_secs(60)); // 10% of 600s
491        }
492        // With a very large interval, jitter is capped at 15 min
493        let jitter = compute_jitter("aaa", Duration::from_secs(10 * 3600));
494        assert!(jitter <= MAX_JITTER);
495    }
496}