1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use super::config::LoopDetectionConfig;
5
6const SEARCH_TOOLS: &[&str] = &["ctx_search", "ctx_semantic_search"];
7
8const SEARCH_SHELL_PREFIXES: &[&str] = &["grep ", "rg ", "find ", "fd ", "ag ", "ack "];
9
10#[derive(Debug, Clone)]
12pub struct LoopDetector {
13 call_history: HashMap<String, Vec<Instant>>,
14 duplicate_counts: HashMap<String, u32>,
15 search_group_history: Vec<Instant>,
16 recent_search_patterns: Vec<String>,
17 normal_threshold: u32,
18 reduced_threshold: u32,
19 blocked_threshold: u32,
20 window: Duration,
21 search_group_limit: u32,
22}
23
24#[derive(Debug, Clone, PartialEq)]
26pub enum ThrottleLevel {
27 Normal,
28 Reduced,
29 Blocked,
30}
31
32#[derive(Debug, Clone)]
34pub struct ThrottleResult {
35 pub level: ThrottleLevel,
36 pub call_count: u32,
37 pub message: Option<String>,
38}
39
40impl Default for LoopDetector {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl LoopDetector {
47 pub fn new() -> Self {
49 Self::with_config(&LoopDetectionConfig::default())
50 }
51
52 pub fn with_config(cfg: &LoopDetectionConfig) -> Self {
55 Self {
56 call_history: HashMap::new(),
57 duplicate_counts: HashMap::new(),
58 search_group_history: Vec::new(),
59 recent_search_patterns: Vec::new(),
60 normal_threshold: cfg.normal_threshold.max(1),
61 reduced_threshold: cfg.reduced_threshold.max(2),
62 blocked_threshold: cfg.blocked_threshold,
64 window: Duration::from_secs(cfg.window_secs),
65 search_group_limit: if cfg.blocked_threshold == 0 {
66 u32::MAX
67 } else {
68 cfg.search_group_limit.max(3)
69 },
70 }
71 }
72
73 pub fn record_call(&mut self, tool: &str, args_fingerprint: &str) -> ThrottleResult {
75 let now = Instant::now();
76 self.prune_window(now);
77
78 let key = format!("{tool}:{args_fingerprint}");
79 let entries = self.call_history.entry(key.clone()).or_default();
80 entries.push(now);
81 let count = entries.len() as u32;
82 *self.duplicate_counts.entry(key).or_default() = count;
83
84 if self.blocked_threshold > 0 && count > self.blocked_threshold {
86 return ThrottleResult {
87 level: ThrottleLevel::Blocked,
88 call_count: count,
89 message: Some(self.block_message(tool, count)),
90 };
91 }
92 if count > self.reduced_threshold {
93 if !crate::core::protocol::meta_visible() {
94 return ThrottleResult {
95 level: ThrottleLevel::Reduced,
96 call_count: count,
97 message: None,
98 };
99 }
100 return ThrottleResult {
101 level: ThrottleLevel::Reduced,
102 call_count: count,
103 message: Some(format!(
104 "Warning: {tool} called {count}x with same args. \
105 Results reduced. Try a different approach or narrow your scope."
106 )),
107 };
108 }
109 if count > self.normal_threshold {
110 if !crate::core::protocol::meta_visible() {
111 return ThrottleResult {
112 level: ThrottleLevel::Reduced,
113 call_count: count,
114 message: None,
115 };
116 }
117 return ThrottleResult {
118 level: ThrottleLevel::Reduced,
119 call_count: count,
120 message: Some(format!(
121 "Note: {tool} called {count}x with similar args. Consider narrowing scope."
122 )),
123 };
124 }
125 ThrottleResult {
126 level: ThrottleLevel::Normal,
127 call_count: count,
128 message: None,
129 }
130 }
131
132 pub fn record_search(
135 &mut self,
136 tool: &str,
137 args_fingerprint: &str,
138 search_pattern: Option<&str>,
139 ) -> ThrottleResult {
140 let now = Instant::now();
141
142 self.search_group_history.push(now);
143 let search_count = self.search_group_history.len() as u32;
144
145 let similar_count = if let Some(pat) = search_pattern {
146 let sc = self.count_similar_patterns(pat);
147 if !pat.is_empty() {
148 self.recent_search_patterns.push(pat.to_string());
149 if self.recent_search_patterns.len() > 15 {
150 self.recent_search_patterns.remove(0);
151 }
152 }
153 sc
154 } else {
155 0
156 };
157
158 if self.blocked_threshold > 0 && similar_count >= self.blocked_threshold {
160 return ThrottleResult {
161 level: ThrottleLevel::Blocked,
162 call_count: similar_count,
163 message: Some(self.search_block_message(similar_count)),
164 };
165 }
166
167 if self.blocked_threshold > 0 && search_count > self.search_group_limit {
169 return ThrottleResult {
170 level: ThrottleLevel::Blocked,
171 call_count: search_count,
172 message: Some(self.search_group_block_message(search_count)),
173 };
174 }
175
176 if similar_count >= self.reduced_threshold {
177 if !crate::core::protocol::meta_visible() {
178 return ThrottleResult {
179 level: ThrottleLevel::Reduced,
180 call_count: similar_count,
181 message: None,
182 };
183 }
184 return ThrottleResult {
185 level: ThrottleLevel::Reduced,
186 call_count: similar_count,
187 message: Some(format!(
188 "Warning: You've searched for similar patterns {similar_count}x. \
189 Narrow your search with the 'path' parameter or try ctx_tree first."
190 )),
191 };
192 }
193
194 if search_count > self.search_group_limit.saturating_sub(3) {
195 let per_fp = self.record_call(tool, args_fingerprint);
196 if per_fp.level != ThrottleLevel::Normal {
197 return per_fp;
198 }
199 if !crate::core::protocol::meta_visible() {
200 return ThrottleResult {
201 level: ThrottleLevel::Reduced,
202 call_count: search_count,
203 message: None,
204 };
205 }
206 return ThrottleResult {
207 level: ThrottleLevel::Reduced,
208 call_count: search_count,
209 message: Some(format!(
210 "Note: {search_count} search calls in the last {}s. \
211 Use ctx_tree to orient first, then scope searches with 'path'.",
212 self.window.as_secs()
213 )),
214 };
215 }
216
217 self.record_call(tool, args_fingerprint)
218 }
219
220 pub fn is_search_tool(tool: &str) -> bool {
222 SEARCH_TOOLS.contains(&tool)
223 }
224
225 pub fn is_search_shell_command(command: &str) -> bool {
227 let cmd = command.trim_start();
228 SEARCH_SHELL_PREFIXES.iter().any(|p| cmd.starts_with(p))
229 }
230
231 pub fn fingerprint(args: &serde_json::Value) -> String {
233 use std::collections::hash_map::DefaultHasher;
234 use std::hash::{Hash, Hasher};
235
236 let canonical = canonical_json(args);
237 let mut hasher = DefaultHasher::new();
238 canonical.hash(&mut hasher);
239 format!("{:016x}", hasher.finish())
240 }
241
242 pub fn stats(&self) -> Vec<(String, u32)> {
244 let mut entries: Vec<(String, u32)> = self
245 .duplicate_counts
246 .iter()
247 .filter(|(_, &count)| count > 1)
248 .map(|(k, &v)| (k.clone(), v))
249 .collect();
250 entries.sort_by_key(|x| std::cmp::Reverse(x.1));
251 entries
252 }
253
254 pub fn reset(&mut self) {
256 self.call_history.clear();
257 self.duplicate_counts.clear();
258 self.search_group_history.clear();
259 self.recent_search_patterns.clear();
260 }
261
262 fn prune_window(&mut self, now: Instant) {
263 for entries in self.call_history.values_mut() {
264 entries.retain(|t| now.duration_since(*t) < self.window);
265 }
266 self.search_group_history
267 .retain(|t| now.duration_since(*t) < self.window);
268 }
269
270 fn count_similar_patterns(&self, new_pattern: &str) -> u32 {
271 let new_lower = new_pattern.to_lowercase();
272 let new_root = extract_alpha_root(&new_lower);
273
274 let mut count = 0u32;
275 for existing in &self.recent_search_patterns {
276 let existing_lower = existing.to_lowercase();
277 if patterns_are_similar(&new_lower, &existing_lower) {
278 count += 1;
279 } else if new_root.len() >= 4 {
280 let existing_root = extract_alpha_root(&existing_lower);
281 if existing_root.len() >= 4
282 && (new_root.starts_with(&existing_root)
283 || existing_root.starts_with(&new_root))
284 {
285 count += 1;
286 }
287 }
288 }
289 count
290 }
291
292 fn block_message(&self, tool: &str, count: u32) -> String {
293 if Self::is_search_tool(tool) {
294 self.search_block_message(count)
295 } else {
296 format!(
297 "LOOP DETECTED: {tool} called {count}x with same/similar args. \
298 Call blocked. Change your approach — the current strategy is not working."
299 )
300 }
301 }
302
303 #[allow(clippy::unused_self)]
304 fn search_block_message(&self, count: u32) -> String {
305 format!(
306 "LOOP DETECTED: You've searched {count}x with similar patterns. STOP searching and change strategy. \
307 1) Use ctx_tree to understand the project structure first. \
308 2) Narrow your search with the 'path' parameter to a specific directory. \
309 3) Use ctx_read with mode='map' to understand a file before searching more."
310 )
311 }
312
313 fn search_group_block_message(&self, count: u32) -> String {
314 format!(
315 "LOOP DETECTED: {count} search calls in {}s — too many. STOP and rethink. \
316 1) Use ctx_tree to map the project structure. \
317 2) Pick ONE specific directory and search there with the 'path' parameter. \
318 3) Read files with ctx_read mode='map' instead of searching blindly.",
319 self.window.as_secs()
320 )
321 }
322}
323
324fn extract_alpha_root(pattern: &str) -> String {
325 pattern
326 .chars()
327 .take_while(|c| c.is_alphanumeric())
328 .collect()
329}
330
331fn patterns_are_similar(a: &str, b: &str) -> bool {
332 if a == b {
333 return true;
334 }
335 if a.contains(b) || b.contains(a) {
336 return true;
337 }
338 let a_alpha: String = a.chars().filter(|c| c.is_alphanumeric()).collect();
339 let b_alpha: String = b.chars().filter(|c| c.is_alphanumeric()).collect();
340 if a_alpha.len() >= 3
341 && b_alpha.len() >= 3
342 && (a_alpha.contains(&b_alpha) || b_alpha.contains(&a_alpha))
343 {
344 return true;
345 }
346 false
347}
348
349fn canonical_json(value: &serde_json::Value) -> String {
350 match value {
351 serde_json::Value::Object(map) => {
352 let mut keys: Vec<&String> = map.keys().collect();
353 keys.sort();
354 let entries: Vec<String> = keys
355 .iter()
356 .map(|k| format!("{}:{}", k, canonical_json(&map[*k])))
357 .collect();
358 format!("{{{}}}", entries.join(","))
359 }
360 serde_json::Value::Array(arr) => {
361 let entries: Vec<String> = arr.iter().map(canonical_json).collect();
362 format!("[{}]", entries.join(","))
363 }
364 _ => value.to_string(),
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 fn test_config(normal: u32, reduced: u32, blocked: u32) -> LoopDetectionConfig {
373 LoopDetectionConfig {
374 normal_threshold: normal,
375 reduced_threshold: reduced,
376 blocked_threshold: blocked,
377 window_secs: 300,
378 search_group_limit: 10,
379 }
380 }
381
382 #[test]
383 fn normal_calls_pass_through() {
384 let mut detector = LoopDetector::new();
385 let r1 = detector.record_call("ctx_read", "abc123");
386 assert_eq!(r1.level, ThrottleLevel::Normal);
387 assert_eq!(r1.call_count, 1);
388 assert!(r1.message.is_none());
389 }
390
391 #[test]
392 fn repeated_calls_trigger_reduced() {
393 let _lock = crate::core::data_dir::test_env_lock();
394 std::env::set_var("LEAN_CTX_META", "1");
395 let cfg = LoopDetectionConfig::default();
396 let mut detector = LoopDetector::with_config(&cfg);
397 for _ in 0..cfg.normal_threshold {
398 detector.record_call("ctx_read", "same_fp");
399 }
400 let result = detector.record_call("ctx_read", "same_fp");
401 assert_eq!(result.level, ThrottleLevel::Reduced);
402 assert!(result.message.is_some());
403 std::env::remove_var("LEAN_CTX_META");
404 }
405
406 #[test]
407 fn excessive_calls_get_blocked_when_enabled() {
408 let cfg = LoopDetectionConfig {
410 blocked_threshold: 6,
411 ..Default::default()
412 };
413 let mut detector = LoopDetector::with_config(&cfg);
414 for _ in 0..cfg.blocked_threshold {
415 detector.record_call("ctx_shell", "same_fp");
416 }
417 let result = detector.record_call("ctx_shell", "same_fp");
418 assert_eq!(result.level, ThrottleLevel::Blocked);
419 assert!(result.message.unwrap().contains("LOOP DETECTED"));
420 }
421
422 #[test]
423 fn blocking_disabled_by_default() {
424 let cfg = LoopDetectionConfig::default();
426 assert_eq!(cfg.blocked_threshold, 0);
427 let mut detector = LoopDetector::with_config(&cfg);
428 for _ in 0..100 {
430 detector.record_call("ctx_shell", "same_fp");
431 }
432 let result = detector.record_call("ctx_shell", "same_fp");
433 assert_ne!(result.level, ThrottleLevel::Blocked);
435 }
436
437 #[test]
438 fn different_args_tracked_separately() {
439 let mut detector = LoopDetector::new();
440 for _ in 0..10 {
441 detector.record_call("ctx_read", "fp_a");
442 }
443 let result = detector.record_call("ctx_read", "fp_b");
444 assert_eq!(result.level, ThrottleLevel::Normal);
445 assert_eq!(result.call_count, 1);
446 }
447
448 #[test]
449 fn fingerprint_deterministic() {
450 let args = serde_json::json!({"path": "test.rs", "mode": "full"});
451 let fp1 = LoopDetector::fingerprint(&args);
452 let fp2 = LoopDetector::fingerprint(&args);
453 assert_eq!(fp1, fp2);
454 }
455
456 #[test]
457 fn fingerprint_order_independent() {
458 let a = serde_json::json!({"mode": "full", "path": "test.rs"});
459 let b = serde_json::json!({"path": "test.rs", "mode": "full"});
460 assert_eq!(LoopDetector::fingerprint(&a), LoopDetector::fingerprint(&b));
461 }
462
463 #[test]
464 fn stats_shows_duplicates() {
465 let mut detector = LoopDetector::new();
466 for _ in 0..5 {
467 detector.record_call("ctx_read", "fp_a");
468 }
469 detector.record_call("ctx_shell", "fp_b");
470 let stats = detector.stats();
471 assert_eq!(stats.len(), 1);
472 assert_eq!(stats[0].1, 5);
473 }
474
475 #[test]
476 fn reset_clears_state() {
477 let mut detector = LoopDetector::new();
478 for _ in 0..5 {
479 detector.record_call("ctx_read", "fp_a");
480 }
481 detector.reset();
482 let result = detector.record_call("ctx_read", "fp_a");
483 assert_eq!(result.call_count, 1);
484 }
485
486 #[test]
487 fn custom_thresholds_from_config() {
488 let cfg = test_config(1, 2, 3);
489 let mut detector = LoopDetector::with_config(&cfg);
490 detector.record_call("ctx_read", "fp");
491 let r = detector.record_call("ctx_read", "fp");
492 assert_eq!(r.level, ThrottleLevel::Reduced);
493 detector.record_call("ctx_read", "fp");
494 let r = detector.record_call("ctx_read", "fp");
495 assert_eq!(r.level, ThrottleLevel::Blocked);
496 }
497
498 #[test]
499 fn similar_patterns_detected() {
500 assert!(patterns_are_similar("compress", "compress"));
501 assert!(patterns_are_similar("compress", "compression"));
502 assert!(patterns_are_similar("compress.*data", "compress"));
503 assert!(!patterns_are_similar("foo", "bar"));
504 assert!(!patterns_are_similar("ab", "cd"));
505 }
506
507 #[test]
508 fn search_group_tracking_when_blocking_enabled() {
509 let cfg = LoopDetectionConfig {
511 search_group_limit: 5,
512 blocked_threshold: 6, ..Default::default()
514 };
515 let mut detector = LoopDetector::with_config(&cfg);
516 for i in 0..5 {
517 let fp = format!("fp_{i}");
518 let r = detector.record_search("ctx_search", &fp, Some(&format!("pattern_{i}")));
519 assert_ne!(r.level, ThrottleLevel::Blocked, "call {i} should not block");
520 }
521 let r = detector.record_search("ctx_search", "fp_5", Some("pattern_5"));
522 assert_eq!(r.level, ThrottleLevel::Blocked);
523 assert!(r.message.unwrap().contains("search calls"));
524 }
525
526 #[test]
527 fn similar_search_patterns_trigger_block_when_enabled() {
528 let cfg = LoopDetectionConfig {
530 blocked_threshold: 6,
531 ..Default::default()
532 };
533 let mut detector = LoopDetector::with_config(&cfg);
534 let variants = [
535 "compress",
536 "compression",
537 "compress.*data",
538 "compress_output",
539 "compressor",
540 "compress_result",
541 "compress_file",
542 ];
543 for (i, pat) in variants
544 .iter()
545 .enumerate()
546 .take(cfg.blocked_threshold as usize)
547 {
548 detector.record_search("ctx_search", &format!("fp_{i}"), Some(pat));
549 }
550 let r = detector.record_search("ctx_search", "fp_new", Some("compress_all"));
551 assert_eq!(r.level, ThrottleLevel::Blocked);
552 }
553
554 #[test]
555 fn is_search_tool_detection() {
556 assert!(LoopDetector::is_search_tool("ctx_search"));
557 assert!(LoopDetector::is_search_tool("ctx_semantic_search"));
558 assert!(!LoopDetector::is_search_tool("ctx_read"));
559 assert!(!LoopDetector::is_search_tool("ctx_shell"));
560 }
561
562 #[test]
563 fn is_search_shell_command_detection() {
564 assert!(LoopDetector::is_search_shell_command("grep -r foo ."));
565 assert!(LoopDetector::is_search_shell_command("rg pattern src/"));
566 assert!(LoopDetector::is_search_shell_command("find . -name '*.rs'"));
567 assert!(!LoopDetector::is_search_shell_command("cargo build"));
568 assert!(!LoopDetector::is_search_shell_command("git status"));
569 }
570
571 #[test]
572 fn search_block_message_has_guidance_when_blocking_enabled() {
573 let cfg = LoopDetectionConfig {
575 blocked_threshold: 6,
576 search_group_limit: 8,
577 ..Default::default()
578 };
579 let mut detector = LoopDetector::with_config(&cfg);
580 for i in 0..10 {
581 detector.record_search("ctx_search", &format!("fp_{i}"), Some("compress"));
582 }
583 let r = detector.record_search("ctx_search", "fp_new", Some("compress"));
584 assert_eq!(r.level, ThrottleLevel::Blocked);
585 let msg = r.message.unwrap();
586 assert!(msg.contains("ctx_tree"));
587 assert!(msg.contains("path"));
588 assert!(msg.contains("ctx_read"));
589 }
590}