Skip to main content

lean_ctx/core/
loop_detection.rs

1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use super::config::LoopDetectionConfig;
5
6const SEARCH_TOOLS: &[&str] = &["ctx_search", "ctx_semantic_search"];
7
8const SEARCH_SHELL_PREFIXES: &[&str] = &["grep ", "rg ", "find ", "fd ", "ag ", "ack "];
9
10/// Tracks repeated tool calls within a time window to detect and throttle agent loops.
11#[derive(Debug, Clone)]
12pub struct LoopDetector {
13    call_history: HashMap<String, Vec<Instant>>,
14    duplicate_counts: HashMap<String, u32>,
15    tool_total_counts: HashMap<String, u32>,
16    tool_total_limits: HashMap<String, u32>,
17    search_group_history: Vec<Instant>,
18    recent_search_patterns: Vec<String>,
19    normal_threshold: u32,
20    reduced_threshold: u32,
21    blocked_threshold: u32,
22    window: Duration,
23    search_group_limit: u32,
24}
25
26/// Severity of throttling applied to a repeated call: normal, reduced, or blocked.
27#[derive(Debug, Clone, PartialEq)]
28pub enum ThrottleLevel {
29    Normal,
30    Reduced,
31    Blocked,
32}
33
34/// Outcome of a loop detection check: throttle level, count, and optional warning.
35#[derive(Debug, Clone)]
36pub struct ThrottleResult {
37    pub level: ThrottleLevel,
38    pub call_count: u32,
39    pub message: Option<String>,
40}
41
42impl Default for LoopDetector {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl LoopDetector {
49    /// Creates a loop detector with default thresholds.
50    pub fn new() -> Self {
51        Self::with_config(&LoopDetectionConfig::default())
52    }
53
54    /// Creates a loop detector with custom thresholds from config.
55    /// Set blocked_threshold to 0 to disable blocking entirely (LeanCTX philosophy).
56    pub fn with_config(cfg: &LoopDetectionConfig) -> Self {
57        Self {
58            call_history: HashMap::new(),
59            duplicate_counts: HashMap::new(),
60            tool_total_counts: HashMap::new(),
61            tool_total_limits: cfg.tool_total_limits.clone(),
62            search_group_history: Vec::new(),
63            recent_search_patterns: Vec::new(),
64            normal_threshold: cfg.normal_threshold.max(1),
65            reduced_threshold: cfg.reduced_threshold.max(2),
66            blocked_threshold: cfg.blocked_threshold,
67            window: Duration::from_secs(cfg.window_secs),
68            search_group_limit: if cfg.blocked_threshold == 0 {
69                u32::MAX
70            } else {
71                cfg.search_group_limit.max(3)
72            },
73        }
74    }
75
76    /// Records a tool call and returns the throttle result based on repetition count.
77    pub fn record_call(&mut self, tool: &str, args_fingerprint: &str) -> ThrottleResult {
78        let now = Instant::now();
79        self.prune_window(now);
80
81        // Per-tool total count (regardless of args)
82        let total = self.tool_total_counts.entry(tool.to_string()).or_insert(0);
83        *total += 1;
84        let total_count = *total;
85
86        if let Some(&limit) = self.tool_total_limits.get(tool) {
87            if total_count > limit {
88                let msg = if crate::core::protocol::meta_visible() {
89                    Some(format!(
90                        "Warning: {tool} called {total_count}x total (limit: {limit}). \
91                         Consider ctx_compress or narrowing scope."
92                    ))
93                } else {
94                    None
95                };
96                return ThrottleResult {
97                    level: ThrottleLevel::Reduced,
98                    call_count: total_count,
99                    message: msg,
100                };
101            }
102        }
103
104        let key = format!("{tool}:{args_fingerprint}");
105        let entries = self.call_history.entry(key.clone()).or_default();
106        entries.push(now);
107        let count = entries.len() as u32;
108        *self.duplicate_counts.entry(key).or_default() = count;
109
110        if self.blocked_threshold > 0 && count > self.blocked_threshold {
111            return ThrottleResult {
112                level: ThrottleLevel::Blocked,
113                call_count: count,
114                message: Some(self.block_message(tool, count)),
115            };
116        }
117        if count > self.reduced_threshold {
118            if !crate::core::protocol::meta_visible() {
119                return ThrottleResult {
120                    level: ThrottleLevel::Reduced,
121                    call_count: count,
122                    message: None,
123                };
124            }
125            return ThrottleResult {
126                level: ThrottleLevel::Reduced,
127                call_count: count,
128                message: Some(format!(
129                    "Warning: {tool} called {count}x with same args. \
130                     Results reduced. Try a different approach or narrow your scope."
131                )),
132            };
133        }
134        if count > self.normal_threshold {
135            if !crate::core::protocol::meta_visible() {
136                return ThrottleResult {
137                    level: ThrottleLevel::Reduced,
138                    call_count: count,
139                    message: None,
140                };
141            }
142            return ThrottleResult {
143                level: ThrottleLevel::Reduced,
144                call_count: count,
145                message: Some(format!(
146                    "Note: {tool} called {count}x with similar args. Consider narrowing scope."
147                )),
148            };
149        }
150        ThrottleResult {
151            level: ThrottleLevel::Normal,
152            call_count: count,
153            message: None,
154        }
155    }
156
157    /// Record a search-category call and check the cross-tool search group limit.
158    /// `search_pattern` is the extracted query/regex the agent is looking for (if available).
159    pub fn record_search(
160        &mut self,
161        tool: &str,
162        args_fingerprint: &str,
163        search_pattern: Option<&str>,
164    ) -> ThrottleResult {
165        let now = Instant::now();
166
167        self.search_group_history.push(now);
168        let search_count = self.search_group_history.len() as u32;
169
170        let similar_count = if let Some(pat) = search_pattern {
171            let sc = self.count_similar_patterns(pat);
172            if !pat.is_empty() {
173                self.recent_search_patterns.push(pat.to_string());
174                if self.recent_search_patterns.len() > 15 {
175                    self.recent_search_patterns.remove(0);
176                }
177            }
178            sc
179        } else {
180            0
181        };
182
183        // blocked_threshold == 0 means blocking is disabled (LeanCTX default)
184        if self.blocked_threshold > 0 && similar_count >= self.blocked_threshold {
185            return ThrottleResult {
186                level: ThrottleLevel::Blocked,
187                call_count: similar_count,
188                message: Some(self.search_block_message(similar_count)),
189            };
190        }
191
192        // search_group_limit == u32::MAX when blocking is disabled
193        if self.blocked_threshold > 0 && search_count > self.search_group_limit {
194            return ThrottleResult {
195                level: ThrottleLevel::Blocked,
196                call_count: search_count,
197                message: Some(self.search_group_block_message(search_count)),
198            };
199        }
200
201        if similar_count >= self.reduced_threshold {
202            if !crate::core::protocol::meta_visible() {
203                return ThrottleResult {
204                    level: ThrottleLevel::Reduced,
205                    call_count: similar_count,
206                    message: None,
207                };
208            }
209            return ThrottleResult {
210                level: ThrottleLevel::Reduced,
211                call_count: similar_count,
212                message: Some(format!(
213                    "Warning: You've searched for similar patterns {similar_count}x. \
214                     Narrow your search with the 'path' parameter or try ctx_tree first."
215                )),
216            };
217        }
218
219        if search_count > self.search_group_limit.saturating_sub(3) {
220            let per_fp = self.record_call(tool, args_fingerprint);
221            if per_fp.level != ThrottleLevel::Normal {
222                return per_fp;
223            }
224            if !crate::core::protocol::meta_visible() {
225                return ThrottleResult {
226                    level: ThrottleLevel::Reduced,
227                    call_count: search_count,
228                    message: None,
229                };
230            }
231            return ThrottleResult {
232                level: ThrottleLevel::Reduced,
233                call_count: search_count,
234                message: Some(format!(
235                    "Note: {search_count} search calls in the last {}s. \
236                     Use ctx_tree to orient first, then scope searches with 'path'.",
237                    self.window.as_secs()
238                )),
239            };
240        }
241
242        self.record_call(tool, args_fingerprint)
243    }
244
245    /// Returns `true` if the tool name is a known search tool (ctx_search, etc.).
246    pub fn is_search_tool(tool: &str) -> bool {
247        SEARCH_TOOLS.contains(&tool)
248    }
249
250    /// Returns `true` if the shell command starts with a search tool (grep, rg, find, etc.).
251    pub fn is_search_shell_command(command: &str) -> bool {
252        let cmd = command.trim_start();
253        SEARCH_SHELL_PREFIXES.iter().any(|p| cmd.starts_with(p))
254    }
255
256    /// Computes a deterministic hash fingerprint of JSON tool arguments.
257    pub fn fingerprint(args: &serde_json::Value) -> String {
258        use std::collections::hash_map::DefaultHasher;
259        use std::hash::{Hash, Hasher};
260
261        let canonical = canonical_json(args);
262        let mut hasher = DefaultHasher::new();
263        canonical.hash(&mut hasher);
264        format!("{:016x}", hasher.finish())
265    }
266
267    /// Returns duplicate call entries sorted by count (descending), filtered to count > 1.
268    pub fn stats(&self) -> Vec<(String, u32)> {
269        let mut entries: Vec<(String, u32)> = self
270            .duplicate_counts
271            .iter()
272            .filter(|(_, &count)| count > 1)
273            .map(|(k, &v)| (k.clone(), v))
274            .collect();
275        entries.sort_by_key(|x| std::cmp::Reverse(x.1));
276        entries
277    }
278
279    /// Clears all tracking state (call history, search patterns, counters).
280    pub fn reset(&mut self) {
281        self.call_history.clear();
282        self.duplicate_counts.clear();
283        self.search_group_history.clear();
284        self.recent_search_patterns.clear();
285    }
286
287    fn prune_window(&mut self, now: Instant) {
288        for entries in self.call_history.values_mut() {
289            entries.retain(|t| now.duration_since(*t) < self.window);
290        }
291        self.search_group_history
292            .retain(|t| now.duration_since(*t) < self.window);
293    }
294
295    fn count_similar_patterns(&self, new_pattern: &str) -> u32 {
296        let new_lower = new_pattern.to_lowercase();
297        let new_root = extract_alpha_root(&new_lower);
298
299        let mut count = 0u32;
300        for existing in &self.recent_search_patterns {
301            let existing_lower = existing.to_lowercase();
302            if patterns_are_similar(&new_lower, &existing_lower) {
303                count += 1;
304            } else if new_root.len() >= 4 {
305                let existing_root = extract_alpha_root(&existing_lower);
306                if existing_root.len() >= 4
307                    && (new_root.starts_with(&existing_root)
308                        || existing_root.starts_with(&new_root))
309                {
310                    count += 1;
311                }
312            }
313        }
314        count
315    }
316
317    fn block_message(&self, tool: &str, count: u32) -> String {
318        if Self::is_search_tool(tool) {
319            self.search_block_message(count)
320        } else {
321            format!(
322                "LOOP DETECTED: {tool} called {count}x with same/similar args. \
323                 Call blocked. Change your approach — the current strategy is not working."
324            )
325        }
326    }
327
328    #[allow(clippy::unused_self)]
329    fn search_block_message(&self, count: u32) -> String {
330        format!(
331            "LOOP DETECTED: You've searched {count}x with similar patterns. STOP searching and change strategy. \
332             1) Use ctx_tree to understand the project structure first. \
333             2) Narrow your search with the 'path' parameter to a specific directory. \
334             3) Use ctx_read with mode='map' to understand a file before searching more."
335        )
336    }
337
338    fn search_group_block_message(&self, count: u32) -> String {
339        format!(
340            "LOOP DETECTED: {count} search calls in {}s — too many. STOP and rethink. \
341             1) Use ctx_tree to map the project structure. \
342             2) Pick ONE specific directory and search there with the 'path' parameter. \
343             3) Read files with ctx_read mode='map' instead of searching blindly.",
344            self.window.as_secs()
345        )
346    }
347}
348
349fn extract_alpha_root(pattern: &str) -> String {
350    pattern
351        .chars()
352        .take_while(|c| c.is_alphanumeric())
353        .collect()
354}
355
356fn patterns_are_similar(a: &str, b: &str) -> bool {
357    if a == b {
358        return true;
359    }
360    if a.contains(b) || b.contains(a) {
361        return true;
362    }
363    let a_alpha: String = a.chars().filter(|c| c.is_alphanumeric()).collect();
364    let b_alpha: String = b.chars().filter(|c| c.is_alphanumeric()).collect();
365    if a_alpha.len() >= 3
366        && b_alpha.len() >= 3
367        && (a_alpha.contains(&b_alpha) || b_alpha.contains(&a_alpha))
368    {
369        return true;
370    }
371    false
372}
373
374fn canonical_json(value: &serde_json::Value) -> String {
375    match value {
376        serde_json::Value::Object(map) => {
377            let mut keys: Vec<&String> = map.keys().collect();
378            keys.sort();
379            let entries: Vec<String> = keys
380                .iter()
381                .map(|k| format!("{}:{}", k, canonical_json(&map[*k])))
382                .collect();
383            format!("{{{}}}", entries.join(","))
384        }
385        serde_json::Value::Array(arr) => {
386            let entries: Vec<String> = arr.iter().map(canonical_json).collect();
387            format!("[{}]", entries.join(","))
388        }
389        _ => value.to_string(),
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    fn test_config(normal: u32, reduced: u32, blocked: u32) -> LoopDetectionConfig {
398        LoopDetectionConfig {
399            normal_threshold: normal,
400            reduced_threshold: reduced,
401            blocked_threshold: blocked,
402            window_secs: 300,
403            search_group_limit: 10,
404            tool_total_limits: std::collections::HashMap::new(),
405        }
406    }
407
408    #[test]
409    fn normal_calls_pass_through() {
410        let mut detector = LoopDetector::new();
411        let r1 = detector.record_call("ctx_read", "abc123");
412        assert_eq!(r1.level, ThrottleLevel::Normal);
413        assert_eq!(r1.call_count, 1);
414        assert!(r1.message.is_none());
415    }
416
417    #[test]
418    fn repeated_calls_trigger_reduced() {
419        let _lock = crate::core::data_dir::test_env_lock();
420        std::env::set_var("LEAN_CTX_META", "1");
421        let cfg = LoopDetectionConfig::default();
422        let mut detector = LoopDetector::with_config(&cfg);
423        for _ in 0..cfg.normal_threshold {
424            detector.record_call("ctx_read", "same_fp");
425        }
426        let result = detector.record_call("ctx_read", "same_fp");
427        assert_eq!(result.level, ThrottleLevel::Reduced);
428        assert!(result.message.is_some());
429        std::env::remove_var("LEAN_CTX_META");
430    }
431
432    #[test]
433    fn excessive_calls_get_blocked_when_enabled() {
434        // Blocking must be explicitly enabled (blocked_threshold > 0)
435        let cfg = LoopDetectionConfig {
436            blocked_threshold: 6,
437            ..Default::default()
438        };
439        let mut detector = LoopDetector::with_config(&cfg);
440        for _ in 0..cfg.blocked_threshold {
441            detector.record_call("ctx_shell", "same_fp");
442        }
443        let result = detector.record_call("ctx_shell", "same_fp");
444        assert_eq!(result.level, ThrottleLevel::Blocked);
445        assert!(result.message.unwrap().contains("LOOP DETECTED"));
446    }
447
448    #[test]
449    fn blocking_disabled_by_default() {
450        // Default config has blocked_threshold = 0, so blocking never happens
451        let cfg = LoopDetectionConfig::default();
452        assert_eq!(cfg.blocked_threshold, 0);
453        let mut detector = LoopDetector::with_config(&cfg);
454        // Even 100 calls should not block when blocking is disabled
455        for _ in 0..100 {
456            detector.record_call("ctx_shell", "same_fp");
457        }
458        let result = detector.record_call("ctx_shell", "same_fp");
459        // Should be Reduced (warning) but never Blocked
460        assert_ne!(result.level, ThrottleLevel::Blocked);
461    }
462
463    #[test]
464    fn different_args_tracked_separately() {
465        let mut detector = LoopDetector::new();
466        for _ in 0..10 {
467            detector.record_call("ctx_read", "fp_a");
468        }
469        let result = detector.record_call("ctx_read", "fp_b");
470        assert_eq!(result.level, ThrottleLevel::Normal);
471        assert_eq!(result.call_count, 1);
472    }
473
474    #[test]
475    fn fingerprint_deterministic() {
476        let args = serde_json::json!({"path": "test.rs", "mode": "full"});
477        let fp1 = LoopDetector::fingerprint(&args);
478        let fp2 = LoopDetector::fingerprint(&args);
479        assert_eq!(fp1, fp2);
480    }
481
482    #[test]
483    fn fingerprint_order_independent() {
484        let a = serde_json::json!({"mode": "full", "path": "test.rs"});
485        let b = serde_json::json!({"path": "test.rs", "mode": "full"});
486        assert_eq!(LoopDetector::fingerprint(&a), LoopDetector::fingerprint(&b));
487    }
488
489    #[test]
490    fn stats_shows_duplicates() {
491        let mut detector = LoopDetector::new();
492        for _ in 0..5 {
493            detector.record_call("ctx_read", "fp_a");
494        }
495        detector.record_call("ctx_shell", "fp_b");
496        let stats = detector.stats();
497        assert_eq!(stats.len(), 1);
498        assert_eq!(stats[0].1, 5);
499    }
500
501    #[test]
502    fn reset_clears_state() {
503        let mut detector = LoopDetector::new();
504        for _ in 0..5 {
505            detector.record_call("ctx_read", "fp_a");
506        }
507        detector.reset();
508        let result = detector.record_call("ctx_read", "fp_a");
509        assert_eq!(result.call_count, 1);
510    }
511
512    #[test]
513    fn custom_thresholds_from_config() {
514        let cfg = test_config(1, 2, 3);
515        let mut detector = LoopDetector::with_config(&cfg);
516        detector.record_call("ctx_read", "fp");
517        let r = detector.record_call("ctx_read", "fp");
518        assert_eq!(r.level, ThrottleLevel::Reduced);
519        detector.record_call("ctx_read", "fp");
520        let r = detector.record_call("ctx_read", "fp");
521        assert_eq!(r.level, ThrottleLevel::Blocked);
522    }
523
524    #[test]
525    fn similar_patterns_detected() {
526        assert!(patterns_are_similar("compress", "compress"));
527        assert!(patterns_are_similar("compress", "compression"));
528        assert!(patterns_are_similar("compress.*data", "compress"));
529        assert!(!patterns_are_similar("foo", "bar"));
530        assert!(!patterns_are_similar("ab", "cd"));
531    }
532
533    #[test]
534    fn search_group_tracking_when_blocking_enabled() {
535        // Blocking must be explicitly enabled for search group limits to block
536        let cfg = LoopDetectionConfig {
537            search_group_limit: 5,
538            blocked_threshold: 6, // Enable blocking
539            ..Default::default()
540        };
541        let mut detector = LoopDetector::with_config(&cfg);
542        for i in 0..5 {
543            let fp = format!("fp_{i}");
544            let r = detector.record_search("ctx_search", &fp, Some(&format!("pattern_{i}")));
545            assert_ne!(r.level, ThrottleLevel::Blocked, "call {i} should not block");
546        }
547        let r = detector.record_search("ctx_search", "fp_5", Some("pattern_5"));
548        assert_eq!(r.level, ThrottleLevel::Blocked);
549        assert!(r.message.unwrap().contains("search calls"));
550    }
551
552    #[test]
553    fn similar_search_patterns_trigger_block_when_enabled() {
554        // Blocking must be explicitly enabled
555        let cfg = LoopDetectionConfig {
556            blocked_threshold: 6,
557            ..Default::default()
558        };
559        let mut detector = LoopDetector::with_config(&cfg);
560        let variants = [
561            "compress",
562            "compression",
563            "compress.*data",
564            "compress_output",
565            "compressor",
566            "compress_result",
567            "compress_file",
568        ];
569        for (i, pat) in variants
570            .iter()
571            .enumerate()
572            .take(cfg.blocked_threshold as usize)
573        {
574            detector.record_search("ctx_search", &format!("fp_{i}"), Some(pat));
575        }
576        let r = detector.record_search("ctx_search", "fp_new", Some("compress_all"));
577        assert_eq!(r.level, ThrottleLevel::Blocked);
578    }
579
580    #[test]
581    fn is_search_tool_detection() {
582        assert!(LoopDetector::is_search_tool("ctx_search"));
583        assert!(LoopDetector::is_search_tool("ctx_semantic_search"));
584        assert!(!LoopDetector::is_search_tool("ctx_read"));
585        assert!(!LoopDetector::is_search_tool("ctx_shell"));
586    }
587
588    #[test]
589    fn is_search_shell_command_detection() {
590        assert!(LoopDetector::is_search_shell_command("grep -r foo ."));
591        assert!(LoopDetector::is_search_shell_command("rg pattern src/"));
592        assert!(LoopDetector::is_search_shell_command("find . -name '*.rs'"));
593        assert!(!LoopDetector::is_search_shell_command("cargo build"));
594        assert!(!LoopDetector::is_search_shell_command("git status"));
595    }
596
597    #[test]
598    fn search_block_message_has_guidance_when_blocking_enabled() {
599        // Blocking must be explicitly enabled to get block messages
600        let cfg = LoopDetectionConfig {
601            blocked_threshold: 6,
602            search_group_limit: 8,
603            ..Default::default()
604        };
605        let mut detector = LoopDetector::with_config(&cfg);
606        for i in 0..10 {
607            detector.record_search("ctx_search", &format!("fp_{i}"), Some("compress"));
608        }
609        let r = detector.record_search("ctx_search", "fp_new", Some("compress"));
610        assert_eq!(r.level, ThrottleLevel::Blocked);
611        let msg = r.message.unwrap();
612        assert!(msg.contains("ctx_tree"));
613        assert!(msg.contains("path"));
614        assert!(msg.contains("ctx_read"));
615    }
616}