Skip to main content

a3s_code_core/
scheduler.rs

1//! Session-scoped prompt scheduler — backs `/loop`, `/cron-list`, `/cron-cancel`.
2//!
3//! Users schedule recurring prompts with `/loop 5m check the deployment`.
4//! The scheduler fires them via a tokio channel; `AgentSession` drains the
5//! channel after each `send()` and runs the fired prompts.
6//!
7//! ## Design
8//!
9//! - `CronScheduler` holds a `std::sync::Mutex<CronSchedulerInner>` so it can
10//!   be accessed from synchronous slash command handlers.
11//! - A background tokio task holds a `Weak<CronScheduler>` and ticks every
12//!   second. It exits automatically when the session is dropped.
13//! - Fired prompts are delivered via `tokio::sync::mpsc::UnboundedSender`.
14//! - Deterministic jitter (0–10% of period, capped at 15 min) is derived
15//!   from the task ID so the same task always fires with the same offset.
16
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19use std::time::{Duration, Instant};
20use tokio::sync::mpsc;
21
22/// Maximum number of tasks a session may have at one time.
23const MAX_TASKS: usize = 50;
24/// Default interval when none is specified.
25const DEFAULT_INTERVAL_SECS: u64 = 600; // 10 minutes
26/// Recurring tasks auto-expire after this duration.
27const MAX_RECURRING_AGE: Duration = Duration::from_secs(3 * 24 * 3600); // 3 days
28/// Jitter cap.
29const MAX_JITTER: Duration = Duration::from_secs(15 * 60); // 15 minutes
30
31// ─── Public Types ────────────────────────────────────────────────────────────
32
33/// A pending scheduled task.
34pub struct ScheduledTask {
35    pub id: String,
36    pub prompt: String,
37    pub interval: Duration,
38    pub recurring: bool,
39    pub created_at: Instant,
40    pub next_fire: Instant,
41    pub fire_count: usize,
42}
43
44/// A task fire event delivered to `AgentSession`.
45pub struct ScheduledFire {
46    pub task_id: String,
47    pub prompt: String,
48}
49
50/// Snapshot of a task used for display (avoids exposing `Instant` in public API).
51pub struct ScheduledTaskInfo {
52    pub id: String,
53    pub prompt: String,
54    pub interval_secs: u64,
55    pub recurring: bool,
56    pub fire_count: usize,
57    /// Seconds until the next fire (0 if overdue).
58    pub next_fire_in_secs: u64,
59}
60
61// ─── Scheduler ───────────────────────────────────────────────────────────────
62
63struct CronSchedulerInner {
64    tasks: HashMap<String, ScheduledTask>,
65}
66
67/// Session-scoped prompt scheduler.
68///
69/// Create with [`CronScheduler::new`], then call [`CronScheduler::start`] to
70/// launch the background ticker.
71pub struct CronScheduler {
72    inner: Mutex<CronSchedulerInner>,
73    prompt_tx: mpsc::UnboundedSender<ScheduledFire>,
74}
75
76impl CronScheduler {
77    /// Create a new scheduler and its receiver channel.
78    ///
79    /// Call [`CronScheduler::start`] after wrapping in `Arc` to start the ticker.
80    pub fn new() -> (Arc<Self>, mpsc::UnboundedReceiver<ScheduledFire>) {
81        let (tx, rx) = mpsc::unbounded_channel();
82        let scheduler = Arc::new(Self {
83            inner: Mutex::new(CronSchedulerInner {
84                tasks: HashMap::new(),
85            }),
86            prompt_tx: tx,
87        });
88        (scheduler, rx)
89    }
90
91    /// Spawn the background 1-second ticker.
92    ///
93    /// The task holds a `Weak` reference so it exits automatically when the
94    /// `Arc<CronScheduler>` is dropped (i.e., when the session is dropped).
95    pub fn start(scheduler: Arc<Self>) {
96        let weak = Arc::downgrade(&scheduler);
97        drop(scheduler);
98        tokio::spawn(async move {
99            let mut interval = tokio::time::interval(Duration::from_secs(1));
100            interval.tick().await; // consume the immediate first tick
101            loop {
102                interval.tick().await;
103                match weak.upgrade() {
104                    Some(s) => s.tick(),
105                    None => break, // session dropped — exit cleanly
106                }
107            }
108        });
109    }
110
111    /// Check all tasks and fire any that are due.
112    fn tick(&self) {
113        let now = Instant::now();
114        let mut to_fire: Vec<(String, String)> = Vec::new();
115        let mut to_remove: Vec<String> = Vec::new();
116
117        {
118            let inner = match self.inner.lock() {
119                Ok(g) => g,
120                Err(_) => return,
121            };
122            for (id, task) in &inner.tasks {
123                if now >= task.next_fire {
124                    to_fire.push((id.clone(), task.prompt.clone()));
125                    let age = now - task.created_at;
126                    if !task.recurring || age >= MAX_RECURRING_AGE {
127                        to_remove.push(id.clone());
128                    }
129                }
130            }
131        }
132
133        if to_fire.is_empty() {
134            return;
135        }
136
137        {
138            let mut inner = match self.inner.lock() {
139                Ok(g) => g,
140                Err(_) => return,
141            };
142            for (id, prompt) in &to_fire {
143                if let Some(task) = inner.tasks.get_mut(id) {
144                    task.fire_count += 1;
145                    if task.recurring && !to_remove.contains(id) {
146                        let jitter = compute_jitter(id, task.interval);
147                        task.next_fire = Instant::now() + task.interval + jitter;
148                    }
149                }
150                let _ = self.prompt_tx.send(ScheduledFire {
151                    task_id: id.clone(),
152                    prompt: prompt.clone(),
153                });
154            }
155            for id in &to_remove {
156                inner.tasks.remove(id);
157            }
158        }
159    }
160
161    /// Schedule a new task. Returns the task ID on success.
162    pub fn create_task(
163        &self,
164        prompt: String,
165        interval: Duration,
166        recurring: bool,
167    ) -> Result<String, String> {
168        let mut inner = self
169            .inner
170            .lock()
171            .map_err(|_| "scheduler lock poisoned".to_string())?;
172        if inner.tasks.len() >= MAX_TASKS {
173            return Err(format!(
174                "maximum of {MAX_TASKS} scheduled tasks reached; cancel one with /cron-cancel"
175            ));
176        }
177        let id = new_task_id();
178        let jitter = compute_jitter(&id, interval);
179        let now = Instant::now();
180        inner.tasks.insert(
181            id.clone(),
182            ScheduledTask {
183                id: id.clone(),
184                prompt,
185                interval,
186                recurring,
187                created_at: now,
188                next_fire: now + interval + jitter,
189                fire_count: 0,
190            },
191        );
192        Ok(id)
193    }
194
195    /// List all active tasks sorted by ID.
196    pub fn list_tasks(&self) -> Vec<ScheduledTaskInfo> {
197        let inner = match self.inner.lock() {
198            Ok(g) => g,
199            Err(_) => return vec![],
200        };
201        let now = Instant::now();
202        let mut tasks: Vec<_> = inner
203            .tasks
204            .values()
205            .map(|t| ScheduledTaskInfo {
206                id: t.id.clone(),
207                prompt: t.prompt.clone(),
208                interval_secs: t.interval.as_secs(),
209                recurring: t.recurring,
210                fire_count: t.fire_count,
211                next_fire_in_secs: if t.next_fire > now {
212                    (t.next_fire - now).as_secs()
213                } else {
214                    0
215                },
216            })
217            .collect();
218        tasks.sort_by(|a, b| a.id.cmp(&b.id));
219        tasks
220    }
221
222    /// Cancel a task by ID. Returns `true` if it existed.
223    pub fn cancel_task(&self, id: &str) -> bool {
224        self.inner
225            .lock()
226            .ok()
227            .map(|mut g| g.tasks.remove(id).is_some())
228            .unwrap_or(false)
229    }
230
231    /// Number of active tasks.
232    pub fn task_count(&self) -> usize {
233        self.inner.lock().map(|g| g.tasks.len()).unwrap_or(0)
234    }
235}
236
237// ─── Helpers ─────────────────────────────────────────────────────────────────
238
239/// Generate a short random task ID (8 hex chars via UUID v4).
240fn new_task_id() -> String {
241    let id = uuid::Uuid::new_v4().to_string();
242    id[..8].to_string()
243}
244
245/// Deterministic jitter: 0–10% of `interval`, capped at 15 min.
246///
247/// Derived from the task ID hash so the same task always fires with the
248/// same relative offset (avoids thundering herd when many tasks share the
249/// same interval).
250fn compute_jitter(id: &str, interval: Duration) -> Duration {
251    use std::collections::hash_map::DefaultHasher;
252    use std::hash::{Hash, Hasher};
253    let mut h = DefaultHasher::new();
254    id.hash(&mut h);
255    let fraction = (h.finish() % 1000) as f64 / 10000.0; // 0.000 .. 0.099
256    let raw_secs = (interval.as_secs_f64() * fraction) as u64;
257    Duration::from_secs(raw_secs.min(MAX_JITTER.as_secs()))
258}
259
260// ─── Interval / Arg Parsing ──────────────────────────────────────────────────
261
262/// Parse a duration string like `"5m"`, `"30s"`, `"2h"`, `"1d"`.
263///
264/// Returns `Some(Duration)` on success, `None` if the format is not recognized.
265pub fn parse_interval(s: &str) -> Option<Duration> {
266    if s.len() < 2 {
267        return None;
268    }
269    let (num_part, unit) = s.split_at(s.len() - 1);
270    let n: u64 = num_part.parse().ok()?;
271    match unit {
272        "s" => Some(Duration::from_secs(n)),
273        "m" => Some(Duration::from_secs(n * 60)),
274        "h" => Some(Duration::from_secs(n * 3600)),
275        "d" => Some(Duration::from_secs(n * 86400)),
276        _ => None,
277    }
278}
279
280/// Parse the argument string from `/loop <args>`.
281///
282/// Supports three forms:
283/// - Leading interval: `/loop 5m check the build` → (5m, "check the build")
284/// - Trailing clause:  `/loop check the build every 2h` → (2h, "check the build")
285/// - No interval:      `/loop check the build` → (10m default, "check the build")
286///
287/// Returns `(interval, prompt)`.
288pub fn parse_loop_args(args: &str) -> (Duration, String) {
289    let args = args.trim();
290
291    // Try leading interval: first whitespace-separated token is a valid interval
292    if let Some(space) = args.find(char::is_whitespace) {
293        let first = &args[..space];
294        let rest = args[space..].trim();
295        if let Some(interval) = parse_interval(first) {
296            if !rest.is_empty() {
297                return (interval, rest.to_string());
298            }
299        }
300    }
301
302    // Try trailing "every <interval>": last " every <token>" in the string.
303    // `find_every_clause` returns the byte position of the leading space before "every".
304    const EVERY_NEEDLE: &str = " every ";
305    if let Some(every_pos) = find_every_clause(args) {
306        let prompt_part = args[..every_pos].trim();
307        let interval_token = args[every_pos + EVERY_NEEDLE.len()..]
308            .split_whitespace()
309            .next()
310            .unwrap_or("");
311        if let Some(interval) = parse_interval(interval_token) {
312            if !prompt_part.is_empty() {
313                return (interval, prompt_part.to_string());
314            }
315        }
316    }
317
318    // Default: 10 minutes, full args as prompt
319    (Duration::from_secs(DEFAULT_INTERVAL_SECS), args.to_string())
320}
321
322/// Find the position of the last ` every <valid-interval>` clause in `s`.
323fn find_every_clause(s: &str) -> Option<usize> {
324    let needle = " every ";
325    let mut best: Option<usize> = None;
326    let mut search_from = 0;
327    while let Some(rel) = s[search_from..].find(needle) {
328        let abs = search_from + rel;
329        let after = s[abs + needle.len()..]
330            .split_whitespace()
331            .next()
332            .unwrap_or("");
333        if parse_interval(after).is_some() {
334            best = Some(abs);
335        }
336        search_from = abs + 1;
337    }
338    best
339}
340
341/// Format a duration in seconds to a human-readable string (`"5m"`, `"2h 30m"`, etc.).
342pub fn format_duration(secs: u64) -> String {
343    if secs == 0 {
344        return "now".to_string();
345    }
346    let d = secs / 86400;
347    let h = (secs % 86400) / 3600;
348    let m = (secs % 3600) / 60;
349    let s = secs % 60;
350    let mut parts: Vec<String> = Vec::new();
351    if d > 0 {
352        parts.push(format!("{d}d"));
353    }
354    if h > 0 {
355        parts.push(format!("{h}h"));
356    }
357    if m > 0 {
358        parts.push(format!("{m}m"));
359    }
360    if s > 0 && parts.is_empty() {
361        parts.push(format!("{s}s"));
362    }
363    parts.join(" ")
364}
365
366// ─── Tests ───────────────────────────────────────────────────────────────────
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn test_parse_interval() {
374        assert_eq!(parse_interval("30s"), Some(Duration::from_secs(30)));
375        assert_eq!(parse_interval("5m"), Some(Duration::from_secs(300)));
376        assert_eq!(parse_interval("2h"), Some(Duration::from_secs(7200)));
377        assert_eq!(parse_interval("1d"), Some(Duration::from_secs(86400)));
378        assert_eq!(parse_interval("bad"), None);
379        assert_eq!(parse_interval(""), None);
380        assert_eq!(parse_interval("m"), None);
381        assert_eq!(parse_interval("0m"), Some(Duration::from_secs(0)));
382    }
383
384    #[test]
385    fn test_parse_loop_args_leading_interval() {
386        let (interval, prompt) = parse_loop_args("5m check the deployment");
387        assert_eq!(interval, Duration::from_secs(300));
388        assert_eq!(prompt, "check the deployment");
389    }
390
391    #[test]
392    fn test_parse_loop_args_trailing_every() {
393        let (interval, prompt) = parse_loop_args("monitor memory usage every 2h");
394        assert_eq!(interval, Duration::from_secs(7200));
395        assert_eq!(prompt, "monitor memory usage");
396    }
397
398    #[test]
399    fn test_parse_loop_args_default_interval() {
400        let (interval, prompt) = parse_loop_args("check the build");
401        assert_eq!(interval, Duration::from_secs(600));
402        assert_eq!(prompt, "check the build");
403    }
404
405    #[test]
406    fn test_parse_loop_args_leading_seconds() {
407        let (interval, prompt) = parse_loop_args("30s ping the server");
408        assert_eq!(interval, Duration::from_secs(30));
409        assert_eq!(prompt, "ping the server");
410    }
411
412    #[test]
413    fn test_format_duration() {
414        assert_eq!(format_duration(0), "now");
415        assert_eq!(format_duration(30), "30s");
416        assert_eq!(format_duration(300), "5m");
417        assert_eq!(format_duration(3600), "1h");
418        assert_eq!(format_duration(3660), "1h 1m");
419        assert_eq!(format_duration(86400), "1d");
420        assert_eq!(format_duration(90000), "1d 1h");
421    }
422
423    #[test]
424    fn test_create_and_list_tasks() {
425        let (scheduler, _rx) = CronScheduler::new();
426        let id = scheduler
427            .create_task("hello".to_string(), Duration::from_secs(60), true)
428            .unwrap();
429        assert_eq!(id.len(), 8);
430        let tasks = scheduler.list_tasks();
431        assert_eq!(tasks.len(), 1);
432        assert_eq!(tasks[0].prompt, "hello");
433        assert_eq!(tasks[0].interval_secs, 60);
434        assert!(tasks[0].recurring);
435        assert_eq!(tasks[0].fire_count, 0);
436    }
437
438    #[test]
439    fn test_cancel_task() {
440        let (scheduler, _rx) = CronScheduler::new();
441        let id = scheduler
442            .create_task("test".to_string(), Duration::from_secs(60), true)
443            .unwrap();
444        assert!(scheduler.cancel_task(&id));
445        assert!(!scheduler.cancel_task(&id)); // already gone
446        assert_eq!(scheduler.list_tasks().len(), 0);
447    }
448
449    #[test]
450    fn test_max_tasks_limit() {
451        let (scheduler, _rx) = CronScheduler::new();
452        for i in 0..MAX_TASKS {
453            scheduler
454                .create_task(format!("task {i}"), Duration::from_secs(60), true)
455                .unwrap();
456        }
457        let err = scheduler
458            .create_task("overflow".to_string(), Duration::from_secs(60), true)
459            .unwrap_err();
460        assert!(err.contains("maximum"));
461    }
462
463    #[test]
464    fn test_compute_jitter_deterministic() {
465        let j1 = compute_jitter("abc123", Duration::from_secs(600));
466        let j2 = compute_jitter("abc123", Duration::from_secs(600));
467        assert_eq!(j1, j2);
468    }
469
470    #[test]
471    fn test_compute_jitter_bounded() {
472        for id in &["aaa", "bbb", "ccc", "ddd"] {
473            let jitter = compute_jitter(id, Duration::from_secs(600));
474            assert!(jitter <= Duration::from_secs(60)); // 10% of 600s
475        }
476        // With a very large interval, jitter is capped at 15 min
477        let jitter = compute_jitter("aaa", Duration::from_secs(10 * 3600));
478        assert!(jitter <= MAX_JITTER);
479    }
480}