Skip to main content

lean_ctx/core/
loop_detection.rs

1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use super::config::LoopDetectionConfig;
5
6const SEARCH_TOOLS: &[&str] = &["ctx_search", "ctx_semantic_search"];
7
8const SEARCH_SHELL_PREFIXES: &[&str] = &["grep ", "rg ", "find ", "fd ", "ag ", "ack "];
9
10#[derive(Debug, Clone)]
11pub struct LoopDetector {
12    call_history: HashMap<String, Vec<Instant>>,
13    duplicate_counts: HashMap<String, u32>,
14    search_group_history: Vec<Instant>,
15    recent_search_patterns: Vec<String>,
16    normal_threshold: u32,
17    reduced_threshold: u32,
18    blocked_threshold: u32,
19    window: Duration,
20    search_group_limit: u32,
21}
22
23#[derive(Debug, Clone, PartialEq)]
24pub enum ThrottleLevel {
25    Normal,
26    Reduced,
27    Blocked,
28}
29
30#[derive(Debug, Clone)]
31pub struct ThrottleResult {
32    pub level: ThrottleLevel,
33    pub call_count: u32,
34    pub message: Option<String>,
35}
36
37impl Default for LoopDetector {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl LoopDetector {
44    pub fn new() -> Self {
45        Self::with_config(&LoopDetectionConfig::default())
46    }
47
48    pub fn with_config(cfg: &LoopDetectionConfig) -> Self {
49        Self {
50            call_history: HashMap::new(),
51            duplicate_counts: HashMap::new(),
52            search_group_history: Vec::new(),
53            recent_search_patterns: Vec::new(),
54            normal_threshold: cfg.normal_threshold.max(1),
55            reduced_threshold: cfg.reduced_threshold.max(2),
56            blocked_threshold: cfg.blocked_threshold.max(3),
57            window: Duration::from_secs(cfg.window_secs),
58            search_group_limit: cfg.search_group_limit.max(3),
59        }
60    }
61
62    pub fn record_call(&mut self, tool: &str, args_fingerprint: &str) -> ThrottleResult {
63        let now = Instant::now();
64        self.prune_window(now);
65
66        let key = format!("{tool}:{args_fingerprint}");
67        let entries = self.call_history.entry(key.clone()).or_default();
68        entries.push(now);
69        let count = entries.len() as u32;
70        *self.duplicate_counts.entry(key).or_default() = count;
71
72        if count > self.blocked_threshold {
73            return ThrottleResult {
74                level: ThrottleLevel::Blocked,
75                call_count: count,
76                message: Some(self.block_message(tool, count)),
77            };
78        }
79        if count > self.reduced_threshold {
80            return ThrottleResult {
81                level: ThrottleLevel::Reduced,
82                call_count: count,
83                message: Some(format!(
84                    "Warning: {tool} called {count}x with same args. \
85                     Results reduced. Try a different approach or narrow your scope."
86                )),
87            };
88        }
89        if count > self.normal_threshold {
90            return ThrottleResult {
91                level: ThrottleLevel::Reduced,
92                call_count: count,
93                message: Some(format!(
94                    "Note: {tool} called {count}x with similar args. Consider narrowing scope."
95                )),
96            };
97        }
98        ThrottleResult {
99            level: ThrottleLevel::Normal,
100            call_count: count,
101            message: None,
102        }
103    }
104
105    /// Record a search-category call and check the cross-tool search group limit.
106    /// `search_pattern` is the extracted query/regex the agent is looking for (if available).
107    pub fn record_search(
108        &mut self,
109        tool: &str,
110        args_fingerprint: &str,
111        search_pattern: Option<&str>,
112    ) -> ThrottleResult {
113        let now = Instant::now();
114
115        self.search_group_history.push(now);
116        let search_count = self.search_group_history.len() as u32;
117
118        let similar_count = if let Some(pat) = search_pattern {
119            let sc = self.count_similar_patterns(pat);
120            if !pat.is_empty() {
121                self.recent_search_patterns.push(pat.to_string());
122                if self.recent_search_patterns.len() > 15 {
123                    self.recent_search_patterns.remove(0);
124                }
125            }
126            sc
127        } else {
128            0
129        };
130
131        if similar_count >= self.blocked_threshold {
132            return ThrottleResult {
133                level: ThrottleLevel::Blocked,
134                call_count: similar_count,
135                message: Some(self.search_block_message(similar_count)),
136            };
137        }
138
139        if search_count > self.search_group_limit {
140            return ThrottleResult {
141                level: ThrottleLevel::Blocked,
142                call_count: search_count,
143                message: Some(self.search_group_block_message(search_count)),
144            };
145        }
146
147        if similar_count >= self.reduced_threshold {
148            return ThrottleResult {
149                level: ThrottleLevel::Reduced,
150                call_count: similar_count,
151                message: Some(format!(
152                    "Warning: You've searched for similar patterns {similar_count}x. \
153                     Narrow your search with the 'path' parameter or try ctx_tree first."
154                )),
155            };
156        }
157
158        if search_count > self.search_group_limit.saturating_sub(3) {
159            let per_fp = self.record_call(tool, args_fingerprint);
160            if per_fp.level != ThrottleLevel::Normal {
161                return per_fp;
162            }
163            return ThrottleResult {
164                level: ThrottleLevel::Reduced,
165                call_count: search_count,
166                message: Some(format!(
167                    "Note: {search_count} search calls in the last {}s. \
168                     Use ctx_tree to orient first, then scope searches with 'path'.",
169                    self.window.as_secs()
170                )),
171            };
172        }
173
174        self.record_call(tool, args_fingerprint)
175    }
176
177    pub fn is_search_tool(tool: &str) -> bool {
178        SEARCH_TOOLS.contains(&tool)
179    }
180
181    pub fn is_search_shell_command(command: &str) -> bool {
182        let cmd = command.trim_start();
183        SEARCH_SHELL_PREFIXES.iter().any(|p| cmd.starts_with(p))
184    }
185
186    pub fn fingerprint(args: &serde_json::Value) -> String {
187        use std::collections::hash_map::DefaultHasher;
188        use std::hash::{Hash, Hasher};
189
190        let canonical = canonical_json(args);
191        let mut hasher = DefaultHasher::new();
192        canonical.hash(&mut hasher);
193        format!("{:016x}", hasher.finish())
194    }
195
196    pub fn stats(&self) -> Vec<(String, u32)> {
197        let mut entries: Vec<(String, u32)> = self
198            .duplicate_counts
199            .iter()
200            .filter(|(_, &count)| count > 1)
201            .map(|(k, &v)| (k.clone(), v))
202            .collect();
203        entries.sort_by(|a, b| b.1.cmp(&a.1));
204        entries
205    }
206
207    pub fn reset(&mut self) {
208        self.call_history.clear();
209        self.duplicate_counts.clear();
210        self.search_group_history.clear();
211        self.recent_search_patterns.clear();
212    }
213
214    fn prune_window(&mut self, now: Instant) {
215        for entries in self.call_history.values_mut() {
216            entries.retain(|t| now.duration_since(*t) < self.window);
217        }
218        self.search_group_history
219            .retain(|t| now.duration_since(*t) < self.window);
220    }
221
222    fn count_similar_patterns(&self, new_pattern: &str) -> u32 {
223        let new_lower = new_pattern.to_lowercase();
224        let new_root = extract_alpha_root(&new_lower);
225
226        let mut count = 0u32;
227        for existing in &self.recent_search_patterns {
228            let existing_lower = existing.to_lowercase();
229            if patterns_are_similar(&new_lower, &existing_lower) {
230                count += 1;
231            } else if new_root.len() >= 4 {
232                let existing_root = extract_alpha_root(&existing_lower);
233                if existing_root.len() >= 4
234                    && (new_root.starts_with(&existing_root)
235                        || existing_root.starts_with(&new_root))
236                {
237                    count += 1;
238                }
239            }
240        }
241        count
242    }
243
244    fn block_message(&self, tool: &str, count: u32) -> String {
245        if Self::is_search_tool(tool) {
246            self.search_block_message(count)
247        } else {
248            format!(
249                "LOOP DETECTED: {tool} called {count}x with same/similar args. \
250                 Call blocked. Change your approach — the current strategy is not working."
251            )
252        }
253    }
254
255    fn search_block_message(&self, count: u32) -> String {
256        format!(
257            "LOOP DETECTED: You've searched {count}x with similar patterns. STOP searching and change strategy. \
258             1) Use ctx_tree to understand the project structure first. \
259             2) Narrow your search with the 'path' parameter to a specific directory. \
260             3) Use ctx_read with mode='map' to understand a file before searching more."
261        )
262    }
263
264    fn search_group_block_message(&self, count: u32) -> String {
265        format!(
266            "LOOP DETECTED: {count} search calls in {}s — too many. STOP and rethink. \
267             1) Use ctx_tree to map the project structure. \
268             2) Pick ONE specific directory and search there with the 'path' parameter. \
269             3) Read files with ctx_read mode='map' instead of searching blindly.",
270            self.window.as_secs()
271        )
272    }
273}
274
275fn extract_alpha_root(pattern: &str) -> String {
276    pattern
277        .chars()
278        .take_while(|c| c.is_alphanumeric())
279        .collect()
280}
281
282fn patterns_are_similar(a: &str, b: &str) -> bool {
283    if a == b {
284        return true;
285    }
286    if a.contains(b) || b.contains(a) {
287        return true;
288    }
289    let a_alpha: String = a.chars().filter(|c| c.is_alphanumeric()).collect();
290    let b_alpha: String = b.chars().filter(|c| c.is_alphanumeric()).collect();
291    if a_alpha.len() >= 3
292        && b_alpha.len() >= 3
293        && (a_alpha.contains(&b_alpha) || b_alpha.contains(&a_alpha))
294    {
295        return true;
296    }
297    false
298}
299
300fn canonical_json(value: &serde_json::Value) -> String {
301    match value {
302        serde_json::Value::Object(map) => {
303            let mut keys: Vec<&String> = map.keys().collect();
304            keys.sort();
305            let entries: Vec<String> = keys
306                .iter()
307                .map(|k| format!("{}:{}", k, canonical_json(&map[*k])))
308                .collect();
309            format!("{{{}}}", entries.join(","))
310        }
311        serde_json::Value::Array(arr) => {
312            let entries: Vec<String> = arr.iter().map(canonical_json).collect();
313            format!("[{}]", entries.join(","))
314        }
315        _ => value.to_string(),
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    fn test_config(normal: u32, reduced: u32, blocked: u32) -> LoopDetectionConfig {
324        LoopDetectionConfig {
325            normal_threshold: normal,
326            reduced_threshold: reduced,
327            blocked_threshold: blocked,
328            window_secs: 300,
329            search_group_limit: 10,
330        }
331    }
332
333    #[test]
334    fn normal_calls_pass_through() {
335        let mut detector = LoopDetector::new();
336        let r1 = detector.record_call("ctx_read", "abc123");
337        assert_eq!(r1.level, ThrottleLevel::Normal);
338        assert_eq!(r1.call_count, 1);
339        assert!(r1.message.is_none());
340    }
341
342    #[test]
343    fn repeated_calls_trigger_reduced() {
344        let cfg = LoopDetectionConfig::default();
345        let mut detector = LoopDetector::with_config(&cfg);
346        for _ in 0..cfg.normal_threshold {
347            detector.record_call("ctx_read", "same_fp");
348        }
349        let result = detector.record_call("ctx_read", "same_fp");
350        assert_eq!(result.level, ThrottleLevel::Reduced);
351        assert!(result.message.is_some());
352    }
353
354    #[test]
355    fn excessive_calls_get_blocked() {
356        let cfg = LoopDetectionConfig::default();
357        let mut detector = LoopDetector::with_config(&cfg);
358        for _ in 0..cfg.blocked_threshold {
359            detector.record_call("ctx_shell", "same_fp");
360        }
361        let result = detector.record_call("ctx_shell", "same_fp");
362        assert_eq!(result.level, ThrottleLevel::Blocked);
363        assert!(result.message.unwrap().contains("LOOP DETECTED"));
364    }
365
366    #[test]
367    fn different_args_tracked_separately() {
368        let mut detector = LoopDetector::new();
369        for _ in 0..10 {
370            detector.record_call("ctx_read", "fp_a");
371        }
372        let result = detector.record_call("ctx_read", "fp_b");
373        assert_eq!(result.level, ThrottleLevel::Normal);
374        assert_eq!(result.call_count, 1);
375    }
376
377    #[test]
378    fn fingerprint_deterministic() {
379        let args = serde_json::json!({"path": "test.rs", "mode": "full"});
380        let fp1 = LoopDetector::fingerprint(&args);
381        let fp2 = LoopDetector::fingerprint(&args);
382        assert_eq!(fp1, fp2);
383    }
384
385    #[test]
386    fn fingerprint_order_independent() {
387        let a = serde_json::json!({"mode": "full", "path": "test.rs"});
388        let b = serde_json::json!({"path": "test.rs", "mode": "full"});
389        assert_eq!(LoopDetector::fingerprint(&a), LoopDetector::fingerprint(&b));
390    }
391
392    #[test]
393    fn stats_shows_duplicates() {
394        let mut detector = LoopDetector::new();
395        for _ in 0..5 {
396            detector.record_call("ctx_read", "fp_a");
397        }
398        detector.record_call("ctx_shell", "fp_b");
399        let stats = detector.stats();
400        assert_eq!(stats.len(), 1);
401        assert_eq!(stats[0].1, 5);
402    }
403
404    #[test]
405    fn reset_clears_state() {
406        let mut detector = LoopDetector::new();
407        for _ in 0..5 {
408            detector.record_call("ctx_read", "fp_a");
409        }
410        detector.reset();
411        let result = detector.record_call("ctx_read", "fp_a");
412        assert_eq!(result.call_count, 1);
413    }
414
415    #[test]
416    fn custom_thresholds_from_config() {
417        let cfg = test_config(1, 2, 3);
418        let mut detector = LoopDetector::with_config(&cfg);
419        detector.record_call("ctx_read", "fp");
420        let r = detector.record_call("ctx_read", "fp");
421        assert_eq!(r.level, ThrottleLevel::Reduced);
422        detector.record_call("ctx_read", "fp");
423        let r = detector.record_call("ctx_read", "fp");
424        assert_eq!(r.level, ThrottleLevel::Blocked);
425    }
426
427    #[test]
428    fn similar_patterns_detected() {
429        assert!(patterns_are_similar("compress", "compress"));
430        assert!(patterns_are_similar("compress", "compression"));
431        assert!(patterns_are_similar("compress.*data", "compress"));
432        assert!(!patterns_are_similar("foo", "bar"));
433        assert!(!patterns_are_similar("ab", "cd"));
434    }
435
436    #[test]
437    fn search_group_tracking() {
438        let cfg = LoopDetectionConfig {
439            search_group_limit: 5,
440            ..Default::default()
441        };
442        let mut detector = LoopDetector::with_config(&cfg);
443        for i in 0..5 {
444            let fp = format!("fp_{i}");
445            let r = detector.record_search("ctx_search", &fp, Some(&format!("pattern_{i}")));
446            assert_ne!(r.level, ThrottleLevel::Blocked, "call {i} should not block");
447        }
448        let r = detector.record_search("ctx_search", "fp_5", Some("pattern_5"));
449        assert_eq!(r.level, ThrottleLevel::Blocked);
450        assert!(r.message.unwrap().contains("search calls"));
451    }
452
453    #[test]
454    fn similar_search_patterns_trigger_block() {
455        let cfg = LoopDetectionConfig::default();
456        let mut detector = LoopDetector::with_config(&cfg);
457        let variants = [
458            "compress",
459            "compression",
460            "compress.*data",
461            "compress_output",
462            "compressor",
463            "compress_result",
464            "compress_file",
465        ];
466        for (i, pat) in variants
467            .iter()
468            .enumerate()
469            .take(cfg.blocked_threshold as usize)
470        {
471            detector.record_search("ctx_search", &format!("fp_{i}"), Some(pat));
472        }
473        let r = detector.record_search("ctx_search", "fp_new", Some("compress_all"));
474        assert_eq!(r.level, ThrottleLevel::Blocked);
475    }
476
477    #[test]
478    fn is_search_tool_detection() {
479        assert!(LoopDetector::is_search_tool("ctx_search"));
480        assert!(LoopDetector::is_search_tool("ctx_semantic_search"));
481        assert!(!LoopDetector::is_search_tool("ctx_read"));
482        assert!(!LoopDetector::is_search_tool("ctx_shell"));
483    }
484
485    #[test]
486    fn is_search_shell_command_detection() {
487        assert!(LoopDetector::is_search_shell_command("grep -r foo ."));
488        assert!(LoopDetector::is_search_shell_command("rg pattern src/"));
489        assert!(LoopDetector::is_search_shell_command("find . -name '*.rs'"));
490        assert!(!LoopDetector::is_search_shell_command("cargo build"));
491        assert!(!LoopDetector::is_search_shell_command("git status"));
492    }
493
494    #[test]
495    fn search_block_message_has_guidance() {
496        let mut detector = LoopDetector::new();
497        for i in 0..10 {
498            detector.record_search("ctx_search", &format!("fp_{i}"), Some("compress"));
499        }
500        let r = detector.record_search("ctx_search", "fp_new", Some("compress"));
501        let msg = r.message.unwrap();
502        assert!(msg.contains("ctx_tree"));
503        assert!(msg.contains("path"));
504        assert!(msg.contains("ctx_read"));
505    }
506}