Skip to main content

baml_agent/
loop_detect.rs

1//! Loop detection for agent loops.
2//!
3//! Detects three types of repetitive behavior:
4//!
5//! 1. **Exact repetition** — identical action signatures (catches trivial loops)
6//! 2. **Semantic repetition** — normalized signatures via [`normalize_signature`]
7//!    (catches loops where the agent retries the same intent with different flags,
8//!    quotes, or fallback chains)
9//! 3. **Output stagnation** — identical tool outputs despite varied commands
10//!    (catches loops where the agent tries different approaches but gets the same result)
11//!
12//! Each signal tracks consecutive matches independently. The worst signal
13//! determines the returned [`LoopStatus`].
14
15use std::collections::hash_map::DefaultHasher;
16use std::hash::{Hash, Hasher};
17
18// ---------------------------------------------------------------------------
19// Signature normalization
20// ---------------------------------------------------------------------------
21
22/// Normalize an action signature to its semantic category.
23///
24/// Strips syntactic noise from bash commands to detect loops where the agent
25/// retries the same intent with minor variations (different flags, quotes,
26/// fallback chains).
27///
28/// # Rules
29///
30/// For `bash:...` signatures:
31/// 1. Strip fallback/chain operators (`||`, `&&`, `;`, `|` — with surrounding spaces)
32/// 2. Remove command flags (`-n`, `-i`, `--long-flag`)
33/// 3. Strip surrounding quotes (`'`, `"`) and trailing slashes from arguments
34/// 4. Search tools (`rg`, `grep`, `ag`, `ack`, `fgrep`, `egrep`) → `bash-search:args`
35/// 5. Other commands → `bash:cmd:args`
36///
37/// Non-bash signatures pass through unchanged.
38///
39/// # Examples
40///
41/// ```
42/// use baml_agent::loop_detect::normalize_signature;
43///
44/// // All these normalize to the same category:
45/// let a = normalize_signature("bash:rg -n 'TODO|FIXME' crates/src/");
46/// let b = normalize_signature("bash:rg -Hn \"TODO|FIXME\" crates/src/");
47/// let c = normalize_signature("bash:grep -rnE 'TODO|FIXME' crates/src/ || echo 'not found'");
48/// assert_eq!(a, b);
49/// assert_eq!(b, c);
50/// assert_eq!(a, "bash-search:TODO|FIXME crates/src");
51///
52/// // Non-bash unchanged
53/// assert_eq!(normalize_signature("read:src/main.rs"), "read:src/main.rs");
54/// ```
55pub fn normalize_signature(sig: &str) -> String {
56    if let Some(cmd) = sig.strip_prefix("bash:") {
57        normalize_bash(cmd)
58    } else {
59        sig.to_string()
60    }
61}
62
63/// Binaries recognized as "search" tools (normalized to `bash-search:` category).
64const SEARCH_BINS: &[&str] = &["rg", "grep", "ag", "ack", "fgrep", "egrep"];
65
66fn normalize_bash(cmd: &str) -> String {
67    // 1. Strip chain operators to isolate the primary command.
68    //    Use " || ", " && ", " ; ", " | " (with spaces) to avoid matching
69    //    inside quoted patterns like 'TODO|FIXME'.
70    let core = [" || ", " && ", " ; ", " | "]
71        .iter()
72        .fold(cmd, |acc, sep| acc.split(sep).next().unwrap_or(acc))
73        .trim();
74
75    // 2. Tokenize by whitespace.
76    let tokens: Vec<&str> = core.split_whitespace().collect();
77    if tokens.is_empty() {
78        return "bash:".into();
79    }
80
81    let bin = tokens[0];
82
83    // 3. Extract non-flag arguments, strip quotes and trailing slashes.
84    let args: Vec<String> = tokens[1..]
85        .iter()
86        .filter(|t| !t.starts_with('-'))
87        .map(|t| {
88            t.trim_matches(|c: char| c == '\'' || c == '"')
89                .trim_end_matches('/')
90                .to_string()
91        })
92        .filter(|s| !s.is_empty())
93        .collect();
94
95    // 4. Categorize.
96    if SEARCH_BINS.contains(&bin) {
97        format!("bash-search:{}", args.join(" "))
98    } else if args.is_empty() {
99        format!("bash:{}", bin)
100    } else {
101        format!("bash:{}:{}", bin, args.join(" "))
102    }
103}
104
105// ---------------------------------------------------------------------------
106// Internal trackers
107// ---------------------------------------------------------------------------
108
109/// Tracks consecutive occurrences of the same string value.
110struct ConsecutiveTracker {
111    last: Option<String>,
112    count: usize,
113}
114
115impl ConsecutiveTracker {
116    fn new() -> Self {
117        Self {
118            last: None,
119            count: 0,
120        }
121    }
122
123    /// Record a value. Returns the current consecutive count (≥ 1).
124    fn record(&mut self, value: &str) -> usize {
125        if self.last.as_deref() == Some(value) {
126            self.count += 1;
127        } else {
128            self.last = Some(value.to_string());
129            self.count = 1;
130        }
131        self.count
132    }
133
134    fn reset(&mut self) {
135        self.last = None;
136        self.count = 0;
137    }
138
139    fn count(&self) -> usize {
140        self.count
141    }
142}
143
144/// Tracks consecutive occurrences by hash (for large strings like tool output).
145struct HashTracker {
146    last_hash: Option<u64>,
147    count: usize,
148}
149
150impl HashTracker {
151    fn new() -> Self {
152        Self {
153            last_hash: None,
154            count: 0,
155        }
156    }
157
158    fn record(&mut self, value: &str) -> usize {
159        let mut hasher = DefaultHasher::new();
160        value.hash(&mut hasher);
161        let hash = hasher.finish();
162
163        if self.last_hash == Some(hash) {
164            self.count += 1;
165        } else {
166            self.last_hash = Some(hash);
167            self.count = 1;
168        }
169        self.count
170    }
171
172    fn reset(&mut self) {
173        self.last_hash = None;
174        self.count = 0;
175    }
176
177    fn count(&self) -> usize {
178        self.count
179    }
180}
181
182// ---------------------------------------------------------------------------
183// LoopDetector
184// ---------------------------------------------------------------------------
185
186/// Detects repeated action patterns in agent loops.
187///
188/// Three independent signals, each tracking consecutive repetitions:
189///
190/// | Signal   | Tracks                          | Catches                                |
191/// |----------|---------------------------------|----------------------------------------|
192/// | Exact    | Identical signatures            | Trivial loops (same tool, same args)   |
193/// | Category | Normalized signatures           | Semantic loops (same intent, diff syntax)|
194/// | Output   | Identical tool output (by hash) | Stagnation (different tools, same result)|
195///
196/// Usage:
197/// ```ignore
198/// let mut detector = LoopDetector::new(6);
199///
200/// // Per step: check action signatures
201/// let sig = "bash:rg -n 'TODO' src/";
202/// let cat = normalize_signature(sig);
203/// match detector.check_with_category(sig, &cat) {
204///     LoopStatus::Abort(n) => { /* stop */ }
205///     LoopStatus::Warning(n) => { /* inject system message */ }
206///     LoopStatus::Ok => { /* proceed */ }
207/// }
208///
209/// // Per action execution: check tool output
210/// match detector.record_output("No matches found") {
211///     LoopStatus::Warning(n) => { /* nudge model */ }
212///     _ => {}
213/// }
214/// ```
215pub struct LoopDetector {
216    /// Tier 1: exact signature repetition.
217    exact: ConsecutiveTracker,
218    /// Tier 2: normalized category repetition.
219    category: ConsecutiveTracker,
220    /// Tier 3: tool output repetition (by hash).
221    output: HashTracker,
222    abort_threshold: usize,
223    warn_threshold: usize,
224}
225
226#[derive(Debug, PartialEq)]
227pub enum LoopStatus {
228    /// No loop detected.
229    Ok,
230    /// Repeat detected, below abort threshold. Contains repeat count.
231    Warning(usize),
232    /// Too many repeats, should abort. Contains repeat count.
233    Abort(usize),
234}
235
236impl LoopDetector {
237    /// Create detector. Warns at `⌈abort_threshold/2⌉`, aborts at `abort_threshold`.
238    pub fn new(abort_threshold: usize) -> Self {
239        Self {
240            exact: ConsecutiveTracker::new(),
241            category: ConsecutiveTracker::new(),
242            output: HashTracker::new(),
243            abort_threshold,
244            warn_threshold: abort_threshold.div_ceil(2),
245        }
246    }
247
248    /// Create detector with explicit warn threshold.
249    pub fn with_thresholds(warn_threshold: usize, abort_threshold: usize) -> Self {
250        Self {
251            exact: ConsecutiveTracker::new(),
252            category: ConsecutiveTracker::new(),
253            output: HashTracker::new(),
254            abort_threshold,
255            warn_threshold,
256        }
257    }
258
259    /// Check action signature only (backward-compatible).
260    ///
261    /// Uses `signature` as both exact match and category.
262    /// For semantic loop detection, use [`check_with_category`] instead.
263    pub fn check(&mut self, signature: &str) -> LoopStatus {
264        self.check_with_category(signature, signature)
265    }
266
267    /// Check action with separate exact signature and normalized category.
268    ///
269    /// Returns the worst status across exact and category signals.
270    pub fn check_with_category(&mut self, signature: &str, category: &str) -> LoopStatus {
271        let exact_n = self.exact.record(signature);
272        let cat_n = self.category.record(category);
273        let max_n = exact_n.max(cat_n);
274
275        if max_n >= self.abort_threshold {
276            LoopStatus::Abort(max_n)
277        } else if max_n >= self.warn_threshold {
278            LoopStatus::Warning(max_n)
279        } else {
280            LoopStatus::Ok
281        }
282    }
283
284    /// Record a tool output and check for output stagnation.
285    ///
286    /// Call after each action execution. Returns [`LoopStatus::Warning`] or
287    /// [`LoopStatus::Abort`] if the same output has been seen too many
288    /// consecutive times — the model is retrying a command that keeps
289    /// giving the same result.
290    pub fn record_output(&mut self, output: &str) -> LoopStatus {
291        let n = self.output.record(output);
292        if n >= self.abort_threshold {
293            LoopStatus::Abort(n)
294        } else if n >= self.warn_threshold {
295            LoopStatus::Warning(n)
296        } else {
297            LoopStatus::Ok
298        }
299    }
300
301    /// Reset all detector state.
302    pub fn reset(&mut self) {
303        self.exact.reset();
304        self.category.reset();
305        self.output.reset();
306    }
307
308    /// Current repeat count (max across all signals).
309    pub fn repeat_count(&self) -> usize {
310        self.exact
311            .count()
312            .max(self.category.count())
313            .max(self.output.count())
314    }
315
316    /// Exact signature repeat count.
317    pub fn exact_count(&self) -> usize {
318        self.exact.count()
319    }
320
321    /// Normalized category repeat count.
322    pub fn category_count(&self) -> usize {
323        self.category.count()
324    }
325
326    /// Output stagnation repeat count.
327    pub fn output_count(&self) -> usize {
328        self.output.count()
329    }
330}
331
332// ---------------------------------------------------------------------------
333// Tests
334// ---------------------------------------------------------------------------
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    // --- normalize_signature ---
341
342    #[test]
343    fn normalize_bash_search_strips_flags_and_quotes() {
344        assert_eq!(
345            normalize_signature("bash:rg -n 'TODO|FIXME' crates/src/"),
346            "bash-search:TODO|FIXME crates/src"
347        );
348    }
349
350    #[test]
351    fn normalize_bash_search_double_quotes() {
352        assert_eq!(
353            normalize_signature("bash:rg -Hn \"TODO|FIXME\" crates/src/"),
354            "bash-search:TODO|FIXME crates/src"
355        );
356    }
357
358    #[test]
359    fn normalize_bash_search_strips_fallback() {
360        assert_eq!(
361            normalize_signature("bash:rg 'TODO' dir/ || echo 'not found'"),
362            "bash-search:TODO dir"
363        );
364    }
365
366    #[test]
367    fn normalize_bash_grep_same_as_rg() {
368        assert_eq!(
369            normalize_signature("bash:grep -rnE 'TODO|FIXME' src/"),
370            "bash-search:TODO|FIXME src"
371        );
372    }
373
374    #[test]
375    fn normalize_bash_complex_fallback() {
376        assert_eq!(
377            normalize_signature("bash:rg 'TODO' dir/ || (echo 'fail' && ls -la dir/)"),
378            "bash-search:TODO dir"
379        );
380    }
381
382    #[test]
383    fn normalize_non_bash_unchanged() {
384        assert_eq!(normalize_signature("read:src/main.rs"), "read:src/main.rs");
385        assert_eq!(
386            normalize_signature("write:config.toml"),
387            "write:config.toml"
388        );
389        assert_eq!(normalize_signature("edit:src/lib.rs"), "edit:src/lib.rs");
390    }
391
392    #[test]
393    fn normalize_bash_non_search_command() {
394        assert_eq!(normalize_signature("bash:cargo test"), "bash:cargo:test");
395        assert_eq!(normalize_signature("bash:ls -la /tmp"), "bash:ls:/tmp");
396        assert_eq!(normalize_signature("bash:cat file.rs"), "bash:cat:file.rs");
397    }
398
399    #[test]
400    fn normalize_all_rg_variants_equal() {
401        let variants = [
402            "bash:rg -n 'TODO|FIXME' crates/baml-agent/src/",
403            "bash:rg 'TODO|FIXME' crates/baml-agent/src/",
404            "bash:rg -i 'TODO|FIXME' crates/baml-agent/src/",
405            "bash:rg -Hn \"TODO|FIXME\" crates/baml-agent/src/",
406            "bash:rg -n \"TODO|FIXME\" crates/baml-agent/src/ || echo 'No matches'",
407            "bash:rg 'TODO|FIXME' crates/baml-agent/src/ || (echo 'fail' && ls -la)",
408        ];
409        let normalized: Vec<String> = variants.iter().map(|v| normalize_signature(v)).collect();
410        let expected = "bash-search:TODO|FIXME crates/baml-agent/src";
411        for (i, n) in normalized.iter().enumerate() {
412            assert_eq!(n, expected, "variant {} failed: {}", i, variants[i]);
413        }
414    }
415
416    // --- Exact repetition (backward compat) ---
417
418    #[test]
419    fn no_loop_different_sigs() {
420        let mut d = LoopDetector::new(6);
421        assert_eq!(d.check("a"), LoopStatus::Ok);
422        assert_eq!(d.check("b"), LoopStatus::Ok);
423        assert_eq!(d.check("c"), LoopStatus::Ok);
424    }
425
426    #[test]
427    fn warn_then_abort() {
428        let mut d = LoopDetector::new(6);
429        assert_eq!(d.check("x"), LoopStatus::Ok);
430        assert_eq!(d.check("x"), LoopStatus::Ok); // 2
431        assert_eq!(d.check("x"), LoopStatus::Warning(3)); // warn at ceil(6/2)=3
432        assert_eq!(d.check("x"), LoopStatus::Warning(4));
433        assert_eq!(d.check("x"), LoopStatus::Warning(5));
434        assert_eq!(d.check("x"), LoopStatus::Abort(6)); // abort at 6
435    }
436
437    #[test]
438    fn reset_clears() {
439        let mut d = LoopDetector::new(4);
440        d.check("x");
441        d.check("x");
442        d.check("x"); // warning
443        d.reset();
444        assert_eq!(d.check("x"), LoopStatus::Ok); // fresh start
445    }
446
447    #[test]
448    fn different_sig_resets_count() {
449        let mut d = LoopDetector::new(6);
450        d.check("x");
451        d.check("x");
452        d.check("x"); // 3 = warning
453        assert_eq!(d.check("y"), LoopStatus::Ok); // reset
454        assert_eq!(d.repeat_count(), 1);
455    }
456
457    // --- Category (semantic) detection ---
458
459    #[test]
460    fn category_catches_semantic_loop() {
461        let mut d = LoopDetector::new(4); // warn at 2, abort at 4
462                                          // Different exact signatures, same normalized category
463        let sigs = [
464            "bash:rg -n 'TODO' src/",
465            "bash:rg 'TODO' src/",
466            "bash:rg -i 'TODO' src/",
467            "bash:grep -rn 'TODO' src/",
468        ];
469
470        let results: Vec<LoopStatus> = sigs
471            .iter()
472            .map(|sig| {
473                let cat = normalize_signature(sig);
474                d.check_with_category(sig, &cat)
475            })
476            .collect();
477
478        // All exact sigs differ → exact count stays at 1.
479        // All categories same → category count 1, 2, 3, 4.
480        // max(exact, category) determines result.
481        assert_eq!(results[0], LoopStatus::Ok); // max(1,1) = 1 < 2
482        assert_eq!(results[1], LoopStatus::Warning(2)); // max(1,2) = 2
483        assert_eq!(results[2], LoopStatus::Warning(3)); // max(1,3) = 3
484        assert_eq!(results[3], LoopStatus::Abort(4)); // max(1,4) = 4
485    }
486
487    #[test]
488    fn different_categories_reset() {
489        let mut d = LoopDetector::new(4);
490        d.check_with_category("bash:rg 'A' src/", "bash-search:A src");
491        d.check_with_category("bash:rg 'A' src/", "bash-search:A src"); // cat=2
492                                                                        // Different category resets
493        d.check_with_category("bash:cargo test", "bash:cargo:test");
494        assert_eq!(d.category.count(), 1);
495    }
496
497    // --- Output stagnation ---
498
499    #[test]
500    fn output_stagnation_detected() {
501        let mut d = LoopDetector::new(4); // warn at 2
502        assert_eq!(d.record_output("No matches found"), LoopStatus::Ok);
503        assert_eq!(d.record_output("No matches found"), LoopStatus::Warning(2));
504        assert_eq!(d.record_output("No matches found"), LoopStatus::Warning(3));
505        assert_eq!(d.record_output("No matches found"), LoopStatus::Abort(4));
506    }
507
508    #[test]
509    fn output_different_resets() {
510        let mut d = LoopDetector::new(4);
511        d.record_output("result A");
512        d.record_output("result A"); // 2 = warning
513        assert_eq!(d.record_output("result B"), LoopStatus::Ok); // reset to 1
514    }
515
516    // --- Combined: real-world scenario ---
517
518    #[test]
519    fn semantic_loop_caught_within_threshold() {
520        // Simulates the actual TODO/FIXME loop from testing.
521        // 6 steps, each with different flags/quotes but same intent.
522        let mut d = LoopDetector::new(6); // warn at 3, abort at 6
523
524        let steps: Vec<(&str, &str)> = vec![
525            ("bash:rg \"TODO|FIXME\" crates/baml-agent/src/", ""),
526            ("bash:rg -n 'TODO|FIXME' crates/baml-agent/src/", ""),
527            (
528                "bash:rg -n \"TODO|FIXME\" crates/baml-agent/src/ || echo 'No'",
529                "No TODO or FIXME found",
530            ),
531            (
532                "bash:rg 'TODO|FIXME' crates/baml-agent/src/ || (echo && ls)",
533                "Search failed...",
534            ),
535            (
536                "bash:rg 'TODO|FIXME' crates/baml-agent/src/",
537                "No TODO or FIXME found",
538            ),
539            (
540                "bash:rg -n 'TODO|FIXME' crates/baml-agent/src/ || echo 'No'",
541                "No TODO or FIXME found",
542            ),
543        ];
544
545        let mut first_warning = None;
546        let mut abort_at = None;
547
548        for (i, (sig, output)) in steps.iter().enumerate() {
549            let cat = normalize_signature(sig);
550            match d.check_with_category(sig, &cat) {
551                LoopStatus::Warning(n) => {
552                    if first_warning.is_none() {
553                        first_warning = Some(i + 1);
554                    }
555                    let _ = n;
556                }
557                LoopStatus::Abort(_) => {
558                    abort_at = Some(i + 1);
559                    break;
560                }
561                LoopStatus::Ok => {}
562            }
563            d.record_output(output);
564        }
565
566        assert_eq!(first_warning, Some(3), "should warn at step 3");
567        assert_eq!(abort_at, Some(6), "should abort at step 6");
568    }
569}