1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use super::config::LoopDetectionConfig;
5
6const SEARCH_TOOLS: &[&str] = &["ctx_search", "ctx_semantic_search"];
7
8const SEARCH_SHELL_PREFIXES: &[&str] = &["grep ", "rg ", "find ", "fd ", "ag ", "ack "];
9
10#[derive(Debug, Clone)]
12pub struct LoopDetector {
13 call_history: HashMap<String, Vec<Instant>>,
14 duplicate_counts: HashMap<String, u32>,
15 tool_total_counts: HashMap<String, u32>,
16 tool_total_limits: HashMap<String, u32>,
17 search_group_history: Vec<Instant>,
18 recent_search_patterns: Vec<String>,
19 normal_threshold: u32,
20 reduced_threshold: u32,
21 blocked_threshold: u32,
22 window: Duration,
23 search_group_limit: u32,
24}
25
26#[derive(Debug, Clone, PartialEq)]
28pub enum ThrottleLevel {
29 Normal,
30 Reduced,
31 Blocked,
32}
33
34#[derive(Debug, Clone)]
36pub struct ThrottleResult {
37 pub level: ThrottleLevel,
38 pub call_count: u32,
39 pub message: Option<String>,
40}
41
42impl Default for LoopDetector {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl LoopDetector {
49 pub fn new() -> Self {
51 Self::with_config(&LoopDetectionConfig::default())
52 }
53
54 pub fn with_config(cfg: &LoopDetectionConfig) -> Self {
57 Self {
58 call_history: HashMap::new(),
59 duplicate_counts: HashMap::new(),
60 tool_total_counts: HashMap::new(),
61 tool_total_limits: cfg.tool_total_limits.clone(),
62 search_group_history: Vec::new(),
63 recent_search_patterns: Vec::new(),
64 normal_threshold: cfg.normal_threshold.max(1),
65 reduced_threshold: cfg.reduced_threshold.max(2),
66 blocked_threshold: cfg.blocked_threshold,
67 window: Duration::from_secs(cfg.window_secs),
68 search_group_limit: if cfg.blocked_threshold == 0 {
69 u32::MAX
70 } else {
71 cfg.search_group_limit.max(3)
72 },
73 }
74 }
75
76 pub fn record_call(&mut self, tool: &str, args_fingerprint: &str) -> ThrottleResult {
78 let now = Instant::now();
79 self.prune_window(now);
80
81 let total = self.tool_total_counts.entry(tool.to_string()).or_insert(0);
83 *total += 1;
84 let total_count = *total;
85
86 if let Some(&limit) = self.tool_total_limits.get(tool) {
87 if total_count > limit {
88 let msg = if crate::core::protocol::meta_visible() {
89 Some(format!(
90 "Warning: {tool} called {total_count}x total (limit: {limit}). \
91 Consider ctx_compress or narrowing scope."
92 ))
93 } else {
94 None
95 };
96 return ThrottleResult {
97 level: ThrottleLevel::Reduced,
98 call_count: total_count,
99 message: msg,
100 };
101 }
102 }
103
104 let key = format!("{tool}:{args_fingerprint}");
105 let entries = self.call_history.entry(key.clone()).or_default();
106 entries.push(now);
107 let count = entries.len() as u32;
108 *self.duplicate_counts.entry(key).or_default() = count;
109
110 if self.blocked_threshold > 0 && count > self.blocked_threshold {
111 return ThrottleResult {
112 level: ThrottleLevel::Blocked,
113 call_count: count,
114 message: Some(self.block_message(tool, count)),
115 };
116 }
117 if count > self.reduced_threshold {
118 if !crate::core::protocol::meta_visible() {
119 return ThrottleResult {
120 level: ThrottleLevel::Reduced,
121 call_count: count,
122 message: None,
123 };
124 }
125 return ThrottleResult {
126 level: ThrottleLevel::Reduced,
127 call_count: count,
128 message: Some(format!(
129 "Warning: {tool} called {count}x with same args. \
130 Results reduced. Try a different approach or narrow your scope."
131 )),
132 };
133 }
134 if count > self.normal_threshold {
135 if !crate::core::protocol::meta_visible() {
136 return ThrottleResult {
137 level: ThrottleLevel::Reduced,
138 call_count: count,
139 message: None,
140 };
141 }
142 return ThrottleResult {
143 level: ThrottleLevel::Reduced,
144 call_count: count,
145 message: Some(format!(
146 "Note: {tool} called {count}x with similar args. Consider narrowing scope."
147 )),
148 };
149 }
150 ThrottleResult {
151 level: ThrottleLevel::Normal,
152 call_count: count,
153 message: None,
154 }
155 }
156
157 pub fn record_search(
160 &mut self,
161 tool: &str,
162 args_fingerprint: &str,
163 search_pattern: Option<&str>,
164 ) -> ThrottleResult {
165 let now = Instant::now();
166
167 self.search_group_history.push(now);
168 let search_count = self.search_group_history.len() as u32;
169
170 let similar_count = if let Some(pat) = search_pattern {
171 let sc = self.count_similar_patterns(pat);
172 if !pat.is_empty() {
173 self.recent_search_patterns.push(pat.to_string());
174 if self.recent_search_patterns.len() > 15 {
175 self.recent_search_patterns.remove(0);
176 }
177 }
178 sc
179 } else {
180 0
181 };
182
183 if self.blocked_threshold > 0 && similar_count >= self.blocked_threshold {
185 return ThrottleResult {
186 level: ThrottleLevel::Blocked,
187 call_count: similar_count,
188 message: Some(self.search_block_message(similar_count)),
189 };
190 }
191
192 if self.blocked_threshold > 0 && search_count > self.search_group_limit {
194 return ThrottleResult {
195 level: ThrottleLevel::Blocked,
196 call_count: search_count,
197 message: Some(self.search_group_block_message(search_count)),
198 };
199 }
200
201 if similar_count >= self.reduced_threshold {
202 if !crate::core::protocol::meta_visible() {
203 return ThrottleResult {
204 level: ThrottleLevel::Reduced,
205 call_count: similar_count,
206 message: None,
207 };
208 }
209 return ThrottleResult {
210 level: ThrottleLevel::Reduced,
211 call_count: similar_count,
212 message: Some(format!(
213 "Warning: You've searched for similar patterns {similar_count}x. \
214 Narrow your search with the 'path' parameter or try ctx_tree first."
215 )),
216 };
217 }
218
219 if search_count > self.search_group_limit.saturating_sub(3) {
220 let per_fp = self.record_call(tool, args_fingerprint);
221 if per_fp.level != ThrottleLevel::Normal {
222 return per_fp;
223 }
224 if !crate::core::protocol::meta_visible() {
225 return ThrottleResult {
226 level: ThrottleLevel::Reduced,
227 call_count: search_count,
228 message: None,
229 };
230 }
231 return ThrottleResult {
232 level: ThrottleLevel::Reduced,
233 call_count: search_count,
234 message: Some(format!(
235 "Note: {search_count} search calls in the last {}s. \
236 Use ctx_tree to orient first, then scope searches with 'path'.",
237 self.window.as_secs()
238 )),
239 };
240 }
241
242 self.record_call(tool, args_fingerprint)
243 }
244
245 pub fn is_search_tool(tool: &str) -> bool {
247 SEARCH_TOOLS.contains(&tool)
248 }
249
250 pub fn is_search_shell_command(command: &str) -> bool {
252 let cmd = command.trim_start();
253 SEARCH_SHELL_PREFIXES.iter().any(|p| cmd.starts_with(p))
254 }
255
256 pub fn fingerprint(args: &serde_json::Value) -> String {
258 use std::collections::hash_map::DefaultHasher;
259 use std::hash::{Hash, Hasher};
260
261 let canonical = canonical_json(args);
262 let mut hasher = DefaultHasher::new();
263 canonical.hash(&mut hasher);
264 format!("{:016x}", hasher.finish())
265 }
266
267 pub fn stats(&self) -> Vec<(String, u32)> {
269 let mut entries: Vec<(String, u32)> = self
270 .duplicate_counts
271 .iter()
272 .filter(|(_, &count)| count > 1)
273 .map(|(k, &v)| (k.clone(), v))
274 .collect();
275 entries.sort_by_key(|x| std::cmp::Reverse(x.1));
276 entries
277 }
278
279 pub fn reset(&mut self) {
281 self.call_history.clear();
282 self.duplicate_counts.clear();
283 self.search_group_history.clear();
284 self.recent_search_patterns.clear();
285 }
286
287 fn prune_window(&mut self, now: Instant) {
288 for entries in self.call_history.values_mut() {
289 entries.retain(|t| now.duration_since(*t) < self.window);
290 }
291 self.search_group_history
292 .retain(|t| now.duration_since(*t) < self.window);
293 }
294
295 fn count_similar_patterns(&self, new_pattern: &str) -> u32 {
296 let new_lower = new_pattern.to_lowercase();
297 let new_root = extract_alpha_root(&new_lower);
298
299 let mut count = 0u32;
300 for existing in &self.recent_search_patterns {
301 let existing_lower = existing.to_lowercase();
302 if patterns_are_similar(&new_lower, &existing_lower) {
303 count += 1;
304 } else if new_root.len() >= 4 {
305 let existing_root = extract_alpha_root(&existing_lower);
306 if existing_root.len() >= 4
307 && (new_root.starts_with(&existing_root)
308 || existing_root.starts_with(&new_root))
309 {
310 count += 1;
311 }
312 }
313 }
314 count
315 }
316
317 fn block_message(&self, tool: &str, count: u32) -> String {
318 if Self::is_search_tool(tool) {
319 self.search_block_message(count)
320 } else {
321 format!(
322 "LOOP DETECTED: {tool} called {count}x with same/similar args. \
323 Call blocked. Change your approach — the current strategy is not working."
324 )
325 }
326 }
327
328 #[allow(clippy::unused_self)]
329 fn search_block_message(&self, count: u32) -> String {
330 format!(
331 "LOOP DETECTED: You've searched {count}x with similar patterns. STOP searching and change strategy. \
332 1) Use ctx_tree to understand the project structure first. \
333 2) Narrow your search with the 'path' parameter to a specific directory. \
334 3) Use ctx_read with mode='map' to understand a file before searching more."
335 )
336 }
337
338 fn search_group_block_message(&self, count: u32) -> String {
339 format!(
340 "LOOP DETECTED: {count} search calls in {}s — too many. STOP and rethink. \
341 1) Use ctx_tree to map the project structure. \
342 2) Pick ONE specific directory and search there with the 'path' parameter. \
343 3) Read files with ctx_read mode='map' instead of searching blindly.",
344 self.window.as_secs()
345 )
346 }
347}
348
349fn extract_alpha_root(pattern: &str) -> String {
350 pattern
351 .chars()
352 .take_while(|c| c.is_alphanumeric())
353 .collect()
354}
355
356fn patterns_are_similar(a: &str, b: &str) -> bool {
357 if a == b {
358 return true;
359 }
360 if a.contains(b) || b.contains(a) {
361 return true;
362 }
363 let a_alpha: String = a.chars().filter(|c| c.is_alphanumeric()).collect();
364 let b_alpha: String = b.chars().filter(|c| c.is_alphanumeric()).collect();
365 if a_alpha.len() >= 3
366 && b_alpha.len() >= 3
367 && (a_alpha.contains(&b_alpha) || b_alpha.contains(&a_alpha))
368 {
369 return true;
370 }
371 false
372}
373
374fn canonical_json(value: &serde_json::Value) -> String {
375 match value {
376 serde_json::Value::Object(map) => {
377 let mut keys: Vec<&String> = map.keys().collect();
378 keys.sort();
379 let entries: Vec<String> = keys
380 .iter()
381 .map(|k| format!("{}:{}", k, canonical_json(&map[*k])))
382 .collect();
383 format!("{{{}}}", entries.join(","))
384 }
385 serde_json::Value::Array(arr) => {
386 let entries: Vec<String> = arr.iter().map(canonical_json).collect();
387 format!("[{}]", entries.join(","))
388 }
389 _ => value.to_string(),
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 fn test_config(normal: u32, reduced: u32, blocked: u32) -> LoopDetectionConfig {
398 LoopDetectionConfig {
399 normal_threshold: normal,
400 reduced_threshold: reduced,
401 blocked_threshold: blocked,
402 window_secs: 300,
403 search_group_limit: 10,
404 tool_total_limits: std::collections::HashMap::new(),
405 }
406 }
407
408 #[test]
409 fn normal_calls_pass_through() {
410 let mut detector = LoopDetector::new();
411 let r1 = detector.record_call("ctx_read", "abc123");
412 assert_eq!(r1.level, ThrottleLevel::Normal);
413 assert_eq!(r1.call_count, 1);
414 assert!(r1.message.is_none());
415 }
416
417 #[test]
418 fn repeated_calls_trigger_reduced() {
419 let _lock = crate::core::data_dir::test_env_lock();
420 std::env::set_var("LEAN_CTX_META", "1");
421 let cfg = LoopDetectionConfig::default();
422 let mut detector = LoopDetector::with_config(&cfg);
423 for _ in 0..cfg.normal_threshold {
424 detector.record_call("ctx_read", "same_fp");
425 }
426 let result = detector.record_call("ctx_read", "same_fp");
427 assert_eq!(result.level, ThrottleLevel::Reduced);
428 assert!(result.message.is_some());
429 std::env::remove_var("LEAN_CTX_META");
430 }
431
432 #[test]
433 fn excessive_calls_get_blocked_when_enabled() {
434 let cfg = LoopDetectionConfig {
436 blocked_threshold: 6,
437 ..Default::default()
438 };
439 let mut detector = LoopDetector::with_config(&cfg);
440 for _ in 0..cfg.blocked_threshold {
441 detector.record_call("ctx_shell", "same_fp");
442 }
443 let result = detector.record_call("ctx_shell", "same_fp");
444 assert_eq!(result.level, ThrottleLevel::Blocked);
445 assert!(result.message.unwrap().contains("LOOP DETECTED"));
446 }
447
448 #[test]
449 fn blocking_disabled_by_default() {
450 let cfg = LoopDetectionConfig::default();
452 assert_eq!(cfg.blocked_threshold, 0);
453 let mut detector = LoopDetector::with_config(&cfg);
454 for _ in 0..100 {
456 detector.record_call("ctx_shell", "same_fp");
457 }
458 let result = detector.record_call("ctx_shell", "same_fp");
459 assert_ne!(result.level, ThrottleLevel::Blocked);
461 }
462
463 #[test]
464 fn different_args_tracked_separately() {
465 let mut detector = LoopDetector::new();
466 for _ in 0..10 {
467 detector.record_call("ctx_read", "fp_a");
468 }
469 let result = detector.record_call("ctx_read", "fp_b");
470 assert_eq!(result.level, ThrottleLevel::Normal);
471 assert_eq!(result.call_count, 1);
472 }
473
474 #[test]
475 fn fingerprint_deterministic() {
476 let args = serde_json::json!({"path": "test.rs", "mode": "full"});
477 let fp1 = LoopDetector::fingerprint(&args);
478 let fp2 = LoopDetector::fingerprint(&args);
479 assert_eq!(fp1, fp2);
480 }
481
482 #[test]
483 fn fingerprint_order_independent() {
484 let a = serde_json::json!({"mode": "full", "path": "test.rs"});
485 let b = serde_json::json!({"path": "test.rs", "mode": "full"});
486 assert_eq!(LoopDetector::fingerprint(&a), LoopDetector::fingerprint(&b));
487 }
488
489 #[test]
490 fn stats_shows_duplicates() {
491 let mut detector = LoopDetector::new();
492 for _ in 0..5 {
493 detector.record_call("ctx_read", "fp_a");
494 }
495 detector.record_call("ctx_shell", "fp_b");
496 let stats = detector.stats();
497 assert_eq!(stats.len(), 1);
498 assert_eq!(stats[0].1, 5);
499 }
500
501 #[test]
502 fn reset_clears_state() {
503 let mut detector = LoopDetector::new();
504 for _ in 0..5 {
505 detector.record_call("ctx_read", "fp_a");
506 }
507 detector.reset();
508 let result = detector.record_call("ctx_read", "fp_a");
509 assert_eq!(result.call_count, 1);
510 }
511
512 #[test]
513 fn custom_thresholds_from_config() {
514 let cfg = test_config(1, 2, 3);
515 let mut detector = LoopDetector::with_config(&cfg);
516 detector.record_call("ctx_read", "fp");
517 let r = detector.record_call("ctx_read", "fp");
518 assert_eq!(r.level, ThrottleLevel::Reduced);
519 detector.record_call("ctx_read", "fp");
520 let r = detector.record_call("ctx_read", "fp");
521 assert_eq!(r.level, ThrottleLevel::Blocked);
522 }
523
524 #[test]
525 fn similar_patterns_detected() {
526 assert!(patterns_are_similar("compress", "compress"));
527 assert!(patterns_are_similar("compress", "compression"));
528 assert!(patterns_are_similar("compress.*data", "compress"));
529 assert!(!patterns_are_similar("foo", "bar"));
530 assert!(!patterns_are_similar("ab", "cd"));
531 }
532
533 #[test]
534 fn search_group_tracking_when_blocking_enabled() {
535 let cfg = LoopDetectionConfig {
537 search_group_limit: 5,
538 blocked_threshold: 6, ..Default::default()
540 };
541 let mut detector = LoopDetector::with_config(&cfg);
542 for i in 0..5 {
543 let fp = format!("fp_{i}");
544 let r = detector.record_search("ctx_search", &fp, Some(&format!("pattern_{i}")));
545 assert_ne!(r.level, ThrottleLevel::Blocked, "call {i} should not block");
546 }
547 let r = detector.record_search("ctx_search", "fp_5", Some("pattern_5"));
548 assert_eq!(r.level, ThrottleLevel::Blocked);
549 assert!(r.message.unwrap().contains("search calls"));
550 }
551
552 #[test]
553 fn similar_search_patterns_trigger_block_when_enabled() {
554 let cfg = LoopDetectionConfig {
556 blocked_threshold: 6,
557 ..Default::default()
558 };
559 let mut detector = LoopDetector::with_config(&cfg);
560 let variants = [
561 "compress",
562 "compression",
563 "compress.*data",
564 "compress_output",
565 "compressor",
566 "compress_result",
567 "compress_file",
568 ];
569 for (i, pat) in variants
570 .iter()
571 .enumerate()
572 .take(cfg.blocked_threshold as usize)
573 {
574 detector.record_search("ctx_search", &format!("fp_{i}"), Some(pat));
575 }
576 let r = detector.record_search("ctx_search", "fp_new", Some("compress_all"));
577 assert_eq!(r.level, ThrottleLevel::Blocked);
578 }
579
580 #[test]
581 fn is_search_tool_detection() {
582 assert!(LoopDetector::is_search_tool("ctx_search"));
583 assert!(LoopDetector::is_search_tool("ctx_semantic_search"));
584 assert!(!LoopDetector::is_search_tool("ctx_read"));
585 assert!(!LoopDetector::is_search_tool("ctx_shell"));
586 }
587
588 #[test]
589 fn is_search_shell_command_detection() {
590 assert!(LoopDetector::is_search_shell_command("grep -r foo ."));
591 assert!(LoopDetector::is_search_shell_command("rg pattern src/"));
592 assert!(LoopDetector::is_search_shell_command("find . -name '*.rs'"));
593 assert!(!LoopDetector::is_search_shell_command("cargo build"));
594 assert!(!LoopDetector::is_search_shell_command("git status"));
595 }
596
597 #[test]
598 fn search_block_message_has_guidance_when_blocking_enabled() {
599 let cfg = LoopDetectionConfig {
601 blocked_threshold: 6,
602 search_group_limit: 8,
603 ..Default::default()
604 };
605 let mut detector = LoopDetector::with_config(&cfg);
606 for i in 0..10 {
607 detector.record_search("ctx_search", &format!("fp_{i}"), Some("compress"));
608 }
609 let r = detector.record_search("ctx_search", "fp_new", Some("compress"));
610 assert_eq!(r.level, ThrottleLevel::Blocked);
611 let msg = r.message.unwrap();
612 assert!(msg.contains("ctx_tree"));
613 assert!(msg.contains("path"));
614 assert!(msg.contains("ctx_read"));
615 }
616}