Skip to main content

lean_ctx/core/
loop_detection.rs

1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use super::config::LoopDetectionConfig;
5
6const SEARCH_TOOLS: &[&str] = &["ctx_search", "ctx_semantic_search"];
7
8const SEARCH_SHELL_PREFIXES: &[&str] = &["grep ", "rg ", "find ", "fd ", "ag ", "ack "];
9
10/// Tracks repeated tool calls within a time window to detect and throttle agent loops.
11#[derive(Debug, Clone)]
12pub struct LoopDetector {
13    call_history: HashMap<String, Vec<Instant>>,
14    duplicate_counts: HashMap<String, u32>,
15    search_group_history: Vec<Instant>,
16    recent_search_patterns: Vec<String>,
17    normal_threshold: u32,
18    reduced_threshold: u32,
19    blocked_threshold: u32,
20    window: Duration,
21    search_group_limit: u32,
22}
23
24/// Severity of throttling applied to a repeated call: normal, reduced, or blocked.
25#[derive(Debug, Clone, PartialEq)]
26pub enum ThrottleLevel {
27    Normal,
28    Reduced,
29    Blocked,
30}
31
32/// Outcome of a loop detection check: throttle level, count, and optional warning.
33#[derive(Debug, Clone)]
34pub struct ThrottleResult {
35    pub level: ThrottleLevel,
36    pub call_count: u32,
37    pub message: Option<String>,
38}
39
40impl Default for LoopDetector {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl LoopDetector {
47    /// Creates a loop detector with default thresholds.
48    pub fn new() -> Self {
49        Self::with_config(&LoopDetectionConfig::default())
50    }
51
52    /// Creates a loop detector with custom thresholds from config.
53    /// Set blocked_threshold to 0 to disable blocking entirely (LeanCTX philosophy).
54    pub fn with_config(cfg: &LoopDetectionConfig) -> Self {
55        Self {
56            call_history: HashMap::new(),
57            duplicate_counts: HashMap::new(),
58            search_group_history: Vec::new(),
59            recent_search_patterns: Vec::new(),
60            normal_threshold: cfg.normal_threshold.max(1),
61            reduced_threshold: cfg.reduced_threshold.max(2),
62            // 0 = never block (LeanCTX default), >0 = block after N identical calls
63            blocked_threshold: cfg.blocked_threshold,
64            window: Duration::from_secs(cfg.window_secs),
65            search_group_limit: if cfg.blocked_threshold == 0 {
66                u32::MAX
67            } else {
68                cfg.search_group_limit.max(3)
69            },
70        }
71    }
72
73    /// Records a tool call and returns the throttle result based on repetition count.
74    pub fn record_call(&mut self, tool: &str, args_fingerprint: &str) -> ThrottleResult {
75        let now = Instant::now();
76        self.prune_window(now);
77
78        let key = format!("{tool}:{args_fingerprint}");
79        let entries = self.call_history.entry(key.clone()).or_default();
80        entries.push(now);
81        let count = entries.len() as u32;
82        *self.duplicate_counts.entry(key).or_default() = count;
83
84        // blocked_threshold == 0 means blocking is disabled (LeanCTX default)
85        if self.blocked_threshold > 0 && count > self.blocked_threshold {
86            return ThrottleResult {
87                level: ThrottleLevel::Blocked,
88                call_count: count,
89                message: Some(self.block_message(tool, count)),
90            };
91        }
92        if count > self.reduced_threshold {
93            if !crate::core::protocol::meta_visible() {
94                return ThrottleResult {
95                    level: ThrottleLevel::Reduced,
96                    call_count: count,
97                    message: None,
98                };
99            }
100            return ThrottleResult {
101                level: ThrottleLevel::Reduced,
102                call_count: count,
103                message: Some(format!(
104                    "Warning: {tool} called {count}x with same args. \
105                     Results reduced. Try a different approach or narrow your scope."
106                )),
107            };
108        }
109        if count > self.normal_threshold {
110            if !crate::core::protocol::meta_visible() {
111                return ThrottleResult {
112                    level: ThrottleLevel::Reduced,
113                    call_count: count,
114                    message: None,
115                };
116            }
117            return ThrottleResult {
118                level: ThrottleLevel::Reduced,
119                call_count: count,
120                message: Some(format!(
121                    "Note: {tool} called {count}x with similar args. Consider narrowing scope."
122                )),
123            };
124        }
125        ThrottleResult {
126            level: ThrottleLevel::Normal,
127            call_count: count,
128            message: None,
129        }
130    }
131
132    /// Record a search-category call and check the cross-tool search group limit.
133    /// `search_pattern` is the extracted query/regex the agent is looking for (if available).
134    pub fn record_search(
135        &mut self,
136        tool: &str,
137        args_fingerprint: &str,
138        search_pattern: Option<&str>,
139    ) -> ThrottleResult {
140        let now = Instant::now();
141
142        self.search_group_history.push(now);
143        let search_count = self.search_group_history.len() as u32;
144
145        let similar_count = if let Some(pat) = search_pattern {
146            let sc = self.count_similar_patterns(pat);
147            if !pat.is_empty() {
148                self.recent_search_patterns.push(pat.to_string());
149                if self.recent_search_patterns.len() > 15 {
150                    self.recent_search_patterns.remove(0);
151                }
152            }
153            sc
154        } else {
155            0
156        };
157
158        // blocked_threshold == 0 means blocking is disabled (LeanCTX default)
159        if self.blocked_threshold > 0 && similar_count >= self.blocked_threshold {
160            return ThrottleResult {
161                level: ThrottleLevel::Blocked,
162                call_count: similar_count,
163                message: Some(self.search_block_message(similar_count)),
164            };
165        }
166
167        // search_group_limit == u32::MAX when blocking is disabled
168        if self.blocked_threshold > 0 && search_count > self.search_group_limit {
169            return ThrottleResult {
170                level: ThrottleLevel::Blocked,
171                call_count: search_count,
172                message: Some(self.search_group_block_message(search_count)),
173            };
174        }
175
176        if similar_count >= self.reduced_threshold {
177            if !crate::core::protocol::meta_visible() {
178                return ThrottleResult {
179                    level: ThrottleLevel::Reduced,
180                    call_count: similar_count,
181                    message: None,
182                };
183            }
184            return ThrottleResult {
185                level: ThrottleLevel::Reduced,
186                call_count: similar_count,
187                message: Some(format!(
188                    "Warning: You've searched for similar patterns {similar_count}x. \
189                     Narrow your search with the 'path' parameter or try ctx_tree first."
190                )),
191            };
192        }
193
194        if search_count > self.search_group_limit.saturating_sub(3) {
195            let per_fp = self.record_call(tool, args_fingerprint);
196            if per_fp.level != ThrottleLevel::Normal {
197                return per_fp;
198            }
199            if !crate::core::protocol::meta_visible() {
200                return ThrottleResult {
201                    level: ThrottleLevel::Reduced,
202                    call_count: search_count,
203                    message: None,
204                };
205            }
206            return ThrottleResult {
207                level: ThrottleLevel::Reduced,
208                call_count: search_count,
209                message: Some(format!(
210                    "Note: {search_count} search calls in the last {}s. \
211                     Use ctx_tree to orient first, then scope searches with 'path'.",
212                    self.window.as_secs()
213                )),
214            };
215        }
216
217        self.record_call(tool, args_fingerprint)
218    }
219
220    /// Returns `true` if the tool name is a known search tool (ctx_search, etc.).
221    pub fn is_search_tool(tool: &str) -> bool {
222        SEARCH_TOOLS.contains(&tool)
223    }
224
225    /// Returns `true` if the shell command starts with a search tool (grep, rg, find, etc.).
226    pub fn is_search_shell_command(command: &str) -> bool {
227        let cmd = command.trim_start();
228        SEARCH_SHELL_PREFIXES.iter().any(|p| cmd.starts_with(p))
229    }
230
231    /// Computes a deterministic hash fingerprint of JSON tool arguments.
232    pub fn fingerprint(args: &serde_json::Value) -> String {
233        use std::collections::hash_map::DefaultHasher;
234        use std::hash::{Hash, Hasher};
235
236        let canonical = canonical_json(args);
237        let mut hasher = DefaultHasher::new();
238        canonical.hash(&mut hasher);
239        format!("{:016x}", hasher.finish())
240    }
241
242    /// Returns duplicate call entries sorted by count (descending), filtered to count > 1.
243    pub fn stats(&self) -> Vec<(String, u32)> {
244        let mut entries: Vec<(String, u32)> = self
245            .duplicate_counts
246            .iter()
247            .filter(|(_, &count)| count > 1)
248            .map(|(k, &v)| (k.clone(), v))
249            .collect();
250        entries.sort_by_key(|x| std::cmp::Reverse(x.1));
251        entries
252    }
253
254    /// Clears all tracking state (call history, search patterns, counters).
255    pub fn reset(&mut self) {
256        self.call_history.clear();
257        self.duplicate_counts.clear();
258        self.search_group_history.clear();
259        self.recent_search_patterns.clear();
260    }
261
262    fn prune_window(&mut self, now: Instant) {
263        for entries in self.call_history.values_mut() {
264            entries.retain(|t| now.duration_since(*t) < self.window);
265        }
266        self.search_group_history
267            .retain(|t| now.duration_since(*t) < self.window);
268    }
269
270    fn count_similar_patterns(&self, new_pattern: &str) -> u32 {
271        let new_lower = new_pattern.to_lowercase();
272        let new_root = extract_alpha_root(&new_lower);
273
274        let mut count = 0u32;
275        for existing in &self.recent_search_patterns {
276            let existing_lower = existing.to_lowercase();
277            if patterns_are_similar(&new_lower, &existing_lower) {
278                count += 1;
279            } else if new_root.len() >= 4 {
280                let existing_root = extract_alpha_root(&existing_lower);
281                if existing_root.len() >= 4
282                    && (new_root.starts_with(&existing_root)
283                        || existing_root.starts_with(&new_root))
284                {
285                    count += 1;
286                }
287            }
288        }
289        count
290    }
291
292    fn block_message(&self, tool: &str, count: u32) -> String {
293        if Self::is_search_tool(tool) {
294            self.search_block_message(count)
295        } else {
296            format!(
297                "LOOP DETECTED: {tool} called {count}x with same/similar args. \
298                 Call blocked. Change your approach — the current strategy is not working."
299            )
300        }
301    }
302
303    #[allow(clippy::unused_self)]
304    fn search_block_message(&self, count: u32) -> String {
305        format!(
306            "LOOP DETECTED: You've searched {count}x with similar patterns. STOP searching and change strategy. \
307             1) Use ctx_tree to understand the project structure first. \
308             2) Narrow your search with the 'path' parameter to a specific directory. \
309             3) Use ctx_read with mode='map' to understand a file before searching more."
310        )
311    }
312
313    fn search_group_block_message(&self, count: u32) -> String {
314        format!(
315            "LOOP DETECTED: {count} search calls in {}s — too many. STOP and rethink. \
316             1) Use ctx_tree to map the project structure. \
317             2) Pick ONE specific directory and search there with the 'path' parameter. \
318             3) Read files with ctx_read mode='map' instead of searching blindly.",
319            self.window.as_secs()
320        )
321    }
322}
323
324fn extract_alpha_root(pattern: &str) -> String {
325    pattern
326        .chars()
327        .take_while(|c| c.is_alphanumeric())
328        .collect()
329}
330
331fn patterns_are_similar(a: &str, b: &str) -> bool {
332    if a == b {
333        return true;
334    }
335    if a.contains(b) || b.contains(a) {
336        return true;
337    }
338    let a_alpha: String = a.chars().filter(|c| c.is_alphanumeric()).collect();
339    let b_alpha: String = b.chars().filter(|c| c.is_alphanumeric()).collect();
340    if a_alpha.len() >= 3
341        && b_alpha.len() >= 3
342        && (a_alpha.contains(&b_alpha) || b_alpha.contains(&a_alpha))
343    {
344        return true;
345    }
346    false
347}
348
349fn canonical_json(value: &serde_json::Value) -> String {
350    match value {
351        serde_json::Value::Object(map) => {
352            let mut keys: Vec<&String> = map.keys().collect();
353            keys.sort();
354            let entries: Vec<String> = keys
355                .iter()
356                .map(|k| format!("{}:{}", k, canonical_json(&map[*k])))
357                .collect();
358            format!("{{{}}}", entries.join(","))
359        }
360        serde_json::Value::Array(arr) => {
361            let entries: Vec<String> = arr.iter().map(canonical_json).collect();
362            format!("[{}]", entries.join(","))
363        }
364        _ => value.to_string(),
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    fn test_config(normal: u32, reduced: u32, blocked: u32) -> LoopDetectionConfig {
373        LoopDetectionConfig {
374            normal_threshold: normal,
375            reduced_threshold: reduced,
376            blocked_threshold: blocked,
377            window_secs: 300,
378            search_group_limit: 10,
379        }
380    }
381
382    #[test]
383    fn normal_calls_pass_through() {
384        let mut detector = LoopDetector::new();
385        let r1 = detector.record_call("ctx_read", "abc123");
386        assert_eq!(r1.level, ThrottleLevel::Normal);
387        assert_eq!(r1.call_count, 1);
388        assert!(r1.message.is_none());
389    }
390
391    #[test]
392    fn repeated_calls_trigger_reduced() {
393        let _lock = crate::core::data_dir::test_env_lock();
394        std::env::set_var("LEAN_CTX_META", "1");
395        let cfg = LoopDetectionConfig::default();
396        let mut detector = LoopDetector::with_config(&cfg);
397        for _ in 0..cfg.normal_threshold {
398            detector.record_call("ctx_read", "same_fp");
399        }
400        let result = detector.record_call("ctx_read", "same_fp");
401        assert_eq!(result.level, ThrottleLevel::Reduced);
402        assert!(result.message.is_some());
403        std::env::remove_var("LEAN_CTX_META");
404    }
405
406    #[test]
407    fn excessive_calls_get_blocked_when_enabled() {
408        // Blocking must be explicitly enabled (blocked_threshold > 0)
409        let cfg = LoopDetectionConfig {
410            blocked_threshold: 6,
411            ..Default::default()
412        };
413        let mut detector = LoopDetector::with_config(&cfg);
414        for _ in 0..cfg.blocked_threshold {
415            detector.record_call("ctx_shell", "same_fp");
416        }
417        let result = detector.record_call("ctx_shell", "same_fp");
418        assert_eq!(result.level, ThrottleLevel::Blocked);
419        assert!(result.message.unwrap().contains("LOOP DETECTED"));
420    }
421
422    #[test]
423    fn blocking_disabled_by_default() {
424        // Default config has blocked_threshold = 0, so blocking never happens
425        let cfg = LoopDetectionConfig::default();
426        assert_eq!(cfg.blocked_threshold, 0);
427        let mut detector = LoopDetector::with_config(&cfg);
428        // Even 100 calls should not block when blocking is disabled
429        for _ in 0..100 {
430            detector.record_call("ctx_shell", "same_fp");
431        }
432        let result = detector.record_call("ctx_shell", "same_fp");
433        // Should be Reduced (warning) but never Blocked
434        assert_ne!(result.level, ThrottleLevel::Blocked);
435    }
436
437    #[test]
438    fn different_args_tracked_separately() {
439        let mut detector = LoopDetector::new();
440        for _ in 0..10 {
441            detector.record_call("ctx_read", "fp_a");
442        }
443        let result = detector.record_call("ctx_read", "fp_b");
444        assert_eq!(result.level, ThrottleLevel::Normal);
445        assert_eq!(result.call_count, 1);
446    }
447
448    #[test]
449    fn fingerprint_deterministic() {
450        let args = serde_json::json!({"path": "test.rs", "mode": "full"});
451        let fp1 = LoopDetector::fingerprint(&args);
452        let fp2 = LoopDetector::fingerprint(&args);
453        assert_eq!(fp1, fp2);
454    }
455
456    #[test]
457    fn fingerprint_order_independent() {
458        let a = serde_json::json!({"mode": "full", "path": "test.rs"});
459        let b = serde_json::json!({"path": "test.rs", "mode": "full"});
460        assert_eq!(LoopDetector::fingerprint(&a), LoopDetector::fingerprint(&b));
461    }
462
463    #[test]
464    fn stats_shows_duplicates() {
465        let mut detector = LoopDetector::new();
466        for _ in 0..5 {
467            detector.record_call("ctx_read", "fp_a");
468        }
469        detector.record_call("ctx_shell", "fp_b");
470        let stats = detector.stats();
471        assert_eq!(stats.len(), 1);
472        assert_eq!(stats[0].1, 5);
473    }
474
475    #[test]
476    fn reset_clears_state() {
477        let mut detector = LoopDetector::new();
478        for _ in 0..5 {
479            detector.record_call("ctx_read", "fp_a");
480        }
481        detector.reset();
482        let result = detector.record_call("ctx_read", "fp_a");
483        assert_eq!(result.call_count, 1);
484    }
485
486    #[test]
487    fn custom_thresholds_from_config() {
488        let cfg = test_config(1, 2, 3);
489        let mut detector = LoopDetector::with_config(&cfg);
490        detector.record_call("ctx_read", "fp");
491        let r = detector.record_call("ctx_read", "fp");
492        assert_eq!(r.level, ThrottleLevel::Reduced);
493        detector.record_call("ctx_read", "fp");
494        let r = detector.record_call("ctx_read", "fp");
495        assert_eq!(r.level, ThrottleLevel::Blocked);
496    }
497
498    #[test]
499    fn similar_patterns_detected() {
500        assert!(patterns_are_similar("compress", "compress"));
501        assert!(patterns_are_similar("compress", "compression"));
502        assert!(patterns_are_similar("compress.*data", "compress"));
503        assert!(!patterns_are_similar("foo", "bar"));
504        assert!(!patterns_are_similar("ab", "cd"));
505    }
506
507    #[test]
508    fn search_group_tracking_when_blocking_enabled() {
509        // Blocking must be explicitly enabled for search group limits to block
510        let cfg = LoopDetectionConfig {
511            search_group_limit: 5,
512            blocked_threshold: 6, // Enable blocking
513            ..Default::default()
514        };
515        let mut detector = LoopDetector::with_config(&cfg);
516        for i in 0..5 {
517            let fp = format!("fp_{i}");
518            let r = detector.record_search("ctx_search", &fp, Some(&format!("pattern_{i}")));
519            assert_ne!(r.level, ThrottleLevel::Blocked, "call {i} should not block");
520        }
521        let r = detector.record_search("ctx_search", "fp_5", Some("pattern_5"));
522        assert_eq!(r.level, ThrottleLevel::Blocked);
523        assert!(r.message.unwrap().contains("search calls"));
524    }
525
526    #[test]
527    fn similar_search_patterns_trigger_block_when_enabled() {
528        // Blocking must be explicitly enabled
529        let cfg = LoopDetectionConfig {
530            blocked_threshold: 6,
531            ..Default::default()
532        };
533        let mut detector = LoopDetector::with_config(&cfg);
534        let variants = [
535            "compress",
536            "compression",
537            "compress.*data",
538            "compress_output",
539            "compressor",
540            "compress_result",
541            "compress_file",
542        ];
543        for (i, pat) in variants
544            .iter()
545            .enumerate()
546            .take(cfg.blocked_threshold as usize)
547        {
548            detector.record_search("ctx_search", &format!("fp_{i}"), Some(pat));
549        }
550        let r = detector.record_search("ctx_search", "fp_new", Some("compress_all"));
551        assert_eq!(r.level, ThrottleLevel::Blocked);
552    }
553
554    #[test]
555    fn is_search_tool_detection() {
556        assert!(LoopDetector::is_search_tool("ctx_search"));
557        assert!(LoopDetector::is_search_tool("ctx_semantic_search"));
558        assert!(!LoopDetector::is_search_tool("ctx_read"));
559        assert!(!LoopDetector::is_search_tool("ctx_shell"));
560    }
561
562    #[test]
563    fn is_search_shell_command_detection() {
564        assert!(LoopDetector::is_search_shell_command("grep -r foo ."));
565        assert!(LoopDetector::is_search_shell_command("rg pattern src/"));
566        assert!(LoopDetector::is_search_shell_command("find . -name '*.rs'"));
567        assert!(!LoopDetector::is_search_shell_command("cargo build"));
568        assert!(!LoopDetector::is_search_shell_command("git status"));
569    }
570
571    #[test]
572    fn search_block_message_has_guidance_when_blocking_enabled() {
573        // Blocking must be explicitly enabled to get block messages
574        let cfg = LoopDetectionConfig {
575            blocked_threshold: 6,
576            search_group_limit: 8,
577            ..Default::default()
578        };
579        let mut detector = LoopDetector::with_config(&cfg);
580        for i in 0..10 {
581            detector.record_search("ctx_search", &format!("fp_{i}"), Some("compress"));
582        }
583        let r = detector.record_search("ctx_search", "fp_new", Some("compress"));
584        assert_eq!(r.level, ThrottleLevel::Blocked);
585        let msg = r.message.unwrap();
586        assert!(msg.contains("ctx_tree"));
587        assert!(msg.contains("path"));
588        assert!(msg.contains("ctx_read"));
589    }
590}