Skip to main content

lean_ctx/core/
loop_detection.rs

1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use super::config::LoopDetectionConfig;
5
6const SEARCH_TOOLS: &[&str] = &["ctx_search", "ctx_semantic_search"];
7
8const SEARCH_SHELL_PREFIXES: &[&str] = &["grep ", "rg ", "find ", "fd ", "ag ", "ack "];
9
10/// Tracks repeated tool calls within a time window to detect and throttle agent loops.
11#[derive(Debug, Clone)]
12pub struct LoopDetector {
13    call_history: HashMap<String, Vec<Instant>>,
14    duplicate_counts: HashMap<String, u32>,
15    search_group_history: Vec<Instant>,
16    recent_search_patterns: Vec<String>,
17    normal_threshold: u32,
18    reduced_threshold: u32,
19    blocked_threshold: u32,
20    window: Duration,
21    search_group_limit: u32,
22}
23
24/// Severity of throttling applied to a repeated call: normal, reduced, or blocked.
25#[derive(Debug, Clone, PartialEq)]
26pub enum ThrottleLevel {
27    Normal,
28    Reduced,
29    Blocked,
30}
31
32/// Outcome of a loop detection check: throttle level, count, and optional warning.
33#[derive(Debug, Clone)]
34pub struct ThrottleResult {
35    pub level: ThrottleLevel,
36    pub call_count: u32,
37    pub message: Option<String>,
38}
39
40impl Default for LoopDetector {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl LoopDetector {
47    /// Creates a loop detector with default thresholds.
48    pub fn new() -> Self {
49        Self::with_config(&LoopDetectionConfig::default())
50    }
51
52    /// Creates a loop detector with custom thresholds from config.
53    pub fn with_config(cfg: &LoopDetectionConfig) -> Self {
54        Self {
55            call_history: HashMap::new(),
56            duplicate_counts: HashMap::new(),
57            search_group_history: Vec::new(),
58            recent_search_patterns: Vec::new(),
59            normal_threshold: cfg.normal_threshold.max(1),
60            reduced_threshold: cfg.reduced_threshold.max(2),
61            blocked_threshold: cfg.blocked_threshold.max(3),
62            window: Duration::from_secs(cfg.window_secs),
63            search_group_limit: cfg.search_group_limit.max(3),
64        }
65    }
66
67    /// Records a tool call and returns the throttle result based on repetition count.
68    pub fn record_call(&mut self, tool: &str, args_fingerprint: &str) -> ThrottleResult {
69        let now = Instant::now();
70        self.prune_window(now);
71
72        let key = format!("{tool}:{args_fingerprint}");
73        let entries = self.call_history.entry(key.clone()).or_default();
74        entries.push(now);
75        let count = entries.len() as u32;
76        *self.duplicate_counts.entry(key).or_default() = count;
77
78        if count > self.blocked_threshold {
79            return ThrottleResult {
80                level: ThrottleLevel::Blocked,
81                call_count: count,
82                message: Some(self.block_message(tool, count)),
83            };
84        }
85        if count > self.reduced_threshold {
86            return ThrottleResult {
87                level: ThrottleLevel::Reduced,
88                call_count: count,
89                message: Some(format!(
90                    "Warning: {tool} called {count}x with same args. \
91                     Results reduced. Try a different approach or narrow your scope."
92                )),
93            };
94        }
95        if count > self.normal_threshold {
96            return ThrottleResult {
97                level: ThrottleLevel::Reduced,
98                call_count: count,
99                message: Some(format!(
100                    "Note: {tool} called {count}x with similar args. Consider narrowing scope."
101                )),
102            };
103        }
104        ThrottleResult {
105            level: ThrottleLevel::Normal,
106            call_count: count,
107            message: None,
108        }
109    }
110
111    /// Record a search-category call and check the cross-tool search group limit.
112    /// `search_pattern` is the extracted query/regex the agent is looking for (if available).
113    pub fn record_search(
114        &mut self,
115        tool: &str,
116        args_fingerprint: &str,
117        search_pattern: Option<&str>,
118    ) -> ThrottleResult {
119        let now = Instant::now();
120
121        self.search_group_history.push(now);
122        let search_count = self.search_group_history.len() as u32;
123
124        let similar_count = if let Some(pat) = search_pattern {
125            let sc = self.count_similar_patterns(pat);
126            if !pat.is_empty() {
127                self.recent_search_patterns.push(pat.to_string());
128                if self.recent_search_patterns.len() > 15 {
129                    self.recent_search_patterns.remove(0);
130                }
131            }
132            sc
133        } else {
134            0
135        };
136
137        if similar_count >= self.blocked_threshold {
138            return ThrottleResult {
139                level: ThrottleLevel::Blocked,
140                call_count: similar_count,
141                message: Some(self.search_block_message(similar_count)),
142            };
143        }
144
145        if search_count > self.search_group_limit {
146            return ThrottleResult {
147                level: ThrottleLevel::Blocked,
148                call_count: search_count,
149                message: Some(self.search_group_block_message(search_count)),
150            };
151        }
152
153        if similar_count >= self.reduced_threshold {
154            return ThrottleResult {
155                level: ThrottleLevel::Reduced,
156                call_count: similar_count,
157                message: Some(format!(
158                    "Warning: You've searched for similar patterns {similar_count}x. \
159                     Narrow your search with the 'path' parameter or try ctx_tree first."
160                )),
161            };
162        }
163
164        if search_count > self.search_group_limit.saturating_sub(3) {
165            let per_fp = self.record_call(tool, args_fingerprint);
166            if per_fp.level != ThrottleLevel::Normal {
167                return per_fp;
168            }
169            return ThrottleResult {
170                level: ThrottleLevel::Reduced,
171                call_count: search_count,
172                message: Some(format!(
173                    "Note: {search_count} search calls in the last {}s. \
174                     Use ctx_tree to orient first, then scope searches with 'path'.",
175                    self.window.as_secs()
176                )),
177            };
178        }
179
180        self.record_call(tool, args_fingerprint)
181    }
182
183    /// Returns `true` if the tool name is a known search tool (ctx_search, etc.).
184    pub fn is_search_tool(tool: &str) -> bool {
185        SEARCH_TOOLS.contains(&tool)
186    }
187
188    /// Returns `true` if the shell command starts with a search tool (grep, rg, find, etc.).
189    pub fn is_search_shell_command(command: &str) -> bool {
190        let cmd = command.trim_start();
191        SEARCH_SHELL_PREFIXES.iter().any(|p| cmd.starts_with(p))
192    }
193
194    /// Computes a deterministic hash fingerprint of JSON tool arguments.
195    pub fn fingerprint(args: &serde_json::Value) -> String {
196        use std::collections::hash_map::DefaultHasher;
197        use std::hash::{Hash, Hasher};
198
199        let canonical = canonical_json(args);
200        let mut hasher = DefaultHasher::new();
201        canonical.hash(&mut hasher);
202        format!("{:016x}", hasher.finish())
203    }
204
205    /// Returns duplicate call entries sorted by count (descending), filtered to count > 1.
206    pub fn stats(&self) -> Vec<(String, u32)> {
207        let mut entries: Vec<(String, u32)> = self
208            .duplicate_counts
209            .iter()
210            .filter(|(_, &count)| count > 1)
211            .map(|(k, &v)| (k.clone(), v))
212            .collect();
213        entries.sort_by_key(|x| std::cmp::Reverse(x.1));
214        entries
215    }
216
217    /// Clears all tracking state (call history, search patterns, counters).
218    pub fn reset(&mut self) {
219        self.call_history.clear();
220        self.duplicate_counts.clear();
221        self.search_group_history.clear();
222        self.recent_search_patterns.clear();
223    }
224
225    fn prune_window(&mut self, now: Instant) {
226        for entries in self.call_history.values_mut() {
227            entries.retain(|t| now.duration_since(*t) < self.window);
228        }
229        self.search_group_history
230            .retain(|t| now.duration_since(*t) < self.window);
231    }
232
233    fn count_similar_patterns(&self, new_pattern: &str) -> u32 {
234        let new_lower = new_pattern.to_lowercase();
235        let new_root = extract_alpha_root(&new_lower);
236
237        let mut count = 0u32;
238        for existing in &self.recent_search_patterns {
239            let existing_lower = existing.to_lowercase();
240            if patterns_are_similar(&new_lower, &existing_lower) {
241                count += 1;
242            } else if new_root.len() >= 4 {
243                let existing_root = extract_alpha_root(&existing_lower);
244                if existing_root.len() >= 4
245                    && (new_root.starts_with(&existing_root)
246                        || existing_root.starts_with(&new_root))
247                {
248                    count += 1;
249                }
250            }
251        }
252        count
253    }
254
255    fn block_message(&self, tool: &str, count: u32) -> String {
256        if Self::is_search_tool(tool) {
257            self.search_block_message(count)
258        } else {
259            format!(
260                "LOOP DETECTED: {tool} called {count}x with same/similar args. \
261                 Call blocked. Change your approach — the current strategy is not working."
262            )
263        }
264    }
265
266    #[allow(clippy::unused_self)]
267    fn search_block_message(&self, count: u32) -> String {
268        format!(
269            "LOOP DETECTED: You've searched {count}x with similar patterns. STOP searching and change strategy. \
270             1) Use ctx_tree to understand the project structure first. \
271             2) Narrow your search with the 'path' parameter to a specific directory. \
272             3) Use ctx_read with mode='map' to understand a file before searching more."
273        )
274    }
275
276    fn search_group_block_message(&self, count: u32) -> String {
277        format!(
278            "LOOP DETECTED: {count} search calls in {}s — too many. STOP and rethink. \
279             1) Use ctx_tree to map the project structure. \
280             2) Pick ONE specific directory and search there with the 'path' parameter. \
281             3) Read files with ctx_read mode='map' instead of searching blindly.",
282            self.window.as_secs()
283        )
284    }
285}
286
287fn extract_alpha_root(pattern: &str) -> String {
288    pattern
289        .chars()
290        .take_while(|c| c.is_alphanumeric())
291        .collect()
292}
293
294fn patterns_are_similar(a: &str, b: &str) -> bool {
295    if a == b {
296        return true;
297    }
298    if a.contains(b) || b.contains(a) {
299        return true;
300    }
301    let a_alpha: String = a.chars().filter(|c| c.is_alphanumeric()).collect();
302    let b_alpha: String = b.chars().filter(|c| c.is_alphanumeric()).collect();
303    if a_alpha.len() >= 3
304        && b_alpha.len() >= 3
305        && (a_alpha.contains(&b_alpha) || b_alpha.contains(&a_alpha))
306    {
307        return true;
308    }
309    false
310}
311
312fn canonical_json(value: &serde_json::Value) -> String {
313    match value {
314        serde_json::Value::Object(map) => {
315            let mut keys: Vec<&String> = map.keys().collect();
316            keys.sort();
317            let entries: Vec<String> = keys
318                .iter()
319                .map(|k| format!("{}:{}", k, canonical_json(&map[*k])))
320                .collect();
321            format!("{{{}}}", entries.join(","))
322        }
323        serde_json::Value::Array(arr) => {
324            let entries: Vec<String> = arr.iter().map(canonical_json).collect();
325            format!("[{}]", entries.join(","))
326        }
327        _ => value.to_string(),
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    fn test_config(normal: u32, reduced: u32, blocked: u32) -> LoopDetectionConfig {
336        LoopDetectionConfig {
337            normal_threshold: normal,
338            reduced_threshold: reduced,
339            blocked_threshold: blocked,
340            window_secs: 300,
341            search_group_limit: 10,
342        }
343    }
344
345    #[test]
346    fn normal_calls_pass_through() {
347        let mut detector = LoopDetector::new();
348        let r1 = detector.record_call("ctx_read", "abc123");
349        assert_eq!(r1.level, ThrottleLevel::Normal);
350        assert_eq!(r1.call_count, 1);
351        assert!(r1.message.is_none());
352    }
353
354    #[test]
355    fn repeated_calls_trigger_reduced() {
356        let cfg = LoopDetectionConfig::default();
357        let mut detector = LoopDetector::with_config(&cfg);
358        for _ in 0..cfg.normal_threshold {
359            detector.record_call("ctx_read", "same_fp");
360        }
361        let result = detector.record_call("ctx_read", "same_fp");
362        assert_eq!(result.level, ThrottleLevel::Reduced);
363        assert!(result.message.is_some());
364    }
365
366    #[test]
367    fn excessive_calls_get_blocked() {
368        let cfg = LoopDetectionConfig::default();
369        let mut detector = LoopDetector::with_config(&cfg);
370        for _ in 0..cfg.blocked_threshold {
371            detector.record_call("ctx_shell", "same_fp");
372        }
373        let result = detector.record_call("ctx_shell", "same_fp");
374        assert_eq!(result.level, ThrottleLevel::Blocked);
375        assert!(result.message.unwrap().contains("LOOP DETECTED"));
376    }
377
378    #[test]
379    fn different_args_tracked_separately() {
380        let mut detector = LoopDetector::new();
381        for _ in 0..10 {
382            detector.record_call("ctx_read", "fp_a");
383        }
384        let result = detector.record_call("ctx_read", "fp_b");
385        assert_eq!(result.level, ThrottleLevel::Normal);
386        assert_eq!(result.call_count, 1);
387    }
388
389    #[test]
390    fn fingerprint_deterministic() {
391        let args = serde_json::json!({"path": "test.rs", "mode": "full"});
392        let fp1 = LoopDetector::fingerprint(&args);
393        let fp2 = LoopDetector::fingerprint(&args);
394        assert_eq!(fp1, fp2);
395    }
396
397    #[test]
398    fn fingerprint_order_independent() {
399        let a = serde_json::json!({"mode": "full", "path": "test.rs"});
400        let b = serde_json::json!({"path": "test.rs", "mode": "full"});
401        assert_eq!(LoopDetector::fingerprint(&a), LoopDetector::fingerprint(&b));
402    }
403
404    #[test]
405    fn stats_shows_duplicates() {
406        let mut detector = LoopDetector::new();
407        for _ in 0..5 {
408            detector.record_call("ctx_read", "fp_a");
409        }
410        detector.record_call("ctx_shell", "fp_b");
411        let stats = detector.stats();
412        assert_eq!(stats.len(), 1);
413        assert_eq!(stats[0].1, 5);
414    }
415
416    #[test]
417    fn reset_clears_state() {
418        let mut detector = LoopDetector::new();
419        for _ in 0..5 {
420            detector.record_call("ctx_read", "fp_a");
421        }
422        detector.reset();
423        let result = detector.record_call("ctx_read", "fp_a");
424        assert_eq!(result.call_count, 1);
425    }
426
427    #[test]
428    fn custom_thresholds_from_config() {
429        let cfg = test_config(1, 2, 3);
430        let mut detector = LoopDetector::with_config(&cfg);
431        detector.record_call("ctx_read", "fp");
432        let r = detector.record_call("ctx_read", "fp");
433        assert_eq!(r.level, ThrottleLevel::Reduced);
434        detector.record_call("ctx_read", "fp");
435        let r = detector.record_call("ctx_read", "fp");
436        assert_eq!(r.level, ThrottleLevel::Blocked);
437    }
438
439    #[test]
440    fn similar_patterns_detected() {
441        assert!(patterns_are_similar("compress", "compress"));
442        assert!(patterns_are_similar("compress", "compression"));
443        assert!(patterns_are_similar("compress.*data", "compress"));
444        assert!(!patterns_are_similar("foo", "bar"));
445        assert!(!patterns_are_similar("ab", "cd"));
446    }
447
448    #[test]
449    fn search_group_tracking() {
450        let cfg = LoopDetectionConfig {
451            search_group_limit: 5,
452            ..Default::default()
453        };
454        let mut detector = LoopDetector::with_config(&cfg);
455        for i in 0..5 {
456            let fp = format!("fp_{i}");
457            let r = detector.record_search("ctx_search", &fp, Some(&format!("pattern_{i}")));
458            assert_ne!(r.level, ThrottleLevel::Blocked, "call {i} should not block");
459        }
460        let r = detector.record_search("ctx_search", "fp_5", Some("pattern_5"));
461        assert_eq!(r.level, ThrottleLevel::Blocked);
462        assert!(r.message.unwrap().contains("search calls"));
463    }
464
465    #[test]
466    fn similar_search_patterns_trigger_block() {
467        let cfg = LoopDetectionConfig::default();
468        let mut detector = LoopDetector::with_config(&cfg);
469        let variants = [
470            "compress",
471            "compression",
472            "compress.*data",
473            "compress_output",
474            "compressor",
475            "compress_result",
476            "compress_file",
477        ];
478        for (i, pat) in variants
479            .iter()
480            .enumerate()
481            .take(cfg.blocked_threshold as usize)
482        {
483            detector.record_search("ctx_search", &format!("fp_{i}"), Some(pat));
484        }
485        let r = detector.record_search("ctx_search", "fp_new", Some("compress_all"));
486        assert_eq!(r.level, ThrottleLevel::Blocked);
487    }
488
489    #[test]
490    fn is_search_tool_detection() {
491        assert!(LoopDetector::is_search_tool("ctx_search"));
492        assert!(LoopDetector::is_search_tool("ctx_semantic_search"));
493        assert!(!LoopDetector::is_search_tool("ctx_read"));
494        assert!(!LoopDetector::is_search_tool("ctx_shell"));
495    }
496
497    #[test]
498    fn is_search_shell_command_detection() {
499        assert!(LoopDetector::is_search_shell_command("grep -r foo ."));
500        assert!(LoopDetector::is_search_shell_command("rg pattern src/"));
501        assert!(LoopDetector::is_search_shell_command("find . -name '*.rs'"));
502        assert!(!LoopDetector::is_search_shell_command("cargo build"));
503        assert!(!LoopDetector::is_search_shell_command("git status"));
504    }
505
506    #[test]
507    fn search_block_message_has_guidance() {
508        let mut detector = LoopDetector::new();
509        for i in 0..10 {
510            detector.record_search("ctx_search", &format!("fp_{i}"), Some("compress"));
511        }
512        let r = detector.record_search("ctx_search", "fp_new", Some("compress"));
513        let msg = r.message.unwrap();
514        assert!(msg.contains("ctx_tree"));
515        assert!(msg.contains("path"));
516        assert!(msg.contains("ctx_read"));
517    }
518}