1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use super::config::LoopDetectionConfig;
5
6const SEARCH_TOOLS: &[&str] = &["ctx_search", "ctx_semantic_search"];
7
8const SEARCH_SHELL_PREFIXES: &[&str] = &["grep ", "rg ", "find ", "fd ", "ag ", "ack "];
9
10#[derive(Debug, Clone)]
12pub struct LoopDetector {
13 call_history: HashMap<String, Vec<Instant>>,
14 duplicate_counts: HashMap<String, u32>,
15 search_group_history: Vec<Instant>,
16 recent_search_patterns: Vec<String>,
17 normal_threshold: u32,
18 reduced_threshold: u32,
19 blocked_threshold: u32,
20 window: Duration,
21 search_group_limit: u32,
22}
23
24#[derive(Debug, Clone, PartialEq)]
26pub enum ThrottleLevel {
27 Normal,
28 Reduced,
29 Blocked,
30}
31
32#[derive(Debug, Clone)]
34pub struct ThrottleResult {
35 pub level: ThrottleLevel,
36 pub call_count: u32,
37 pub message: Option<String>,
38}
39
40impl Default for LoopDetector {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl LoopDetector {
47 pub fn new() -> Self {
49 Self::with_config(&LoopDetectionConfig::default())
50 }
51
52 pub fn with_config(cfg: &LoopDetectionConfig) -> Self {
54 Self {
55 call_history: HashMap::new(),
56 duplicate_counts: HashMap::new(),
57 search_group_history: Vec::new(),
58 recent_search_patterns: Vec::new(),
59 normal_threshold: cfg.normal_threshold.max(1),
60 reduced_threshold: cfg.reduced_threshold.max(2),
61 blocked_threshold: cfg.blocked_threshold.max(3),
62 window: Duration::from_secs(cfg.window_secs),
63 search_group_limit: cfg.search_group_limit.max(3),
64 }
65 }
66
67 pub fn record_call(&mut self, tool: &str, args_fingerprint: &str) -> ThrottleResult {
69 let now = Instant::now();
70 self.prune_window(now);
71
72 let key = format!("{tool}:{args_fingerprint}");
73 let entries = self.call_history.entry(key.clone()).or_default();
74 entries.push(now);
75 let count = entries.len() as u32;
76 *self.duplicate_counts.entry(key).or_default() = count;
77
78 if count > self.blocked_threshold {
79 return ThrottleResult {
80 level: ThrottleLevel::Blocked,
81 call_count: count,
82 message: Some(self.block_message(tool, count)),
83 };
84 }
85 if count > self.reduced_threshold {
86 return ThrottleResult {
87 level: ThrottleLevel::Reduced,
88 call_count: count,
89 message: Some(format!(
90 "Warning: {tool} called {count}x with same args. \
91 Results reduced. Try a different approach or narrow your scope."
92 )),
93 };
94 }
95 if count > self.normal_threshold {
96 return ThrottleResult {
97 level: ThrottleLevel::Reduced,
98 call_count: count,
99 message: Some(format!(
100 "Note: {tool} called {count}x with similar args. Consider narrowing scope."
101 )),
102 };
103 }
104 ThrottleResult {
105 level: ThrottleLevel::Normal,
106 call_count: count,
107 message: None,
108 }
109 }
110
111 pub fn record_search(
114 &mut self,
115 tool: &str,
116 args_fingerprint: &str,
117 search_pattern: Option<&str>,
118 ) -> ThrottleResult {
119 let now = Instant::now();
120
121 self.search_group_history.push(now);
122 let search_count = self.search_group_history.len() as u32;
123
124 let similar_count = if let Some(pat) = search_pattern {
125 let sc = self.count_similar_patterns(pat);
126 if !pat.is_empty() {
127 self.recent_search_patterns.push(pat.to_string());
128 if self.recent_search_patterns.len() > 15 {
129 self.recent_search_patterns.remove(0);
130 }
131 }
132 sc
133 } else {
134 0
135 };
136
137 if similar_count >= self.blocked_threshold {
138 return ThrottleResult {
139 level: ThrottleLevel::Blocked,
140 call_count: similar_count,
141 message: Some(self.search_block_message(similar_count)),
142 };
143 }
144
145 if search_count > self.search_group_limit {
146 return ThrottleResult {
147 level: ThrottleLevel::Blocked,
148 call_count: search_count,
149 message: Some(self.search_group_block_message(search_count)),
150 };
151 }
152
153 if similar_count >= self.reduced_threshold {
154 return ThrottleResult {
155 level: ThrottleLevel::Reduced,
156 call_count: similar_count,
157 message: Some(format!(
158 "Warning: You've searched for similar patterns {similar_count}x. \
159 Narrow your search with the 'path' parameter or try ctx_tree first."
160 )),
161 };
162 }
163
164 if search_count > self.search_group_limit.saturating_sub(3) {
165 let per_fp = self.record_call(tool, args_fingerprint);
166 if per_fp.level != ThrottleLevel::Normal {
167 return per_fp;
168 }
169 return ThrottleResult {
170 level: ThrottleLevel::Reduced,
171 call_count: search_count,
172 message: Some(format!(
173 "Note: {search_count} search calls in the last {}s. \
174 Use ctx_tree to orient first, then scope searches with 'path'.",
175 self.window.as_secs()
176 )),
177 };
178 }
179
180 self.record_call(tool, args_fingerprint)
181 }
182
183 pub fn is_search_tool(tool: &str) -> bool {
185 SEARCH_TOOLS.contains(&tool)
186 }
187
188 pub fn is_search_shell_command(command: &str) -> bool {
190 let cmd = command.trim_start();
191 SEARCH_SHELL_PREFIXES.iter().any(|p| cmd.starts_with(p))
192 }
193
194 pub fn fingerprint(args: &serde_json::Value) -> String {
196 use std::collections::hash_map::DefaultHasher;
197 use std::hash::{Hash, Hasher};
198
199 let canonical = canonical_json(args);
200 let mut hasher = DefaultHasher::new();
201 canonical.hash(&mut hasher);
202 format!("{:016x}", hasher.finish())
203 }
204
205 pub fn stats(&self) -> Vec<(String, u32)> {
207 let mut entries: Vec<(String, u32)> = self
208 .duplicate_counts
209 .iter()
210 .filter(|(_, &count)| count > 1)
211 .map(|(k, &v)| (k.clone(), v))
212 .collect();
213 entries.sort_by_key(|x| std::cmp::Reverse(x.1));
214 entries
215 }
216
217 pub fn reset(&mut self) {
219 self.call_history.clear();
220 self.duplicate_counts.clear();
221 self.search_group_history.clear();
222 self.recent_search_patterns.clear();
223 }
224
225 fn prune_window(&mut self, now: Instant) {
226 for entries in self.call_history.values_mut() {
227 entries.retain(|t| now.duration_since(*t) < self.window);
228 }
229 self.search_group_history
230 .retain(|t| now.duration_since(*t) < self.window);
231 }
232
233 fn count_similar_patterns(&self, new_pattern: &str) -> u32 {
234 let new_lower = new_pattern.to_lowercase();
235 let new_root = extract_alpha_root(&new_lower);
236
237 let mut count = 0u32;
238 for existing in &self.recent_search_patterns {
239 let existing_lower = existing.to_lowercase();
240 if patterns_are_similar(&new_lower, &existing_lower) {
241 count += 1;
242 } else if new_root.len() >= 4 {
243 let existing_root = extract_alpha_root(&existing_lower);
244 if existing_root.len() >= 4
245 && (new_root.starts_with(&existing_root)
246 || existing_root.starts_with(&new_root))
247 {
248 count += 1;
249 }
250 }
251 }
252 count
253 }
254
255 fn block_message(&self, tool: &str, count: u32) -> String {
256 if Self::is_search_tool(tool) {
257 self.search_block_message(count)
258 } else {
259 format!(
260 "LOOP DETECTED: {tool} called {count}x with same/similar args. \
261 Call blocked. Change your approach — the current strategy is not working."
262 )
263 }
264 }
265
266 #[allow(clippy::unused_self)]
267 fn search_block_message(&self, count: u32) -> String {
268 format!(
269 "LOOP DETECTED: You've searched {count}x with similar patterns. STOP searching and change strategy. \
270 1) Use ctx_tree to understand the project structure first. \
271 2) Narrow your search with the 'path' parameter to a specific directory. \
272 3) Use ctx_read with mode='map' to understand a file before searching more."
273 )
274 }
275
276 fn search_group_block_message(&self, count: u32) -> String {
277 format!(
278 "LOOP DETECTED: {count} search calls in {}s — too many. STOP and rethink. \
279 1) Use ctx_tree to map the project structure. \
280 2) Pick ONE specific directory and search there with the 'path' parameter. \
281 3) Read files with ctx_read mode='map' instead of searching blindly.",
282 self.window.as_secs()
283 )
284 }
285}
286
287fn extract_alpha_root(pattern: &str) -> String {
288 pattern
289 .chars()
290 .take_while(|c| c.is_alphanumeric())
291 .collect()
292}
293
294fn patterns_are_similar(a: &str, b: &str) -> bool {
295 if a == b {
296 return true;
297 }
298 if a.contains(b) || b.contains(a) {
299 return true;
300 }
301 let a_alpha: String = a.chars().filter(|c| c.is_alphanumeric()).collect();
302 let b_alpha: String = b.chars().filter(|c| c.is_alphanumeric()).collect();
303 if a_alpha.len() >= 3
304 && b_alpha.len() >= 3
305 && (a_alpha.contains(&b_alpha) || b_alpha.contains(&a_alpha))
306 {
307 return true;
308 }
309 false
310}
311
312fn canonical_json(value: &serde_json::Value) -> String {
313 match value {
314 serde_json::Value::Object(map) => {
315 let mut keys: Vec<&String> = map.keys().collect();
316 keys.sort();
317 let entries: Vec<String> = keys
318 .iter()
319 .map(|k| format!("{}:{}", k, canonical_json(&map[*k])))
320 .collect();
321 format!("{{{}}}", entries.join(","))
322 }
323 serde_json::Value::Array(arr) => {
324 let entries: Vec<String> = arr.iter().map(canonical_json).collect();
325 format!("[{}]", entries.join(","))
326 }
327 _ => value.to_string(),
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 fn test_config(normal: u32, reduced: u32, blocked: u32) -> LoopDetectionConfig {
336 LoopDetectionConfig {
337 normal_threshold: normal,
338 reduced_threshold: reduced,
339 blocked_threshold: blocked,
340 window_secs: 300,
341 search_group_limit: 10,
342 }
343 }
344
345 #[test]
346 fn normal_calls_pass_through() {
347 let mut detector = LoopDetector::new();
348 let r1 = detector.record_call("ctx_read", "abc123");
349 assert_eq!(r1.level, ThrottleLevel::Normal);
350 assert_eq!(r1.call_count, 1);
351 assert!(r1.message.is_none());
352 }
353
354 #[test]
355 fn repeated_calls_trigger_reduced() {
356 let cfg = LoopDetectionConfig::default();
357 let mut detector = LoopDetector::with_config(&cfg);
358 for _ in 0..cfg.normal_threshold {
359 detector.record_call("ctx_read", "same_fp");
360 }
361 let result = detector.record_call("ctx_read", "same_fp");
362 assert_eq!(result.level, ThrottleLevel::Reduced);
363 assert!(result.message.is_some());
364 }
365
366 #[test]
367 fn excessive_calls_get_blocked() {
368 let cfg = LoopDetectionConfig::default();
369 let mut detector = LoopDetector::with_config(&cfg);
370 for _ in 0..cfg.blocked_threshold {
371 detector.record_call("ctx_shell", "same_fp");
372 }
373 let result = detector.record_call("ctx_shell", "same_fp");
374 assert_eq!(result.level, ThrottleLevel::Blocked);
375 assert!(result.message.unwrap().contains("LOOP DETECTED"));
376 }
377
378 #[test]
379 fn different_args_tracked_separately() {
380 let mut detector = LoopDetector::new();
381 for _ in 0..10 {
382 detector.record_call("ctx_read", "fp_a");
383 }
384 let result = detector.record_call("ctx_read", "fp_b");
385 assert_eq!(result.level, ThrottleLevel::Normal);
386 assert_eq!(result.call_count, 1);
387 }
388
389 #[test]
390 fn fingerprint_deterministic() {
391 let args = serde_json::json!({"path": "test.rs", "mode": "full"});
392 let fp1 = LoopDetector::fingerprint(&args);
393 let fp2 = LoopDetector::fingerprint(&args);
394 assert_eq!(fp1, fp2);
395 }
396
397 #[test]
398 fn fingerprint_order_independent() {
399 let a = serde_json::json!({"mode": "full", "path": "test.rs"});
400 let b = serde_json::json!({"path": "test.rs", "mode": "full"});
401 assert_eq!(LoopDetector::fingerprint(&a), LoopDetector::fingerprint(&b));
402 }
403
404 #[test]
405 fn stats_shows_duplicates() {
406 let mut detector = LoopDetector::new();
407 for _ in 0..5 {
408 detector.record_call("ctx_read", "fp_a");
409 }
410 detector.record_call("ctx_shell", "fp_b");
411 let stats = detector.stats();
412 assert_eq!(stats.len(), 1);
413 assert_eq!(stats[0].1, 5);
414 }
415
416 #[test]
417 fn reset_clears_state() {
418 let mut detector = LoopDetector::new();
419 for _ in 0..5 {
420 detector.record_call("ctx_read", "fp_a");
421 }
422 detector.reset();
423 let result = detector.record_call("ctx_read", "fp_a");
424 assert_eq!(result.call_count, 1);
425 }
426
427 #[test]
428 fn custom_thresholds_from_config() {
429 let cfg = test_config(1, 2, 3);
430 let mut detector = LoopDetector::with_config(&cfg);
431 detector.record_call("ctx_read", "fp");
432 let r = detector.record_call("ctx_read", "fp");
433 assert_eq!(r.level, ThrottleLevel::Reduced);
434 detector.record_call("ctx_read", "fp");
435 let r = detector.record_call("ctx_read", "fp");
436 assert_eq!(r.level, ThrottleLevel::Blocked);
437 }
438
439 #[test]
440 fn similar_patterns_detected() {
441 assert!(patterns_are_similar("compress", "compress"));
442 assert!(patterns_are_similar("compress", "compression"));
443 assert!(patterns_are_similar("compress.*data", "compress"));
444 assert!(!patterns_are_similar("foo", "bar"));
445 assert!(!patterns_are_similar("ab", "cd"));
446 }
447
448 #[test]
449 fn search_group_tracking() {
450 let cfg = LoopDetectionConfig {
451 search_group_limit: 5,
452 ..Default::default()
453 };
454 let mut detector = LoopDetector::with_config(&cfg);
455 for i in 0..5 {
456 let fp = format!("fp_{i}");
457 let r = detector.record_search("ctx_search", &fp, Some(&format!("pattern_{i}")));
458 assert_ne!(r.level, ThrottleLevel::Blocked, "call {i} should not block");
459 }
460 let r = detector.record_search("ctx_search", "fp_5", Some("pattern_5"));
461 assert_eq!(r.level, ThrottleLevel::Blocked);
462 assert!(r.message.unwrap().contains("search calls"));
463 }
464
465 #[test]
466 fn similar_search_patterns_trigger_block() {
467 let cfg = LoopDetectionConfig::default();
468 let mut detector = LoopDetector::with_config(&cfg);
469 let variants = [
470 "compress",
471 "compression",
472 "compress.*data",
473 "compress_output",
474 "compressor",
475 "compress_result",
476 "compress_file",
477 ];
478 for (i, pat) in variants
479 .iter()
480 .enumerate()
481 .take(cfg.blocked_threshold as usize)
482 {
483 detector.record_search("ctx_search", &format!("fp_{i}"), Some(pat));
484 }
485 let r = detector.record_search("ctx_search", "fp_new", Some("compress_all"));
486 assert_eq!(r.level, ThrottleLevel::Blocked);
487 }
488
489 #[test]
490 fn is_search_tool_detection() {
491 assert!(LoopDetector::is_search_tool("ctx_search"));
492 assert!(LoopDetector::is_search_tool("ctx_semantic_search"));
493 assert!(!LoopDetector::is_search_tool("ctx_read"));
494 assert!(!LoopDetector::is_search_tool("ctx_shell"));
495 }
496
497 #[test]
498 fn is_search_shell_command_detection() {
499 assert!(LoopDetector::is_search_shell_command("grep -r foo ."));
500 assert!(LoopDetector::is_search_shell_command("rg pattern src/"));
501 assert!(LoopDetector::is_search_shell_command("find . -name '*.rs'"));
502 assert!(!LoopDetector::is_search_shell_command("cargo build"));
503 assert!(!LoopDetector::is_search_shell_command("git status"));
504 }
505
506 #[test]
507 fn search_block_message_has_guidance() {
508 let mut detector = LoopDetector::new();
509 for i in 0..10 {
510 detector.record_search("ctx_search", &format!("fp_{i}"), Some("compress"));
511 }
512 let r = detector.record_search("ctx_search", "fp_new", Some("compress"));
513 let msg = r.message.unwrap();
514 assert!(msg.contains("ctx_tree"));
515 assert!(msg.contains("path"));
516 assert!(msg.contains("ctx_read"));
517 }
518}