Skip to main content

lean_ctx/core/
loop_detection.rs

1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use super::config::LoopDetectionConfig;
5
6const SEARCH_TOOLS: &[&str] = &["ctx_search", "ctx_semantic_search"];
7
8const SEARCH_SHELL_PREFIXES: &[&str] = &["grep ", "rg ", "find ", "fd ", "ag ", "ack "];
9
10/// Tracks repeated tool calls within a time window to detect and throttle agent loops.
11#[derive(Debug, Clone)]
12pub struct LoopDetector {
13    call_history: HashMap<String, Vec<Instant>>,
14    duplicate_counts: HashMap<String, u32>,
15    search_group_history: Vec<Instant>,
16    recent_search_patterns: Vec<String>,
17    normal_threshold: u32,
18    reduced_threshold: u32,
19    blocked_threshold: u32,
20    window: Duration,
21    search_group_limit: u32,
22}
23
24/// Severity of throttling applied to a repeated call: normal, reduced, or blocked.
25#[derive(Debug, Clone, PartialEq)]
26pub enum ThrottleLevel {
27    Normal,
28    Reduced,
29    Blocked,
30}
31
32/// Outcome of a loop detection check: throttle level, count, and optional warning.
33#[derive(Debug, Clone)]
34pub struct ThrottleResult {
35    pub level: ThrottleLevel,
36    pub call_count: u32,
37    pub message: Option<String>,
38}
39
40impl Default for LoopDetector {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl LoopDetector {
47    /// Creates a loop detector with default thresholds.
48    pub fn new() -> Self {
49        Self::with_config(&LoopDetectionConfig::default())
50    }
51
52    /// Creates a loop detector with custom thresholds from config.
53    /// Set blocked_threshold to 0 to disable blocking entirely (LeanCTX philosophy).
54    pub fn with_config(cfg: &LoopDetectionConfig) -> Self {
55        Self {
56            call_history: HashMap::new(),
57            duplicate_counts: HashMap::new(),
58            search_group_history: Vec::new(),
59            recent_search_patterns: Vec::new(),
60            normal_threshold: cfg.normal_threshold.max(1),
61            reduced_threshold: cfg.reduced_threshold.max(2),
62            // 0 = never block (LeanCTX default), >0 = block after N identical calls
63            blocked_threshold: cfg.blocked_threshold,
64            window: Duration::from_secs(cfg.window_secs),
65            search_group_limit: if cfg.blocked_threshold == 0 {
66                u32::MAX
67            } else {
68                cfg.search_group_limit.max(3)
69            },
70        }
71    }
72
73    /// Records a tool call and returns the throttle result based on repetition count.
74    pub fn record_call(&mut self, tool: &str, args_fingerprint: &str) -> ThrottleResult {
75        let now = Instant::now();
76        self.prune_window(now);
77
78        let key = format!("{tool}:{args_fingerprint}");
79        let entries = self.call_history.entry(key.clone()).or_default();
80        entries.push(now);
81        let count = entries.len() as u32;
82        *self.duplicate_counts.entry(key).or_default() = count;
83
84        // blocked_threshold == 0 means blocking is disabled (LeanCTX default)
85        if self.blocked_threshold > 0 && count > self.blocked_threshold {
86            return ThrottleResult {
87                level: ThrottleLevel::Blocked,
88                call_count: count,
89                message: Some(self.block_message(tool, count)),
90            };
91        }
92        if count > self.reduced_threshold {
93            return ThrottleResult {
94                level: ThrottleLevel::Reduced,
95                call_count: count,
96                message: Some(format!(
97                    "Warning: {tool} called {count}x with same args. \
98                     Results reduced. Try a different approach or narrow your scope."
99                )),
100            };
101        }
102        if count > self.normal_threshold {
103            return ThrottleResult {
104                level: ThrottleLevel::Reduced,
105                call_count: count,
106                message: Some(format!(
107                    "Note: {tool} called {count}x with similar args. Consider narrowing scope."
108                )),
109            };
110        }
111        ThrottleResult {
112            level: ThrottleLevel::Normal,
113            call_count: count,
114            message: None,
115        }
116    }
117
118    /// Record a search-category call and check the cross-tool search group limit.
119    /// `search_pattern` is the extracted query/regex the agent is looking for (if available).
120    pub fn record_search(
121        &mut self,
122        tool: &str,
123        args_fingerprint: &str,
124        search_pattern: Option<&str>,
125    ) -> ThrottleResult {
126        let now = Instant::now();
127
128        self.search_group_history.push(now);
129        let search_count = self.search_group_history.len() as u32;
130
131        let similar_count = if let Some(pat) = search_pattern {
132            let sc = self.count_similar_patterns(pat);
133            if !pat.is_empty() {
134                self.recent_search_patterns.push(pat.to_string());
135                if self.recent_search_patterns.len() > 15 {
136                    self.recent_search_patterns.remove(0);
137                }
138            }
139            sc
140        } else {
141            0
142        };
143
144        // blocked_threshold == 0 means blocking is disabled (LeanCTX default)
145        if self.blocked_threshold > 0 && similar_count >= self.blocked_threshold {
146            return ThrottleResult {
147                level: ThrottleLevel::Blocked,
148                call_count: similar_count,
149                message: Some(self.search_block_message(similar_count)),
150            };
151        }
152
153        // search_group_limit == u32::MAX when blocking is disabled
154        if self.blocked_threshold > 0 && search_count > self.search_group_limit {
155            return ThrottleResult {
156                level: ThrottleLevel::Blocked,
157                call_count: search_count,
158                message: Some(self.search_group_block_message(search_count)),
159            };
160        }
161
162        if similar_count >= self.reduced_threshold {
163            return ThrottleResult {
164                level: ThrottleLevel::Reduced,
165                call_count: similar_count,
166                message: Some(format!(
167                    "Warning: You've searched for similar patterns {similar_count}x. \
168                     Narrow your search with the 'path' parameter or try ctx_tree first."
169                )),
170            };
171        }
172
173        if search_count > self.search_group_limit.saturating_sub(3) {
174            let per_fp = self.record_call(tool, args_fingerprint);
175            if per_fp.level != ThrottleLevel::Normal {
176                return per_fp;
177            }
178            return ThrottleResult {
179                level: ThrottleLevel::Reduced,
180                call_count: search_count,
181                message: Some(format!(
182                    "Note: {search_count} search calls in the last {}s. \
183                     Use ctx_tree to orient first, then scope searches with 'path'.",
184                    self.window.as_secs()
185                )),
186            };
187        }
188
189        self.record_call(tool, args_fingerprint)
190    }
191
192    /// Returns `true` if the tool name is a known search tool (ctx_search, etc.).
193    pub fn is_search_tool(tool: &str) -> bool {
194        SEARCH_TOOLS.contains(&tool)
195    }
196
197    /// Returns `true` if the shell command starts with a search tool (grep, rg, find, etc.).
198    pub fn is_search_shell_command(command: &str) -> bool {
199        let cmd = command.trim_start();
200        SEARCH_SHELL_PREFIXES.iter().any(|p| cmd.starts_with(p))
201    }
202
203    /// Computes a deterministic hash fingerprint of JSON tool arguments.
204    pub fn fingerprint(args: &serde_json::Value) -> String {
205        use std::collections::hash_map::DefaultHasher;
206        use std::hash::{Hash, Hasher};
207
208        let canonical = canonical_json(args);
209        let mut hasher = DefaultHasher::new();
210        canonical.hash(&mut hasher);
211        format!("{:016x}", hasher.finish())
212    }
213
214    /// Returns duplicate call entries sorted by count (descending), filtered to count > 1.
215    pub fn stats(&self) -> Vec<(String, u32)> {
216        let mut entries: Vec<(String, u32)> = self
217            .duplicate_counts
218            .iter()
219            .filter(|(_, &count)| count > 1)
220            .map(|(k, &v)| (k.clone(), v))
221            .collect();
222        entries.sort_by_key(|x| std::cmp::Reverse(x.1));
223        entries
224    }
225
226    /// Clears all tracking state (call history, search patterns, counters).
227    pub fn reset(&mut self) {
228        self.call_history.clear();
229        self.duplicate_counts.clear();
230        self.search_group_history.clear();
231        self.recent_search_patterns.clear();
232    }
233
234    fn prune_window(&mut self, now: Instant) {
235        for entries in self.call_history.values_mut() {
236            entries.retain(|t| now.duration_since(*t) < self.window);
237        }
238        self.search_group_history
239            .retain(|t| now.duration_since(*t) < self.window);
240    }
241
242    fn count_similar_patterns(&self, new_pattern: &str) -> u32 {
243        let new_lower = new_pattern.to_lowercase();
244        let new_root = extract_alpha_root(&new_lower);
245
246        let mut count = 0u32;
247        for existing in &self.recent_search_patterns {
248            let existing_lower = existing.to_lowercase();
249            if patterns_are_similar(&new_lower, &existing_lower) {
250                count += 1;
251            } else if new_root.len() >= 4 {
252                let existing_root = extract_alpha_root(&existing_lower);
253                if existing_root.len() >= 4
254                    && (new_root.starts_with(&existing_root)
255                        || existing_root.starts_with(&new_root))
256                {
257                    count += 1;
258                }
259            }
260        }
261        count
262    }
263
264    fn block_message(&self, tool: &str, count: u32) -> String {
265        if Self::is_search_tool(tool) {
266            self.search_block_message(count)
267        } else {
268            format!(
269                "LOOP DETECTED: {tool} called {count}x with same/similar args. \
270                 Call blocked. Change your approach — the current strategy is not working."
271            )
272        }
273    }
274
275    #[allow(clippy::unused_self)]
276    fn search_block_message(&self, count: u32) -> String {
277        format!(
278            "LOOP DETECTED: You've searched {count}x with similar patterns. STOP searching and change strategy. \
279             1) Use ctx_tree to understand the project structure first. \
280             2) Narrow your search with the 'path' parameter to a specific directory. \
281             3) Use ctx_read with mode='map' to understand a file before searching more."
282        )
283    }
284
285    fn search_group_block_message(&self, count: u32) -> String {
286        format!(
287            "LOOP DETECTED: {count} search calls in {}s — too many. STOP and rethink. \
288             1) Use ctx_tree to map the project structure. \
289             2) Pick ONE specific directory and search there with the 'path' parameter. \
290             3) Read files with ctx_read mode='map' instead of searching blindly.",
291            self.window.as_secs()
292        )
293    }
294}
295
296fn extract_alpha_root(pattern: &str) -> String {
297    pattern
298        .chars()
299        .take_while(|c| c.is_alphanumeric())
300        .collect()
301}
302
303fn patterns_are_similar(a: &str, b: &str) -> bool {
304    if a == b {
305        return true;
306    }
307    if a.contains(b) || b.contains(a) {
308        return true;
309    }
310    let a_alpha: String = a.chars().filter(|c| c.is_alphanumeric()).collect();
311    let b_alpha: String = b.chars().filter(|c| c.is_alphanumeric()).collect();
312    if a_alpha.len() >= 3
313        && b_alpha.len() >= 3
314        && (a_alpha.contains(&b_alpha) || b_alpha.contains(&a_alpha))
315    {
316        return true;
317    }
318    false
319}
320
321fn canonical_json(value: &serde_json::Value) -> String {
322    match value {
323        serde_json::Value::Object(map) => {
324            let mut keys: Vec<&String> = map.keys().collect();
325            keys.sort();
326            let entries: Vec<String> = keys
327                .iter()
328                .map(|k| format!("{}:{}", k, canonical_json(&map[*k])))
329                .collect();
330            format!("{{{}}}", entries.join(","))
331        }
332        serde_json::Value::Array(arr) => {
333            let entries: Vec<String> = arr.iter().map(canonical_json).collect();
334            format!("[{}]", entries.join(","))
335        }
336        _ => value.to_string(),
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    fn test_config(normal: u32, reduced: u32, blocked: u32) -> LoopDetectionConfig {
345        LoopDetectionConfig {
346            normal_threshold: normal,
347            reduced_threshold: reduced,
348            blocked_threshold: blocked,
349            window_secs: 300,
350            search_group_limit: 10,
351        }
352    }
353
354    #[test]
355    fn normal_calls_pass_through() {
356        let mut detector = LoopDetector::new();
357        let r1 = detector.record_call("ctx_read", "abc123");
358        assert_eq!(r1.level, ThrottleLevel::Normal);
359        assert_eq!(r1.call_count, 1);
360        assert!(r1.message.is_none());
361    }
362
363    #[test]
364    fn repeated_calls_trigger_reduced() {
365        let cfg = LoopDetectionConfig::default();
366        let mut detector = LoopDetector::with_config(&cfg);
367        for _ in 0..cfg.normal_threshold {
368            detector.record_call("ctx_read", "same_fp");
369        }
370        let result = detector.record_call("ctx_read", "same_fp");
371        assert_eq!(result.level, ThrottleLevel::Reduced);
372        assert!(result.message.is_some());
373    }
374
375    #[test]
376    fn excessive_calls_get_blocked_when_enabled() {
377        // Blocking must be explicitly enabled (blocked_threshold > 0)
378        let cfg = LoopDetectionConfig {
379            blocked_threshold: 6,
380            ..Default::default()
381        };
382        let mut detector = LoopDetector::with_config(&cfg);
383        for _ in 0..cfg.blocked_threshold {
384            detector.record_call("ctx_shell", "same_fp");
385        }
386        let result = detector.record_call("ctx_shell", "same_fp");
387        assert_eq!(result.level, ThrottleLevel::Blocked);
388        assert!(result.message.unwrap().contains("LOOP DETECTED"));
389    }
390
391    #[test]
392    fn blocking_disabled_by_default() {
393        // Default config has blocked_threshold = 0, so blocking never happens
394        let cfg = LoopDetectionConfig::default();
395        assert_eq!(cfg.blocked_threshold, 0);
396        let mut detector = LoopDetector::with_config(&cfg);
397        // Even 100 calls should not block when blocking is disabled
398        for _ in 0..100 {
399            detector.record_call("ctx_shell", "same_fp");
400        }
401        let result = detector.record_call("ctx_shell", "same_fp");
402        // Should be Reduced (warning) but never Blocked
403        assert_ne!(result.level, ThrottleLevel::Blocked);
404    }
405
406    #[test]
407    fn different_args_tracked_separately() {
408        let mut detector = LoopDetector::new();
409        for _ in 0..10 {
410            detector.record_call("ctx_read", "fp_a");
411        }
412        let result = detector.record_call("ctx_read", "fp_b");
413        assert_eq!(result.level, ThrottleLevel::Normal);
414        assert_eq!(result.call_count, 1);
415    }
416
417    #[test]
418    fn fingerprint_deterministic() {
419        let args = serde_json::json!({"path": "test.rs", "mode": "full"});
420        let fp1 = LoopDetector::fingerprint(&args);
421        let fp2 = LoopDetector::fingerprint(&args);
422        assert_eq!(fp1, fp2);
423    }
424
425    #[test]
426    fn fingerprint_order_independent() {
427        let a = serde_json::json!({"mode": "full", "path": "test.rs"});
428        let b = serde_json::json!({"path": "test.rs", "mode": "full"});
429        assert_eq!(LoopDetector::fingerprint(&a), LoopDetector::fingerprint(&b));
430    }
431
432    #[test]
433    fn stats_shows_duplicates() {
434        let mut detector = LoopDetector::new();
435        for _ in 0..5 {
436            detector.record_call("ctx_read", "fp_a");
437        }
438        detector.record_call("ctx_shell", "fp_b");
439        let stats = detector.stats();
440        assert_eq!(stats.len(), 1);
441        assert_eq!(stats[0].1, 5);
442    }
443
444    #[test]
445    fn reset_clears_state() {
446        let mut detector = LoopDetector::new();
447        for _ in 0..5 {
448            detector.record_call("ctx_read", "fp_a");
449        }
450        detector.reset();
451        let result = detector.record_call("ctx_read", "fp_a");
452        assert_eq!(result.call_count, 1);
453    }
454
455    #[test]
456    fn custom_thresholds_from_config() {
457        let cfg = test_config(1, 2, 3);
458        let mut detector = LoopDetector::with_config(&cfg);
459        detector.record_call("ctx_read", "fp");
460        let r = detector.record_call("ctx_read", "fp");
461        assert_eq!(r.level, ThrottleLevel::Reduced);
462        detector.record_call("ctx_read", "fp");
463        let r = detector.record_call("ctx_read", "fp");
464        assert_eq!(r.level, ThrottleLevel::Blocked);
465    }
466
467    #[test]
468    fn similar_patterns_detected() {
469        assert!(patterns_are_similar("compress", "compress"));
470        assert!(patterns_are_similar("compress", "compression"));
471        assert!(patterns_are_similar("compress.*data", "compress"));
472        assert!(!patterns_are_similar("foo", "bar"));
473        assert!(!patterns_are_similar("ab", "cd"));
474    }
475
476    #[test]
477    fn search_group_tracking_when_blocking_enabled() {
478        // Blocking must be explicitly enabled for search group limits to block
479        let cfg = LoopDetectionConfig {
480            search_group_limit: 5,
481            blocked_threshold: 6, // Enable blocking
482            ..Default::default()
483        };
484        let mut detector = LoopDetector::with_config(&cfg);
485        for i in 0..5 {
486            let fp = format!("fp_{i}");
487            let r = detector.record_search("ctx_search", &fp, Some(&format!("pattern_{i}")));
488            assert_ne!(r.level, ThrottleLevel::Blocked, "call {i} should not block");
489        }
490        let r = detector.record_search("ctx_search", "fp_5", Some("pattern_5"));
491        assert_eq!(r.level, ThrottleLevel::Blocked);
492        assert!(r.message.unwrap().contains("search calls"));
493    }
494
495    #[test]
496    fn similar_search_patterns_trigger_block_when_enabled() {
497        // Blocking must be explicitly enabled
498        let cfg = LoopDetectionConfig {
499            blocked_threshold: 6,
500            ..Default::default()
501        };
502        let mut detector = LoopDetector::with_config(&cfg);
503        let variants = [
504            "compress",
505            "compression",
506            "compress.*data",
507            "compress_output",
508            "compressor",
509            "compress_result",
510            "compress_file",
511        ];
512        for (i, pat) in variants
513            .iter()
514            .enumerate()
515            .take(cfg.blocked_threshold as usize)
516        {
517            detector.record_search("ctx_search", &format!("fp_{i}"), Some(pat));
518        }
519        let r = detector.record_search("ctx_search", "fp_new", Some("compress_all"));
520        assert_eq!(r.level, ThrottleLevel::Blocked);
521    }
522
523    #[test]
524    fn is_search_tool_detection() {
525        assert!(LoopDetector::is_search_tool("ctx_search"));
526        assert!(LoopDetector::is_search_tool("ctx_semantic_search"));
527        assert!(!LoopDetector::is_search_tool("ctx_read"));
528        assert!(!LoopDetector::is_search_tool("ctx_shell"));
529    }
530
531    #[test]
532    fn is_search_shell_command_detection() {
533        assert!(LoopDetector::is_search_shell_command("grep -r foo ."));
534        assert!(LoopDetector::is_search_shell_command("rg pattern src/"));
535        assert!(LoopDetector::is_search_shell_command("find . -name '*.rs'"));
536        assert!(!LoopDetector::is_search_shell_command("cargo build"));
537        assert!(!LoopDetector::is_search_shell_command("git status"));
538    }
539
540    #[test]
541    fn search_block_message_has_guidance_when_blocking_enabled() {
542        // Blocking must be explicitly enabled to get block messages
543        let cfg = LoopDetectionConfig {
544            blocked_threshold: 6,
545            search_group_limit: 8,
546            ..Default::default()
547        };
548        let mut detector = LoopDetector::with_config(&cfg);
549        for i in 0..10 {
550            detector.record_search("ctx_search", &format!("fp_{i}"), Some("compress"));
551        }
552        let r = detector.record_search("ctx_search", "fp_new", Some("compress"));
553        assert_eq!(r.level, ThrottleLevel::Blocked);
554        let msg = r.message.unwrap();
555        assert!(msg.contains("ctx_tree"));
556        assert!(msg.contains("path"));
557        assert!(msg.contains("ctx_read"));
558    }
559}