Skip to main content

baml_agent/
loop_detect.rs

1/// Detects repeated action patterns in agent loops.
2///
3/// Usage:
4/// ```ignore
5/// let mut detector = LoopDetector::new(6); // abort after 6 repeats
6/// for action in actions {
7///     let sig = format!("tool:{}:{}", action.name, action.key_arg);
8///     match detector.check(&sig) {
9///         LoopStatus::Ok => { /* proceed */ }
10///         LoopStatus::Warning(n) => { /* inject warning into context */ }
11///         LoopStatus::Abort(n) => { /* stop the loop */ }
12///     }
13/// }
14/// ```
15pub struct LoopDetector {
16    last_signature: Option<String>,
17    repeat_count: usize,
18    abort_threshold: usize,
19    warn_threshold: usize,
20}
21
22#[derive(Debug, PartialEq)]
23pub enum LoopStatus {
24    /// No loop detected.
25    Ok,
26    /// Repeat detected, but below abort threshold. Contains repeat count.
27    Warning(usize),
28    /// Too many repeats, should abort. Contains repeat count.
29    Abort(usize),
30}
31
32impl LoopDetector {
33    /// Create detector. Warns at `abort_threshold / 2`, aborts at `abort_threshold`.
34    pub fn new(abort_threshold: usize) -> Self {
35        Self {
36            last_signature: None,
37            repeat_count: 0,
38            abort_threshold,
39            warn_threshold: abort_threshold / 2,
40        }
41    }
42
43    /// Create detector with explicit warn threshold.
44    pub fn with_thresholds(warn_threshold: usize, abort_threshold: usize) -> Self {
45        Self {
46            last_signature: None,
47            repeat_count: 0,
48            abort_threshold,
49            warn_threshold,
50        }
51    }
52
53    /// Check a combined signature for the current step's actions.
54    ///
55    /// `signature` should uniquely identify the action(s) being taken.
56    /// If multiple actions, join their signatures with `|`.
57    pub fn check(&mut self, signature: &str) -> LoopStatus {
58        if self.last_signature.as_deref() == Some(signature) {
59            self.repeat_count += 1;
60            if self.repeat_count >= self.abort_threshold {
61                return LoopStatus::Abort(self.repeat_count);
62            }
63            if self.repeat_count >= self.warn_threshold {
64                return LoopStatus::Warning(self.repeat_count);
65            }
66        } else {
67            self.repeat_count = 1;
68            self.last_signature = Some(signature.into());
69        }
70        LoopStatus::Ok
71    }
72
73    /// Reset detector state.
74    pub fn reset(&mut self) {
75        self.last_signature = None;
76        self.repeat_count = 0;
77    }
78
79    pub fn repeat_count(&self) -> usize {
80        self.repeat_count
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87
88    #[test]
89    fn no_loop_different_sigs() {
90        let mut d = LoopDetector::new(6);
91        assert_eq!(d.check("a"), LoopStatus::Ok);
92        assert_eq!(d.check("b"), LoopStatus::Ok);
93        assert_eq!(d.check("c"), LoopStatus::Ok);
94    }
95
96    #[test]
97    fn warn_then_abort() {
98        let mut d = LoopDetector::new(6);
99        assert_eq!(d.check("x"), LoopStatus::Ok);
100        assert_eq!(d.check("x"), LoopStatus::Ok); // 2
101        assert_eq!(d.check("x"), LoopStatus::Warning(3)); // warn at 3
102        assert_eq!(d.check("x"), LoopStatus::Warning(4));
103        assert_eq!(d.check("x"), LoopStatus::Warning(5));
104        assert_eq!(d.check("x"), LoopStatus::Abort(6)); // abort at 6
105    }
106
107    #[test]
108    fn reset_clears() {
109        let mut d = LoopDetector::new(4);
110        d.check("x");
111        d.check("x");
112        d.check("x"); // warning
113        d.reset();
114        assert_eq!(d.check("x"), LoopStatus::Ok); // fresh start
115    }
116
117    #[test]
118    fn different_sig_resets_count() {
119        let mut d = LoopDetector::new(6);
120        d.check("x");
121        d.check("x");
122        d.check("x"); // 3 = warning
123        assert_eq!(d.check("y"), LoopStatus::Ok); // reset
124        assert_eq!(d.repeat_count(), 1);
125    }
126}