1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use super::config::LoopDetectionConfig;
5
6const SEARCH_TOOLS: &[&str] = &["ctx_search", "ctx_semantic_search"];
7
8const SEARCH_SHELL_PREFIXES: &[&str] = &["grep ", "rg ", "find ", "fd ", "ag ", "ack "];
9
10#[derive(Debug, Clone)]
12pub struct LoopDetector {
13 call_history: HashMap<String, Vec<Instant>>,
14 duplicate_counts: HashMap<String, u32>,
15 search_group_history: Vec<Instant>,
16 recent_search_patterns: Vec<String>,
17 normal_threshold: u32,
18 reduced_threshold: u32,
19 blocked_threshold: u32,
20 window: Duration,
21 search_group_limit: u32,
22}
23
24#[derive(Debug, Clone, PartialEq)]
26pub enum ThrottleLevel {
27 Normal,
28 Reduced,
29 Blocked,
30}
31
32#[derive(Debug, Clone)]
34pub struct ThrottleResult {
35 pub level: ThrottleLevel,
36 pub call_count: u32,
37 pub message: Option<String>,
38}
39
40impl Default for LoopDetector {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl LoopDetector {
47 pub fn new() -> Self {
49 Self::with_config(&LoopDetectionConfig::default())
50 }
51
52 pub fn with_config(cfg: &LoopDetectionConfig) -> Self {
55 Self {
56 call_history: HashMap::new(),
57 duplicate_counts: HashMap::new(),
58 search_group_history: Vec::new(),
59 recent_search_patterns: Vec::new(),
60 normal_threshold: cfg.normal_threshold.max(1),
61 reduced_threshold: cfg.reduced_threshold.max(2),
62 blocked_threshold: cfg.blocked_threshold,
64 window: Duration::from_secs(cfg.window_secs),
65 search_group_limit: if cfg.blocked_threshold == 0 {
66 u32::MAX
67 } else {
68 cfg.search_group_limit.max(3)
69 },
70 }
71 }
72
73 pub fn record_call(&mut self, tool: &str, args_fingerprint: &str) -> ThrottleResult {
75 let now = Instant::now();
76 self.prune_window(now);
77
78 let key = format!("{tool}:{args_fingerprint}");
79 let entries = self.call_history.entry(key.clone()).or_default();
80 entries.push(now);
81 let count = entries.len() as u32;
82 *self.duplicate_counts.entry(key).or_default() = count;
83
84 if self.blocked_threshold > 0 && count > self.blocked_threshold {
86 return ThrottleResult {
87 level: ThrottleLevel::Blocked,
88 call_count: count,
89 message: Some(self.block_message(tool, count)),
90 };
91 }
92 if count > self.reduced_threshold {
93 return ThrottleResult {
94 level: ThrottleLevel::Reduced,
95 call_count: count,
96 message: Some(format!(
97 "Warning: {tool} called {count}x with same args. \
98 Results reduced. Try a different approach or narrow your scope."
99 )),
100 };
101 }
102 if count > self.normal_threshold {
103 return ThrottleResult {
104 level: ThrottleLevel::Reduced,
105 call_count: count,
106 message: Some(format!(
107 "Note: {tool} called {count}x with similar args. Consider narrowing scope."
108 )),
109 };
110 }
111 ThrottleResult {
112 level: ThrottleLevel::Normal,
113 call_count: count,
114 message: None,
115 }
116 }
117
118 pub fn record_search(
121 &mut self,
122 tool: &str,
123 args_fingerprint: &str,
124 search_pattern: Option<&str>,
125 ) -> ThrottleResult {
126 let now = Instant::now();
127
128 self.search_group_history.push(now);
129 let search_count = self.search_group_history.len() as u32;
130
131 let similar_count = if let Some(pat) = search_pattern {
132 let sc = self.count_similar_patterns(pat);
133 if !pat.is_empty() {
134 self.recent_search_patterns.push(pat.to_string());
135 if self.recent_search_patterns.len() > 15 {
136 self.recent_search_patterns.remove(0);
137 }
138 }
139 sc
140 } else {
141 0
142 };
143
144 if self.blocked_threshold > 0 && similar_count >= self.blocked_threshold {
146 return ThrottleResult {
147 level: ThrottleLevel::Blocked,
148 call_count: similar_count,
149 message: Some(self.search_block_message(similar_count)),
150 };
151 }
152
153 if self.blocked_threshold > 0 && search_count > self.search_group_limit {
155 return ThrottleResult {
156 level: ThrottleLevel::Blocked,
157 call_count: search_count,
158 message: Some(self.search_group_block_message(search_count)),
159 };
160 }
161
162 if similar_count >= self.reduced_threshold {
163 return ThrottleResult {
164 level: ThrottleLevel::Reduced,
165 call_count: similar_count,
166 message: Some(format!(
167 "Warning: You've searched for similar patterns {similar_count}x. \
168 Narrow your search with the 'path' parameter or try ctx_tree first."
169 )),
170 };
171 }
172
173 if search_count > self.search_group_limit.saturating_sub(3) {
174 let per_fp = self.record_call(tool, args_fingerprint);
175 if per_fp.level != ThrottleLevel::Normal {
176 return per_fp;
177 }
178 return ThrottleResult {
179 level: ThrottleLevel::Reduced,
180 call_count: search_count,
181 message: Some(format!(
182 "Note: {search_count} search calls in the last {}s. \
183 Use ctx_tree to orient first, then scope searches with 'path'.",
184 self.window.as_secs()
185 )),
186 };
187 }
188
189 self.record_call(tool, args_fingerprint)
190 }
191
192 pub fn is_search_tool(tool: &str) -> bool {
194 SEARCH_TOOLS.contains(&tool)
195 }
196
197 pub fn is_search_shell_command(command: &str) -> bool {
199 let cmd = command.trim_start();
200 SEARCH_SHELL_PREFIXES.iter().any(|p| cmd.starts_with(p))
201 }
202
203 pub fn fingerprint(args: &serde_json::Value) -> String {
205 use std::collections::hash_map::DefaultHasher;
206 use std::hash::{Hash, Hasher};
207
208 let canonical = canonical_json(args);
209 let mut hasher = DefaultHasher::new();
210 canonical.hash(&mut hasher);
211 format!("{:016x}", hasher.finish())
212 }
213
214 pub fn stats(&self) -> Vec<(String, u32)> {
216 let mut entries: Vec<(String, u32)> = self
217 .duplicate_counts
218 .iter()
219 .filter(|(_, &count)| count > 1)
220 .map(|(k, &v)| (k.clone(), v))
221 .collect();
222 entries.sort_by_key(|x| std::cmp::Reverse(x.1));
223 entries
224 }
225
226 pub fn reset(&mut self) {
228 self.call_history.clear();
229 self.duplicate_counts.clear();
230 self.search_group_history.clear();
231 self.recent_search_patterns.clear();
232 }
233
234 fn prune_window(&mut self, now: Instant) {
235 for entries in self.call_history.values_mut() {
236 entries.retain(|t| now.duration_since(*t) < self.window);
237 }
238 self.search_group_history
239 .retain(|t| now.duration_since(*t) < self.window);
240 }
241
242 fn count_similar_patterns(&self, new_pattern: &str) -> u32 {
243 let new_lower = new_pattern.to_lowercase();
244 let new_root = extract_alpha_root(&new_lower);
245
246 let mut count = 0u32;
247 for existing in &self.recent_search_patterns {
248 let existing_lower = existing.to_lowercase();
249 if patterns_are_similar(&new_lower, &existing_lower) {
250 count += 1;
251 } else if new_root.len() >= 4 {
252 let existing_root = extract_alpha_root(&existing_lower);
253 if existing_root.len() >= 4
254 && (new_root.starts_with(&existing_root)
255 || existing_root.starts_with(&new_root))
256 {
257 count += 1;
258 }
259 }
260 }
261 count
262 }
263
264 fn block_message(&self, tool: &str, count: u32) -> String {
265 if Self::is_search_tool(tool) {
266 self.search_block_message(count)
267 } else {
268 format!(
269 "LOOP DETECTED: {tool} called {count}x with same/similar args. \
270 Call blocked. Change your approach — the current strategy is not working."
271 )
272 }
273 }
274
275 #[allow(clippy::unused_self)]
276 fn search_block_message(&self, count: u32) -> String {
277 format!(
278 "LOOP DETECTED: You've searched {count}x with similar patterns. STOP searching and change strategy. \
279 1) Use ctx_tree to understand the project structure first. \
280 2) Narrow your search with the 'path' parameter to a specific directory. \
281 3) Use ctx_read with mode='map' to understand a file before searching more."
282 )
283 }
284
285 fn search_group_block_message(&self, count: u32) -> String {
286 format!(
287 "LOOP DETECTED: {count} search calls in {}s — too many. STOP and rethink. \
288 1) Use ctx_tree to map the project structure. \
289 2) Pick ONE specific directory and search there with the 'path' parameter. \
290 3) Read files with ctx_read mode='map' instead of searching blindly.",
291 self.window.as_secs()
292 )
293 }
294}
295
296fn extract_alpha_root(pattern: &str) -> String {
297 pattern
298 .chars()
299 .take_while(|c| c.is_alphanumeric())
300 .collect()
301}
302
303fn patterns_are_similar(a: &str, b: &str) -> bool {
304 if a == b {
305 return true;
306 }
307 if a.contains(b) || b.contains(a) {
308 return true;
309 }
310 let a_alpha: String = a.chars().filter(|c| c.is_alphanumeric()).collect();
311 let b_alpha: String = b.chars().filter(|c| c.is_alphanumeric()).collect();
312 if a_alpha.len() >= 3
313 && b_alpha.len() >= 3
314 && (a_alpha.contains(&b_alpha) || b_alpha.contains(&a_alpha))
315 {
316 return true;
317 }
318 false
319}
320
321fn canonical_json(value: &serde_json::Value) -> String {
322 match value {
323 serde_json::Value::Object(map) => {
324 let mut keys: Vec<&String> = map.keys().collect();
325 keys.sort();
326 let entries: Vec<String> = keys
327 .iter()
328 .map(|k| format!("{}:{}", k, canonical_json(&map[*k])))
329 .collect();
330 format!("{{{}}}", entries.join(","))
331 }
332 serde_json::Value::Array(arr) => {
333 let entries: Vec<String> = arr.iter().map(canonical_json).collect();
334 format!("[{}]", entries.join(","))
335 }
336 _ => value.to_string(),
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 fn test_config(normal: u32, reduced: u32, blocked: u32) -> LoopDetectionConfig {
345 LoopDetectionConfig {
346 normal_threshold: normal,
347 reduced_threshold: reduced,
348 blocked_threshold: blocked,
349 window_secs: 300,
350 search_group_limit: 10,
351 }
352 }
353
354 #[test]
355 fn normal_calls_pass_through() {
356 let mut detector = LoopDetector::new();
357 let r1 = detector.record_call("ctx_read", "abc123");
358 assert_eq!(r1.level, ThrottleLevel::Normal);
359 assert_eq!(r1.call_count, 1);
360 assert!(r1.message.is_none());
361 }
362
363 #[test]
364 fn repeated_calls_trigger_reduced() {
365 let cfg = LoopDetectionConfig::default();
366 let mut detector = LoopDetector::with_config(&cfg);
367 for _ in 0..cfg.normal_threshold {
368 detector.record_call("ctx_read", "same_fp");
369 }
370 let result = detector.record_call("ctx_read", "same_fp");
371 assert_eq!(result.level, ThrottleLevel::Reduced);
372 assert!(result.message.is_some());
373 }
374
375 #[test]
376 fn excessive_calls_get_blocked_when_enabled() {
377 let cfg = LoopDetectionConfig {
379 blocked_threshold: 6,
380 ..Default::default()
381 };
382 let mut detector = LoopDetector::with_config(&cfg);
383 for _ in 0..cfg.blocked_threshold {
384 detector.record_call("ctx_shell", "same_fp");
385 }
386 let result = detector.record_call("ctx_shell", "same_fp");
387 assert_eq!(result.level, ThrottleLevel::Blocked);
388 assert!(result.message.unwrap().contains("LOOP DETECTED"));
389 }
390
391 #[test]
392 fn blocking_disabled_by_default() {
393 let cfg = LoopDetectionConfig::default();
395 assert_eq!(cfg.blocked_threshold, 0);
396 let mut detector = LoopDetector::with_config(&cfg);
397 for _ in 0..100 {
399 detector.record_call("ctx_shell", "same_fp");
400 }
401 let result = detector.record_call("ctx_shell", "same_fp");
402 assert_ne!(result.level, ThrottleLevel::Blocked);
404 }
405
406 #[test]
407 fn different_args_tracked_separately() {
408 let mut detector = LoopDetector::new();
409 for _ in 0..10 {
410 detector.record_call("ctx_read", "fp_a");
411 }
412 let result = detector.record_call("ctx_read", "fp_b");
413 assert_eq!(result.level, ThrottleLevel::Normal);
414 assert_eq!(result.call_count, 1);
415 }
416
417 #[test]
418 fn fingerprint_deterministic() {
419 let args = serde_json::json!({"path": "test.rs", "mode": "full"});
420 let fp1 = LoopDetector::fingerprint(&args);
421 let fp2 = LoopDetector::fingerprint(&args);
422 assert_eq!(fp1, fp2);
423 }
424
425 #[test]
426 fn fingerprint_order_independent() {
427 let a = serde_json::json!({"mode": "full", "path": "test.rs"});
428 let b = serde_json::json!({"path": "test.rs", "mode": "full"});
429 assert_eq!(LoopDetector::fingerprint(&a), LoopDetector::fingerprint(&b));
430 }
431
432 #[test]
433 fn stats_shows_duplicates() {
434 let mut detector = LoopDetector::new();
435 for _ in 0..5 {
436 detector.record_call("ctx_read", "fp_a");
437 }
438 detector.record_call("ctx_shell", "fp_b");
439 let stats = detector.stats();
440 assert_eq!(stats.len(), 1);
441 assert_eq!(stats[0].1, 5);
442 }
443
444 #[test]
445 fn reset_clears_state() {
446 let mut detector = LoopDetector::new();
447 for _ in 0..5 {
448 detector.record_call("ctx_read", "fp_a");
449 }
450 detector.reset();
451 let result = detector.record_call("ctx_read", "fp_a");
452 assert_eq!(result.call_count, 1);
453 }
454
455 #[test]
456 fn custom_thresholds_from_config() {
457 let cfg = test_config(1, 2, 3);
458 let mut detector = LoopDetector::with_config(&cfg);
459 detector.record_call("ctx_read", "fp");
460 let r = detector.record_call("ctx_read", "fp");
461 assert_eq!(r.level, ThrottleLevel::Reduced);
462 detector.record_call("ctx_read", "fp");
463 let r = detector.record_call("ctx_read", "fp");
464 assert_eq!(r.level, ThrottleLevel::Blocked);
465 }
466
467 #[test]
468 fn similar_patterns_detected() {
469 assert!(patterns_are_similar("compress", "compress"));
470 assert!(patterns_are_similar("compress", "compression"));
471 assert!(patterns_are_similar("compress.*data", "compress"));
472 assert!(!patterns_are_similar("foo", "bar"));
473 assert!(!patterns_are_similar("ab", "cd"));
474 }
475
476 #[test]
477 fn search_group_tracking_when_blocking_enabled() {
478 let cfg = LoopDetectionConfig {
480 search_group_limit: 5,
481 blocked_threshold: 6, ..Default::default()
483 };
484 let mut detector = LoopDetector::with_config(&cfg);
485 for i in 0..5 {
486 let fp = format!("fp_{i}");
487 let r = detector.record_search("ctx_search", &fp, Some(&format!("pattern_{i}")));
488 assert_ne!(r.level, ThrottleLevel::Blocked, "call {i} should not block");
489 }
490 let r = detector.record_search("ctx_search", "fp_5", Some("pattern_5"));
491 assert_eq!(r.level, ThrottleLevel::Blocked);
492 assert!(r.message.unwrap().contains("search calls"));
493 }
494
495 #[test]
496 fn similar_search_patterns_trigger_block_when_enabled() {
497 let cfg = LoopDetectionConfig {
499 blocked_threshold: 6,
500 ..Default::default()
501 };
502 let mut detector = LoopDetector::with_config(&cfg);
503 let variants = [
504 "compress",
505 "compression",
506 "compress.*data",
507 "compress_output",
508 "compressor",
509 "compress_result",
510 "compress_file",
511 ];
512 for (i, pat) in variants
513 .iter()
514 .enumerate()
515 .take(cfg.blocked_threshold as usize)
516 {
517 detector.record_search("ctx_search", &format!("fp_{i}"), Some(pat));
518 }
519 let r = detector.record_search("ctx_search", "fp_new", Some("compress_all"));
520 assert_eq!(r.level, ThrottleLevel::Blocked);
521 }
522
523 #[test]
524 fn is_search_tool_detection() {
525 assert!(LoopDetector::is_search_tool("ctx_search"));
526 assert!(LoopDetector::is_search_tool("ctx_semantic_search"));
527 assert!(!LoopDetector::is_search_tool("ctx_read"));
528 assert!(!LoopDetector::is_search_tool("ctx_shell"));
529 }
530
531 #[test]
532 fn is_search_shell_command_detection() {
533 assert!(LoopDetector::is_search_shell_command("grep -r foo ."));
534 assert!(LoopDetector::is_search_shell_command("rg pattern src/"));
535 assert!(LoopDetector::is_search_shell_command("find . -name '*.rs'"));
536 assert!(!LoopDetector::is_search_shell_command("cargo build"));
537 assert!(!LoopDetector::is_search_shell_command("git status"));
538 }
539
540 #[test]
541 fn search_block_message_has_guidance_when_blocking_enabled() {
542 let cfg = LoopDetectionConfig {
544 blocked_threshold: 6,
545 search_group_limit: 8,
546 ..Default::default()
547 };
548 let mut detector = LoopDetector::with_config(&cfg);
549 for i in 0..10 {
550 detector.record_search("ctx_search", &format!("fp_{i}"), Some("compress"));
551 }
552 let r = detector.record_search("ctx_search", "fp_new", Some("compress"));
553 assert_eq!(r.level, ThrottleLevel::Blocked);
554 let msg = r.message.unwrap();
555 assert!(msg.contains("ctx_tree"));
556 assert!(msg.contains("path"));
557 assert!(msg.contains("ctx_read"));
558 }
559}