1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use super::config::LoopDetectionConfig;
5
6const SEARCH_TOOLS: &[&str] = &["ctx_search", "ctx_semantic_search"];
7
8const SEARCH_SHELL_PREFIXES: &[&str] = &["grep ", "rg ", "find ", "fd ", "ag ", "ack "];
9
10#[derive(Debug, Clone)]
11pub struct LoopDetector {
12 call_history: HashMap<String, Vec<Instant>>,
13 duplicate_counts: HashMap<String, u32>,
14 search_group_history: Vec<Instant>,
15 recent_search_patterns: Vec<String>,
16 normal_threshold: u32,
17 reduced_threshold: u32,
18 blocked_threshold: u32,
19 window: Duration,
20 search_group_limit: u32,
21}
22
23#[derive(Debug, Clone, PartialEq)]
24pub enum ThrottleLevel {
25 Normal,
26 Reduced,
27 Blocked,
28}
29
30#[derive(Debug, Clone)]
31pub struct ThrottleResult {
32 pub level: ThrottleLevel,
33 pub call_count: u32,
34 pub message: Option<String>,
35}
36
37impl Default for LoopDetector {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl LoopDetector {
44 pub fn new() -> Self {
45 Self::with_config(&LoopDetectionConfig::default())
46 }
47
48 pub fn with_config(cfg: &LoopDetectionConfig) -> Self {
49 Self {
50 call_history: HashMap::new(),
51 duplicate_counts: HashMap::new(),
52 search_group_history: Vec::new(),
53 recent_search_patterns: Vec::new(),
54 normal_threshold: cfg.normal_threshold.max(1),
55 reduced_threshold: cfg.reduced_threshold.max(2),
56 blocked_threshold: cfg.blocked_threshold.max(3),
57 window: Duration::from_secs(cfg.window_secs),
58 search_group_limit: cfg.search_group_limit.max(3),
59 }
60 }
61
62 pub fn record_call(&mut self, tool: &str, args_fingerprint: &str) -> ThrottleResult {
63 let now = Instant::now();
64 self.prune_window(now);
65
66 let key = format!("{tool}:{args_fingerprint}");
67 let entries = self.call_history.entry(key.clone()).or_default();
68 entries.push(now);
69 let count = entries.len() as u32;
70 *self.duplicate_counts.entry(key).or_default() = count;
71
72 if count > self.blocked_threshold {
73 return ThrottleResult {
74 level: ThrottleLevel::Blocked,
75 call_count: count,
76 message: Some(self.block_message(tool, count)),
77 };
78 }
79 if count > self.reduced_threshold {
80 return ThrottleResult {
81 level: ThrottleLevel::Reduced,
82 call_count: count,
83 message: Some(format!(
84 "Warning: {tool} called {count}x with same args. \
85 Results reduced. Try a different approach or narrow your scope."
86 )),
87 };
88 }
89 if count > self.normal_threshold {
90 return ThrottleResult {
91 level: ThrottleLevel::Reduced,
92 call_count: count,
93 message: Some(format!(
94 "Note: {tool} called {count}x with similar args. Consider narrowing scope."
95 )),
96 };
97 }
98 ThrottleResult {
99 level: ThrottleLevel::Normal,
100 call_count: count,
101 message: None,
102 }
103 }
104
105 pub fn record_search(
108 &mut self,
109 tool: &str,
110 args_fingerprint: &str,
111 search_pattern: Option<&str>,
112 ) -> ThrottleResult {
113 let now = Instant::now();
114
115 self.search_group_history.push(now);
116 let search_count = self.search_group_history.len() as u32;
117
118 let similar_count = if let Some(pat) = search_pattern {
119 let sc = self.count_similar_patterns(pat);
120 if !pat.is_empty() {
121 self.recent_search_patterns.push(pat.to_string());
122 if self.recent_search_patterns.len() > 15 {
123 self.recent_search_patterns.remove(0);
124 }
125 }
126 sc
127 } else {
128 0
129 };
130
131 if similar_count >= self.blocked_threshold {
132 return ThrottleResult {
133 level: ThrottleLevel::Blocked,
134 call_count: similar_count,
135 message: Some(self.search_block_message(similar_count)),
136 };
137 }
138
139 if search_count > self.search_group_limit {
140 return ThrottleResult {
141 level: ThrottleLevel::Blocked,
142 call_count: search_count,
143 message: Some(self.search_group_block_message(search_count)),
144 };
145 }
146
147 if similar_count >= self.reduced_threshold {
148 return ThrottleResult {
149 level: ThrottleLevel::Reduced,
150 call_count: similar_count,
151 message: Some(format!(
152 "Warning: You've searched for similar patterns {similar_count}x. \
153 Narrow your search with the 'path' parameter or try ctx_tree first."
154 )),
155 };
156 }
157
158 if search_count > self.search_group_limit.saturating_sub(3) {
159 let per_fp = self.record_call(tool, args_fingerprint);
160 if per_fp.level != ThrottleLevel::Normal {
161 return per_fp;
162 }
163 return ThrottleResult {
164 level: ThrottleLevel::Reduced,
165 call_count: search_count,
166 message: Some(format!(
167 "Note: {search_count} search calls in the last {}s. \
168 Use ctx_tree to orient first, then scope searches with 'path'.",
169 self.window.as_secs()
170 )),
171 };
172 }
173
174 self.record_call(tool, args_fingerprint)
175 }
176
177 pub fn is_search_tool(tool: &str) -> bool {
178 SEARCH_TOOLS.contains(&tool)
179 }
180
181 pub fn is_search_shell_command(command: &str) -> bool {
182 let cmd = command.trim_start();
183 SEARCH_SHELL_PREFIXES.iter().any(|p| cmd.starts_with(p))
184 }
185
186 pub fn fingerprint(args: &serde_json::Value) -> String {
187 use std::collections::hash_map::DefaultHasher;
188 use std::hash::{Hash, Hasher};
189
190 let canonical = canonical_json(args);
191 let mut hasher = DefaultHasher::new();
192 canonical.hash(&mut hasher);
193 format!("{:016x}", hasher.finish())
194 }
195
196 pub fn stats(&self) -> Vec<(String, u32)> {
197 let mut entries: Vec<(String, u32)> = self
198 .duplicate_counts
199 .iter()
200 .filter(|(_, &count)| count > 1)
201 .map(|(k, &v)| (k.clone(), v))
202 .collect();
203 entries.sort_by(|a, b| b.1.cmp(&a.1));
204 entries
205 }
206
207 pub fn reset(&mut self) {
208 self.call_history.clear();
209 self.duplicate_counts.clear();
210 self.search_group_history.clear();
211 self.recent_search_patterns.clear();
212 }
213
214 fn prune_window(&mut self, now: Instant) {
215 for entries in self.call_history.values_mut() {
216 entries.retain(|t| now.duration_since(*t) < self.window);
217 }
218 self.search_group_history
219 .retain(|t| now.duration_since(*t) < self.window);
220 }
221
222 fn count_similar_patterns(&self, new_pattern: &str) -> u32 {
223 let new_lower = new_pattern.to_lowercase();
224 let new_root = extract_alpha_root(&new_lower);
225
226 let mut count = 0u32;
227 for existing in &self.recent_search_patterns {
228 let existing_lower = existing.to_lowercase();
229 if patterns_are_similar(&new_lower, &existing_lower) {
230 count += 1;
231 } else if new_root.len() >= 4 {
232 let existing_root = extract_alpha_root(&existing_lower);
233 if existing_root.len() >= 4
234 && (new_root.starts_with(&existing_root)
235 || existing_root.starts_with(&new_root))
236 {
237 count += 1;
238 }
239 }
240 }
241 count
242 }
243
244 fn block_message(&self, tool: &str, count: u32) -> String {
245 if Self::is_search_tool(tool) {
246 self.search_block_message(count)
247 } else {
248 format!(
249 "LOOP DETECTED: {tool} called {count}x with same/similar args. \
250 Call blocked. Change your approach — the current strategy is not working."
251 )
252 }
253 }
254
255 fn search_block_message(&self, count: u32) -> String {
256 format!(
257 "LOOP DETECTED: You've searched {count}x with similar patterns. STOP searching and change strategy. \
258 1) Use ctx_tree to understand the project structure first. \
259 2) Narrow your search with the 'path' parameter to a specific directory. \
260 3) Use ctx_read with mode='map' to understand a file before searching more."
261 )
262 }
263
264 fn search_group_block_message(&self, count: u32) -> String {
265 format!(
266 "LOOP DETECTED: {count} search calls in {}s — too many. STOP and rethink. \
267 1) Use ctx_tree to map the project structure. \
268 2) Pick ONE specific directory and search there with the 'path' parameter. \
269 3) Read files with ctx_read mode='map' instead of searching blindly.",
270 self.window.as_secs()
271 )
272 }
273}
274
275fn extract_alpha_root(pattern: &str) -> String {
276 pattern
277 .chars()
278 .take_while(|c| c.is_alphanumeric())
279 .collect()
280}
281
282fn patterns_are_similar(a: &str, b: &str) -> bool {
283 if a == b {
284 return true;
285 }
286 if a.contains(b) || b.contains(a) {
287 return true;
288 }
289 let a_alpha: String = a.chars().filter(|c| c.is_alphanumeric()).collect();
290 let b_alpha: String = b.chars().filter(|c| c.is_alphanumeric()).collect();
291 if a_alpha.len() >= 3
292 && b_alpha.len() >= 3
293 && (a_alpha.contains(&b_alpha) || b_alpha.contains(&a_alpha))
294 {
295 return true;
296 }
297 false
298}
299
300fn canonical_json(value: &serde_json::Value) -> String {
301 match value {
302 serde_json::Value::Object(map) => {
303 let mut keys: Vec<&String> = map.keys().collect();
304 keys.sort();
305 let entries: Vec<String> = keys
306 .iter()
307 .map(|k| format!("{}:{}", k, canonical_json(&map[*k])))
308 .collect();
309 format!("{{{}}}", entries.join(","))
310 }
311 serde_json::Value::Array(arr) => {
312 let entries: Vec<String> = arr.iter().map(canonical_json).collect();
313 format!("[{}]", entries.join(","))
314 }
315 _ => value.to_string(),
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 fn test_config(normal: u32, reduced: u32, blocked: u32) -> LoopDetectionConfig {
324 LoopDetectionConfig {
325 normal_threshold: normal,
326 reduced_threshold: reduced,
327 blocked_threshold: blocked,
328 window_secs: 300,
329 search_group_limit: 10,
330 }
331 }
332
333 #[test]
334 fn normal_calls_pass_through() {
335 let mut detector = LoopDetector::new();
336 let r1 = detector.record_call("ctx_read", "abc123");
337 assert_eq!(r1.level, ThrottleLevel::Normal);
338 assert_eq!(r1.call_count, 1);
339 assert!(r1.message.is_none());
340 }
341
342 #[test]
343 fn repeated_calls_trigger_reduced() {
344 let cfg = LoopDetectionConfig::default();
345 let mut detector = LoopDetector::with_config(&cfg);
346 for _ in 0..cfg.normal_threshold {
347 detector.record_call("ctx_read", "same_fp");
348 }
349 let result = detector.record_call("ctx_read", "same_fp");
350 assert_eq!(result.level, ThrottleLevel::Reduced);
351 assert!(result.message.is_some());
352 }
353
354 #[test]
355 fn excessive_calls_get_blocked() {
356 let cfg = LoopDetectionConfig::default();
357 let mut detector = LoopDetector::with_config(&cfg);
358 for _ in 0..cfg.blocked_threshold {
359 detector.record_call("ctx_shell", "same_fp");
360 }
361 let result = detector.record_call("ctx_shell", "same_fp");
362 assert_eq!(result.level, ThrottleLevel::Blocked);
363 assert!(result.message.unwrap().contains("LOOP DETECTED"));
364 }
365
366 #[test]
367 fn different_args_tracked_separately() {
368 let mut detector = LoopDetector::new();
369 for _ in 0..10 {
370 detector.record_call("ctx_read", "fp_a");
371 }
372 let result = detector.record_call("ctx_read", "fp_b");
373 assert_eq!(result.level, ThrottleLevel::Normal);
374 assert_eq!(result.call_count, 1);
375 }
376
377 #[test]
378 fn fingerprint_deterministic() {
379 let args = serde_json::json!({"path": "test.rs", "mode": "full"});
380 let fp1 = LoopDetector::fingerprint(&args);
381 let fp2 = LoopDetector::fingerprint(&args);
382 assert_eq!(fp1, fp2);
383 }
384
385 #[test]
386 fn fingerprint_order_independent() {
387 let a = serde_json::json!({"mode": "full", "path": "test.rs"});
388 let b = serde_json::json!({"path": "test.rs", "mode": "full"});
389 assert_eq!(LoopDetector::fingerprint(&a), LoopDetector::fingerprint(&b));
390 }
391
392 #[test]
393 fn stats_shows_duplicates() {
394 let mut detector = LoopDetector::new();
395 for _ in 0..5 {
396 detector.record_call("ctx_read", "fp_a");
397 }
398 detector.record_call("ctx_shell", "fp_b");
399 let stats = detector.stats();
400 assert_eq!(stats.len(), 1);
401 assert_eq!(stats[0].1, 5);
402 }
403
404 #[test]
405 fn reset_clears_state() {
406 let mut detector = LoopDetector::new();
407 for _ in 0..5 {
408 detector.record_call("ctx_read", "fp_a");
409 }
410 detector.reset();
411 let result = detector.record_call("ctx_read", "fp_a");
412 assert_eq!(result.call_count, 1);
413 }
414
415 #[test]
416 fn custom_thresholds_from_config() {
417 let cfg = test_config(1, 2, 3);
418 let mut detector = LoopDetector::with_config(&cfg);
419 detector.record_call("ctx_read", "fp");
420 let r = detector.record_call("ctx_read", "fp");
421 assert_eq!(r.level, ThrottleLevel::Reduced);
422 detector.record_call("ctx_read", "fp");
423 let r = detector.record_call("ctx_read", "fp");
424 assert_eq!(r.level, ThrottleLevel::Blocked);
425 }
426
427 #[test]
428 fn similar_patterns_detected() {
429 assert!(patterns_are_similar("compress", "compress"));
430 assert!(patterns_are_similar("compress", "compression"));
431 assert!(patterns_are_similar("compress.*data", "compress"));
432 assert!(!patterns_are_similar("foo", "bar"));
433 assert!(!patterns_are_similar("ab", "cd"));
434 }
435
436 #[test]
437 fn search_group_tracking() {
438 let cfg = LoopDetectionConfig {
439 search_group_limit: 5,
440 ..Default::default()
441 };
442 let mut detector = LoopDetector::with_config(&cfg);
443 for i in 0..5 {
444 let fp = format!("fp_{i}");
445 let r = detector.record_search("ctx_search", &fp, Some(&format!("pattern_{i}")));
446 assert_ne!(r.level, ThrottleLevel::Blocked, "call {i} should not block");
447 }
448 let r = detector.record_search("ctx_search", "fp_5", Some("pattern_5"));
449 assert_eq!(r.level, ThrottleLevel::Blocked);
450 assert!(r.message.unwrap().contains("search calls"));
451 }
452
453 #[test]
454 fn similar_search_patterns_trigger_block() {
455 let cfg = LoopDetectionConfig::default();
456 let mut detector = LoopDetector::with_config(&cfg);
457 let variants = [
458 "compress",
459 "compression",
460 "compress.*data",
461 "compress_output",
462 "compressor",
463 "compress_result",
464 "compress_file",
465 ];
466 for (i, pat) in variants
467 .iter()
468 .enumerate()
469 .take(cfg.blocked_threshold as usize)
470 {
471 detector.record_search("ctx_search", &format!("fp_{i}"), Some(pat));
472 }
473 let r = detector.record_search("ctx_search", "fp_new", Some("compress_all"));
474 assert_eq!(r.level, ThrottleLevel::Blocked);
475 }
476
477 #[test]
478 fn is_search_tool_detection() {
479 assert!(LoopDetector::is_search_tool("ctx_search"));
480 assert!(LoopDetector::is_search_tool("ctx_semantic_search"));
481 assert!(!LoopDetector::is_search_tool("ctx_read"));
482 assert!(!LoopDetector::is_search_tool("ctx_shell"));
483 }
484
485 #[test]
486 fn is_search_shell_command_detection() {
487 assert!(LoopDetector::is_search_shell_command("grep -r foo ."));
488 assert!(LoopDetector::is_search_shell_command("rg pattern src/"));
489 assert!(LoopDetector::is_search_shell_command("find . -name '*.rs'"));
490 assert!(!LoopDetector::is_search_shell_command("cargo build"));
491 assert!(!LoopDetector::is_search_shell_command("git status"));
492 }
493
494 #[test]
495 fn search_block_message_has_guidance() {
496 let mut detector = LoopDetector::new();
497 for i in 0..10 {
498 detector.record_search("ctx_search", &format!("fp_{i}"), Some("compress"));
499 }
500 let r = detector.record_search("ctx_search", "fp_new", Some("compress"));
501 let msg = r.message.unwrap();
502 assert!(msg.contains("ctx_tree"));
503 assert!(msg.contains("path"));
504 assert!(msg.contains("ctx_read"));
505 }
506}