1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use super::config::LoopDetectionConfig;
5
6const SEARCH_TOOLS: &[&str] = &["ctx_search", "ctx_semantic_search"];
7
8const SEARCH_SHELL_PREFIXES: &[&str] = &["grep ", "rg ", "find ", "fd ", "ag ", "ack "];
9
10const CORRECTION_WINDOW: Duration = Duration::from_mins(2);
11const MODE_BOUNCE_WINDOW: Duration = Duration::from_secs(30);
12const SHELL_RERUN_WINDOW: Duration = Duration::from_mins(1);
13const COLD_START_CALLS: u32 = 3;
14
15#[derive(Debug, Clone, PartialEq)]
17pub enum CorrectionKind {
18 FreshReRead,
19 ShellReRun,
20 ModeBounce,
21}
22
23#[derive(Debug, Clone)]
25pub struct LoopDetector {
26 call_history: HashMap<String, Vec<Instant>>,
27 duplicate_counts: HashMap<String, u32>,
28 tool_total_counts: HashMap<String, u32>,
29 tool_total_limits: HashMap<String, u32>,
30 search_group_history: Vec<Instant>,
31 recent_search_patterns: Vec<String>,
32 normal_threshold: u32,
33 reduced_threshold: u32,
34 blocked_threshold: u32,
35 window: Duration,
36 search_group_limit: u32,
37 correction_signals: Vec<(Instant, CorrectionKind)>,
39 recent_reads: HashMap<String, (Instant, String)>,
40 recent_commands: HashMap<String, Instant>,
41 total_calls: u32,
42}
43
44#[derive(Debug, Clone, PartialEq)]
46pub enum ThrottleLevel {
47 Normal,
48 Reduced,
49 Blocked,
50}
51
52#[derive(Debug, Clone)]
54pub struct ThrottleResult {
55 pub level: ThrottleLevel,
56 pub call_count: u32,
57 pub message: Option<String>,
58}
59
60impl Default for ThrottleResult {
61 fn default() -> Self {
62 Self {
63 level: ThrottleLevel::Normal,
64 call_count: 0,
65 message: None,
66 }
67 }
68}
69
70impl Default for LoopDetector {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl LoopDetector {
77 pub fn new() -> Self {
79 Self::with_config(&LoopDetectionConfig::default())
80 }
81
82 pub fn with_config(cfg: &LoopDetectionConfig) -> Self {
85 Self {
86 call_history: HashMap::new(),
87 duplicate_counts: HashMap::new(),
88 tool_total_counts: HashMap::new(),
89 tool_total_limits: cfg.tool_total_limits.clone(),
90 search_group_history: Vec::new(),
91 recent_search_patterns: Vec::new(),
92 normal_threshold: cfg.normal_threshold.max(1),
93 reduced_threshold: cfg.reduced_threshold.max(2),
94 blocked_threshold: cfg.blocked_threshold,
95 window: Duration::from_secs(cfg.window_secs),
96 search_group_limit: if cfg.blocked_threshold == 0 {
97 u32::MAX
98 } else {
99 cfg.search_group_limit.max(3)
100 },
101 correction_signals: Vec::new(),
102 recent_reads: HashMap::new(),
103 recent_commands: HashMap::new(),
104 total_calls: 0,
105 }
106 }
107
108 pub fn record_call(&mut self, tool: &str, args_fingerprint: &str) -> ThrottleResult {
110 let now = Instant::now();
111 self.prune_window(now);
112
113 let total = self.tool_total_counts.entry(tool.to_string()).or_insert(0);
115 *total += 1;
116 let total_count = *total;
117
118 if let Some(&limit) = self.tool_total_limits.get(tool) {
119 if total_count > limit {
120 let msg = if crate::core::protocol::meta_visible() {
121 Some(format!(
122 "Warning: {tool} called {total_count}x total (limit: {limit}). \
123 Consider ctx_compress or narrowing scope."
124 ))
125 } else {
126 None
127 };
128 return ThrottleResult {
129 level: ThrottleLevel::Reduced,
130 call_count: total_count,
131 message: msg,
132 };
133 }
134 }
135
136 let key = format!("{tool}:{args_fingerprint}");
137 let entries = self.call_history.entry(key.clone()).or_default();
138 entries.push(now);
139 let count = entries.len() as u32;
140 *self.duplicate_counts.entry(key).or_default() = count;
141
142 if self.blocked_threshold > 0 && count > self.blocked_threshold {
143 return ThrottleResult {
144 level: ThrottleLevel::Blocked,
145 call_count: count,
146 message: Some(self.block_message(tool, count)),
147 };
148 }
149 if count > self.reduced_threshold {
150 if !crate::core::protocol::meta_visible() {
151 return ThrottleResult {
152 level: ThrottleLevel::Reduced,
153 call_count: count,
154 message: None,
155 };
156 }
157 return ThrottleResult {
158 level: ThrottleLevel::Reduced,
159 call_count: count,
160 message: Some(format!(
161 "Warning: {tool} called {count}x with same args. \
162 Results reduced. Try a different approach or narrow your scope."
163 )),
164 };
165 }
166 if count > self.normal_threshold {
167 if !crate::core::protocol::meta_visible() {
168 return ThrottleResult {
169 level: ThrottleLevel::Reduced,
170 call_count: count,
171 message: None,
172 };
173 }
174 return ThrottleResult {
175 level: ThrottleLevel::Reduced,
176 call_count: count,
177 message: Some(format!(
178 "Note: {tool} called {count}x with similar args. Consider narrowing scope."
179 )),
180 };
181 }
182 ThrottleResult {
183 level: ThrottleLevel::Normal,
184 call_count: count,
185 message: None,
186 }
187 }
188
189 pub fn record_error_outcome(&mut self, tool: &str, args_fingerprint: &str) {
192 let key = format!("{tool}:{args_fingerprint}");
193 if let Some(entries) = self.call_history.get_mut(&key) {
194 entries.pop();
195 let count = entries.len() as u32;
196 self.duplicate_counts.insert(key, count);
197 }
198 }
199
200 pub fn record_search(
203 &mut self,
204 tool: &str,
205 args_fingerprint: &str,
206 search_pattern: Option<&str>,
207 ) -> ThrottleResult {
208 let now = Instant::now();
209
210 self.search_group_history.push(now);
211 let search_count = self.search_group_history.len() as u32;
212
213 let similar_count = if let Some(pat) = search_pattern {
214 let sc = self.count_similar_patterns(pat);
215 if !pat.is_empty() {
216 self.recent_search_patterns.push(pat.to_string());
217 if self.recent_search_patterns.len() > 15 {
218 self.recent_search_patterns.remove(0);
219 }
220 }
221 sc
222 } else {
223 0
224 };
225
226 if self.blocked_threshold > 0 && similar_count >= self.blocked_threshold {
228 return ThrottleResult {
229 level: ThrottleLevel::Blocked,
230 call_count: similar_count,
231 message: Some(self.search_block_message(similar_count)),
232 };
233 }
234
235 if self.blocked_threshold > 0 && search_count > self.search_group_limit {
237 return ThrottleResult {
238 level: ThrottleLevel::Blocked,
239 call_count: search_count,
240 message: Some(self.search_group_block_message(search_count)),
241 };
242 }
243
244 if similar_count >= self.reduced_threshold {
245 if !crate::core::protocol::meta_visible() {
246 return ThrottleResult {
247 level: ThrottleLevel::Reduced,
248 call_count: similar_count,
249 message: None,
250 };
251 }
252 return ThrottleResult {
253 level: ThrottleLevel::Reduced,
254 call_count: similar_count,
255 message: Some(format!(
256 "Warning: You've searched for similar patterns {similar_count}x. \
257 Narrow your search with the 'path' parameter or try ctx_tree first."
258 )),
259 };
260 }
261
262 if search_count > self.search_group_limit.saturating_sub(3) {
263 let per_fp = self.record_call(tool, args_fingerprint);
264 if per_fp.level != ThrottleLevel::Normal {
265 return per_fp;
266 }
267 if !crate::core::protocol::meta_visible() {
268 return ThrottleResult {
269 level: ThrottleLevel::Reduced,
270 call_count: search_count,
271 message: None,
272 };
273 }
274 return ThrottleResult {
275 level: ThrottleLevel::Reduced,
276 call_count: search_count,
277 message: Some(format!(
278 "Note: {search_count} search calls in the last {}s. \
279 Use ctx_tree to orient first, then scope searches with 'path'.",
280 self.window.as_secs()
281 )),
282 };
283 }
284
285 self.record_call(tool, args_fingerprint)
286 }
287
288 pub fn is_search_tool(tool: &str) -> bool {
290 SEARCH_TOOLS.contains(&tool)
291 }
292
293 pub fn is_search_shell_command(command: &str) -> bool {
295 let cmd = command.trim_start();
296 SEARCH_SHELL_PREFIXES.iter().any(|p| cmd.starts_with(p))
297 }
298
299 pub fn fingerprint(args: &serde_json::Value) -> String {
301 use std::collections::hash_map::DefaultHasher;
302 use std::hash::{Hash, Hasher};
303
304 let canonical = canonical_json(args);
305 let mut hasher = DefaultHasher::new();
306 canonical.hash(&mut hasher);
307 format!("{:016x}", hasher.finish())
308 }
309
310 pub fn stats(&self) -> Vec<(String, u32)> {
312 let mut entries: Vec<(String, u32)> = self
313 .duplicate_counts
314 .iter()
315 .filter(|(_, &count)| count > 1)
316 .map(|(k, &v)| (k.clone(), v))
317 .collect();
318 entries.sort_by_key(|x| std::cmp::Reverse(x.1));
319 entries
320 }
321
322 pub fn record_read_for_correction(&mut self, path: &str, mode: &str, fresh: bool) {
326 self.total_calls += 1;
327 let now = Instant::now();
328
329 if self.total_calls <= COLD_START_CALLS {
330 self.recent_reads
331 .insert(path.to_string(), (now, mode.to_string()));
332 return;
333 }
334
335 if fresh {
336 if let Some((prev_time, _)) = self.recent_reads.get(path) {
337 if now.duration_since(*prev_time) < CORRECTION_WINDOW {
338 self.correction_signals
339 .push((now, CorrectionKind::FreshReRead));
340 }
341 }
342 }
343
344 if mode == "full" {
345 if let Some((prev_time, prev_mode)) = self.recent_reads.get(path) {
346 let is_bounce = (prev_mode == "map" || prev_mode == "signatures")
347 && now.duration_since(*prev_time) < MODE_BOUNCE_WINDOW;
348 if is_bounce {
349 self.correction_signals
350 .push((now, CorrectionKind::ModeBounce));
351 }
352 }
353 }
354
355 self.recent_reads
356 .insert(path.to_string(), (now, mode.to_string()));
357 }
358
359 pub fn record_shell_for_correction(&mut self, command: &str) {
361 self.total_calls += 1;
362 let now = Instant::now();
363
364 if self.total_calls <= COLD_START_CALLS {
365 self.recent_commands.insert(command.to_string(), now);
366 return;
367 }
368
369 let key = normalize_shell_command(command);
370 if let Some(prev_time) = self.recent_commands.get(&key) {
371 if now.duration_since(*prev_time) < SHELL_RERUN_WINDOW {
372 self.correction_signals
373 .push((now, CorrectionKind::ShellReRun));
374 }
375 }
376 self.recent_commands.insert(key, now);
377 }
378
379 pub fn correction_count(&self) -> u32 {
381 let now = Instant::now();
382 self.correction_signals
383 .iter()
384 .filter(|(t, _)| now.duration_since(*t) < CORRECTION_WINDOW)
385 .count() as u32
386 }
387
388 pub fn correction_rate(&self) -> f64 {
390 let count = self.correction_count();
391 if count == 0 {
392 return 0.0;
393 }
394 let window_mins = CORRECTION_WINDOW.as_secs_f64() / 60.0;
395 f64::from(count) / window_mins
396 }
397
398 pub fn prune_corrections(&mut self) {
400 let now = Instant::now();
401 self.correction_signals
402 .retain(|(t, _)| now.duration_since(*t) < CORRECTION_WINDOW);
403 self.recent_reads
404 .retain(|_, (t, _)| now.duration_since(*t) < CORRECTION_WINDOW);
405 self.recent_commands
406 .retain(|_, t| now.duration_since(*t) < CORRECTION_WINDOW);
407 }
408
409 pub fn reset(&mut self) {
411 self.call_history.clear();
412 self.duplicate_counts.clear();
413 self.search_group_history.clear();
414 self.recent_search_patterns.clear();
415 self.correction_signals.clear();
416 self.recent_reads.clear();
417 self.recent_commands.clear();
418 self.total_calls = 0;
419 }
420
421 fn prune_window(&mut self, now: Instant) {
422 for entries in self.call_history.values_mut() {
423 entries.retain(|t| now.duration_since(*t) < self.window);
424 }
425 self.search_group_history
426 .retain(|t| now.duration_since(*t) < self.window);
427 }
428
429 fn count_similar_patterns(&self, new_pattern: &str) -> u32 {
430 let new_lower = new_pattern.to_lowercase();
431 let new_root = extract_alpha_root(&new_lower);
432
433 let mut count = 0u32;
434 for existing in &self.recent_search_patterns {
435 let existing_lower = existing.to_lowercase();
436 if patterns_are_similar(&new_lower, &existing_lower) {
437 count += 1;
438 } else if new_root.len() >= 4 {
439 let existing_root = extract_alpha_root(&existing_lower);
440 if existing_root.len() >= 4
441 && (new_root.starts_with(&existing_root)
442 || existing_root.starts_with(&new_root))
443 {
444 count += 1;
445 }
446 }
447 }
448 count
449 }
450
451 fn block_message(&self, tool: &str, count: u32) -> String {
452 if Self::is_search_tool(tool) {
453 self.search_block_message(count)
454 } else {
455 format!(
456 "LOOP DETECTED: {tool} called {count}x with same/similar args. \
457 Call blocked. Change your approach — the current strategy is not working."
458 )
459 }
460 }
461
462 #[allow(clippy::unused_self)]
463 fn search_block_message(&self, count: u32) -> String {
464 format!(
465 "LOOP DETECTED: You've searched {count}x with similar patterns. STOP searching and change strategy. \
466 1) Use ctx_tree to understand the project structure first. \
467 2) Narrow your search with the 'path' parameter to a specific directory. \
468 3) Use ctx_read with mode='map' to understand a file before searching more."
469 )
470 }
471
472 fn search_group_block_message(&self, count: u32) -> String {
473 format!(
474 "LOOP DETECTED: {count} search calls in {}s — too many. STOP and rethink. \
475 1) Use ctx_tree to map the project structure. \
476 2) Pick ONE specific directory and search there with the 'path' parameter. \
477 3) Read files with ctx_read mode='map' instead of searching blindly.",
478 self.window.as_secs()
479 )
480 }
481}
482
483fn normalize_shell_command(cmd: &str) -> String {
484 cmd.split_whitespace()
485 .take(5)
486 .collect::<Vec<_>>()
487 .join(" ")
488 .to_lowercase()
489}
490
491fn extract_alpha_root(pattern: &str) -> String {
492 pattern
493 .chars()
494 .take_while(|c| c.is_alphanumeric())
495 .collect()
496}
497
498fn patterns_are_similar(a: &str, b: &str) -> bool {
499 if a == b {
500 return true;
501 }
502 if a.contains(b) || b.contains(a) {
503 return true;
504 }
505 let a_alpha: String = a.chars().filter(|c| c.is_alphanumeric()).collect();
506 let b_alpha: String = b.chars().filter(|c| c.is_alphanumeric()).collect();
507 if a_alpha.len() >= 3
508 && b_alpha.len() >= 3
509 && (a_alpha.contains(&b_alpha) || b_alpha.contains(&a_alpha))
510 {
511 return true;
512 }
513 false
514}
515
516fn canonical_json(value: &serde_json::Value) -> String {
517 match value {
518 serde_json::Value::Object(map) => {
519 let mut keys: Vec<&String> = map.keys().collect();
520 keys.sort();
521 let entries: Vec<String> = keys
522 .iter()
523 .map(|k| format!("{}:{}", k, canonical_json(&map[*k])))
524 .collect();
525 format!("{{{}}}", entries.join(","))
526 }
527 serde_json::Value::Array(arr) => {
528 let entries: Vec<String> = arr.iter().map(canonical_json).collect();
529 format!("[{}]", entries.join(","))
530 }
531 _ => value.to_string(),
532 }
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538
539 fn test_config(normal: u32, reduced: u32, blocked: u32) -> LoopDetectionConfig {
540 LoopDetectionConfig {
541 normal_threshold: normal,
542 reduced_threshold: reduced,
543 blocked_threshold: blocked,
544 window_secs: 300,
545 search_group_limit: 10,
546 tool_total_limits: std::collections::HashMap::new(),
547 }
548 }
549
550 #[test]
551 fn normal_calls_pass_through() {
552 let mut detector = LoopDetector::new();
553 let r1 = detector.record_call("ctx_read", "abc123");
554 assert_eq!(r1.level, ThrottleLevel::Normal);
555 assert_eq!(r1.call_count, 1);
556 assert!(r1.message.is_none());
557 }
558
559 #[test]
560 fn repeated_calls_trigger_reduced() {
561 let _lock = crate::core::data_dir::test_env_lock();
562 std::env::set_var("LEAN_CTX_META", "1");
563 let cfg = LoopDetectionConfig::default();
564 let mut detector = LoopDetector::with_config(&cfg);
565 for _ in 0..cfg.normal_threshold {
566 detector.record_call("ctx_read", "same_fp");
567 }
568 let result = detector.record_call("ctx_read", "same_fp");
569 assert_eq!(result.level, ThrottleLevel::Reduced);
570 assert!(result.message.is_some());
571 std::env::remove_var("LEAN_CTX_META");
572 }
573
574 #[test]
575 fn excessive_calls_get_blocked_when_enabled() {
576 let cfg = LoopDetectionConfig {
578 blocked_threshold: 6,
579 ..Default::default()
580 };
581 let mut detector = LoopDetector::with_config(&cfg);
582 for _ in 0..cfg.blocked_threshold {
583 detector.record_call("ctx_shell", "same_fp");
584 }
585 let result = detector.record_call("ctx_shell", "same_fp");
586 assert_eq!(result.level, ThrottleLevel::Blocked);
587 assert!(result.message.unwrap().contains("LOOP DETECTED"));
588 }
589
590 #[test]
591 fn blocking_disabled_by_default() {
592 let cfg = LoopDetectionConfig::default();
594 assert_eq!(cfg.blocked_threshold, 0);
595 let mut detector = LoopDetector::with_config(&cfg);
596 for _ in 0..100 {
598 detector.record_call("ctx_shell", "same_fp");
599 }
600 let result = detector.record_call("ctx_shell", "same_fp");
601 assert_ne!(result.level, ThrottleLevel::Blocked);
603 }
604
605 #[test]
606 fn different_args_tracked_separately() {
607 let mut detector = LoopDetector::new();
608 for _ in 0..10 {
609 detector.record_call("ctx_read", "fp_a");
610 }
611 let result = detector.record_call("ctx_read", "fp_b");
612 assert_eq!(result.level, ThrottleLevel::Normal);
613 assert_eq!(result.call_count, 1);
614 }
615
616 #[test]
617 fn fingerprint_deterministic() {
618 let args = serde_json::json!({"path": "test.rs", "mode": "full"});
619 let fp1 = LoopDetector::fingerprint(&args);
620 let fp2 = LoopDetector::fingerprint(&args);
621 assert_eq!(fp1, fp2);
622 }
623
624 #[test]
625 fn fingerprint_order_independent() {
626 let a = serde_json::json!({"mode": "full", "path": "test.rs"});
627 let b = serde_json::json!({"path": "test.rs", "mode": "full"});
628 assert_eq!(LoopDetector::fingerprint(&a), LoopDetector::fingerprint(&b));
629 }
630
631 #[test]
632 fn stats_shows_duplicates() {
633 let mut detector = LoopDetector::new();
634 for _ in 0..5 {
635 detector.record_call("ctx_read", "fp_a");
636 }
637 detector.record_call("ctx_shell", "fp_b");
638 let stats = detector.stats();
639 assert_eq!(stats.len(), 1);
640 assert_eq!(stats[0].1, 5);
641 }
642
643 #[test]
644 fn reset_clears_state() {
645 let mut detector = LoopDetector::new();
646 for _ in 0..5 {
647 detector.record_call("ctx_read", "fp_a");
648 }
649 detector.reset();
650 let result = detector.record_call("ctx_read", "fp_a");
651 assert_eq!(result.call_count, 1);
652 }
653
654 #[test]
655 fn custom_thresholds_from_config() {
656 let cfg = test_config(1, 2, 3);
657 let mut detector = LoopDetector::with_config(&cfg);
658 detector.record_call("ctx_read", "fp");
659 let r = detector.record_call("ctx_read", "fp");
660 assert_eq!(r.level, ThrottleLevel::Reduced);
661 detector.record_call("ctx_read", "fp");
662 let r = detector.record_call("ctx_read", "fp");
663 assert_eq!(r.level, ThrottleLevel::Blocked);
664 }
665
666 #[test]
667 fn similar_patterns_detected() {
668 assert!(patterns_are_similar("compress", "compress"));
669 assert!(patterns_are_similar("compress", "compression"));
670 assert!(patterns_are_similar("compress.*data", "compress"));
671 assert!(!patterns_are_similar("foo", "bar"));
672 assert!(!patterns_are_similar("ab", "cd"));
673 }
674
675 #[test]
676 fn search_group_tracking_when_blocking_enabled() {
677 let cfg = LoopDetectionConfig {
679 search_group_limit: 5,
680 blocked_threshold: 6, ..Default::default()
682 };
683 let mut detector = LoopDetector::with_config(&cfg);
684 for i in 0..5 {
685 let fp = format!("fp_{i}");
686 let r = detector.record_search("ctx_search", &fp, Some(&format!("pattern_{i}")));
687 assert_ne!(r.level, ThrottleLevel::Blocked, "call {i} should not block");
688 }
689 let r = detector.record_search("ctx_search", "fp_5", Some("pattern_5"));
690 assert_eq!(r.level, ThrottleLevel::Blocked);
691 assert!(r.message.unwrap().contains("search calls"));
692 }
693
694 #[test]
695 fn similar_search_patterns_trigger_block_when_enabled() {
696 let cfg = LoopDetectionConfig {
698 blocked_threshold: 6,
699 ..Default::default()
700 };
701 let mut detector = LoopDetector::with_config(&cfg);
702 let variants = [
703 "compress",
704 "compression",
705 "compress.*data",
706 "compress_output",
707 "compressor",
708 "compress_result",
709 "compress_file",
710 ];
711 for (i, pat) in variants
712 .iter()
713 .enumerate()
714 .take(cfg.blocked_threshold as usize)
715 {
716 detector.record_search("ctx_search", &format!("fp_{i}"), Some(pat));
717 }
718 let r = detector.record_search("ctx_search", "fp_new", Some("compress_all"));
719 assert_eq!(r.level, ThrottleLevel::Blocked);
720 }
721
722 #[test]
723 fn is_search_tool_detection() {
724 assert!(LoopDetector::is_search_tool("ctx_search"));
725 assert!(LoopDetector::is_search_tool("ctx_semantic_search"));
726 assert!(!LoopDetector::is_search_tool("ctx_read"));
727 assert!(!LoopDetector::is_search_tool("ctx_shell"));
728 }
729
730 #[test]
731 fn is_search_shell_command_detection() {
732 assert!(LoopDetector::is_search_shell_command("grep -r foo ."));
733 assert!(LoopDetector::is_search_shell_command("rg pattern src/"));
734 assert!(LoopDetector::is_search_shell_command("find . -name '*.rs'"));
735 assert!(!LoopDetector::is_search_shell_command("cargo build"));
736 assert!(!LoopDetector::is_search_shell_command("git status"));
737 }
738
739 #[test]
740 fn correction_fresh_reread_detected() {
741 let mut detector = LoopDetector::new();
742 detector.record_read_for_correction("src/main.rs", "full", false);
744 detector.record_read_for_correction("src/lib.rs", "full", false);
745 detector.record_read_for_correction("src/util.rs", "full", false);
746 detector.record_read_for_correction("src/main.rs", "full", false);
748 assert_eq!(detector.correction_count(), 0);
749 detector.record_read_for_correction("src/main.rs", "full", true);
751 assert_eq!(detector.correction_count(), 1);
752 }
753
754 #[test]
755 fn correction_mode_bounce_detected() {
756 let mut detector = LoopDetector::new();
757 for i in 0..COLD_START_CALLS {
759 detector.record_read_for_correction(&format!("f{i}.rs"), "full", false);
760 }
761 detector.record_read_for_correction("src/cache.rs", "map", false);
763 assert_eq!(detector.correction_count(), 0);
764 detector.record_read_for_correction("src/cache.rs", "full", false);
766 assert_eq!(detector.correction_count(), 1);
767 }
768
769 #[test]
770 fn correction_shell_rerun_detected() {
771 let mut detector = LoopDetector::new();
772 for i in 0..COLD_START_CALLS {
774 detector.record_shell_for_correction(&format!("echo {i}"));
775 }
776 detector.record_shell_for_correction("cargo test --lib");
778 assert_eq!(detector.correction_count(), 0);
779 detector.record_shell_for_correction("cargo test --lib");
781 assert_eq!(detector.correction_count(), 1);
782 }
783
784 #[test]
785 fn correction_rate_calculation() {
786 let mut detector = LoopDetector::new();
787 for i in 0..COLD_START_CALLS {
788 detector.record_shell_for_correction(&format!("init{i}"));
789 }
790 detector.record_shell_for_correction("cargo check");
791 detector.record_shell_for_correction("cargo check");
792 detector.record_shell_for_correction("cargo check");
793 assert_eq!(detector.correction_count(), 2);
795 assert!(detector.correction_rate() > 0.0);
796 }
797
798 #[test]
799 fn correction_cold_start_ignored() {
800 let mut detector = LoopDetector::new();
801 detector.record_shell_for_correction("cargo check");
803 detector.record_shell_for_correction("cargo check");
804 detector.record_shell_for_correction("cargo check");
805 assert_eq!(detector.correction_count(), 0);
806 }
807
808 #[test]
809 fn search_block_message_has_guidance_when_blocking_enabled() {
810 let cfg = LoopDetectionConfig {
812 blocked_threshold: 6,
813 search_group_limit: 8,
814 ..Default::default()
815 };
816 let mut detector = LoopDetector::with_config(&cfg);
817 for i in 0..10 {
818 detector.record_search("ctx_search", &format!("fp_{i}"), Some("compress"));
819 }
820 let r = detector.record_search("ctx_search", "fp_new", Some("compress"));
821 assert_eq!(r.level, ThrottleLevel::Blocked);
822 let msg = r.message.unwrap();
823 assert!(msg.contains("ctx_tree"));
824 assert!(msg.contains("path"));
825 assert!(msg.contains("ctx_read"));
826 }
827
828 #[test]
829 fn error_outcome_undoes_pre_dispatch_count() {
830 let cfg = test_config(2, 4, 0);
831 let mut detector = LoopDetector::with_config(&cfg);
832
833 detector.record_call("ctx_read", "fp1");
834 detector.record_call("ctx_read", "fp1");
835 detector.record_error_outcome("ctx_read", "fp1");
836
837 let r = detector.record_call("ctx_read", "fp1");
838 assert_eq!(r.call_count, 2, "error should have undone one count");
839 assert_eq!(r.level, ThrottleLevel::Normal);
840 }
841
842 #[test]
843 fn repeated_errors_dont_trigger_reduced() {
844 let cfg = test_config(2, 4, 0);
845 let mut detector = LoopDetector::with_config(&cfg);
846
847 for _ in 0..5 {
848 detector.record_call("ctx_read", "fp1");
849 detector.record_error_outcome("ctx_read", "fp1");
850 }
851
852 let r = detector.record_call("ctx_read", "fp1");
853 assert_eq!(
854 r.level,
855 ThrottleLevel::Normal,
856 "5 failed retries should not throttle"
857 );
858 }
859
860 #[test]
861 fn mixed_success_and_error_correct_count() {
862 let cfg = test_config(2, 4, 0);
863 let mut detector = LoopDetector::with_config(&cfg);
864
865 detector.record_call("ctx_read", "fp1");
866 detector.record_error_outcome("ctx_read", "fp1");
867 detector.record_call("ctx_read", "fp1");
868 detector.record_error_outcome("ctx_read", "fp1");
869 detector.record_call("ctx_read", "fp1");
870 assert_eq!(detector.record_call("ctx_read", "fp1").call_count, 2);
872 }
873
874 #[test]
875 fn error_outcome_on_nonexistent_key_is_noop() {
876 let mut detector = LoopDetector::new();
877 detector.record_error_outcome("ctx_read", "never_called");
878 let r = detector.record_call("ctx_read", "never_called");
879 assert_eq!(r.call_count, 1);
880 }
881
882 #[test]
883 fn error_outcome_doesnt_go_negative() {
884 let mut detector = LoopDetector::new();
885 detector.record_call("ctx_read", "fp1");
886 detector.record_error_outcome("ctx_read", "fp1");
887 detector.record_error_outcome("ctx_read", "fp1");
888 let r = detector.record_call("ctx_read", "fp1");
889 assert_eq!(r.call_count, 1, "count should never go below 0");
890 }
891
892 #[test]
893 fn error_in_tool_a_doesnt_affect_tool_b() {
894 let cfg = test_config(2, 4, 0);
895 let mut detector = LoopDetector::with_config(&cfg);
896
897 for _ in 0..5 {
898 detector.record_call("ctx_read", "fp1");
899 detector.record_error_outcome("ctx_read", "fp1");
900 }
901
902 let r = detector.record_call("ctx_shell", "fp_shell");
903 assert_eq!(r.call_count, 1);
904 assert_eq!(r.level, ThrottleLevel::Normal);
905 }
906
907 #[test]
908 fn different_fingerprints_independent_after_errors() {
909 let cfg = test_config(2, 4, 0);
910 let mut detector = LoopDetector::with_config(&cfg);
911
912 detector.record_call("ctx_read", "fp_a");
913 detector.record_error_outcome("ctx_read", "fp_a");
914
915 detector.record_call("ctx_read", "fp_b");
916 let r = detector.record_call("ctx_read", "fp_b");
917 assert_eq!(r.call_count, 2);
918
919 let r_a = detector.record_call("ctx_read", "fp_a");
920 assert_eq!(r_a.call_count, 1, "fp_a count should be reset to 0 then +1");
921 }
922
923 #[test]
924 fn correction_degrade_recovery_after_prune() {
925 let mut detector = LoopDetector::new();
926 for i in 0..4u32 {
927 detector.record_read_for_correction(&format!("warmup{i}.rs"), "full", false);
928 }
929 detector.record_read_for_correction("target.rs", "full", false);
930 detector.record_read_for_correction("target.rs", "full", true);
931 assert!(detector.correction_count() > 0);
932 detector.prune_corrections();
933 assert!(detector.correction_count() >= 1);
936 }
937
938 #[test]
939 fn success_after_errors_resets_to_normal() {
940 let cfg = test_config(2, 4, 0);
941 let mut detector = LoopDetector::with_config(&cfg);
942
943 for _ in 0..3 {
944 detector.record_call("ctx_read", "fp1");
945 detector.record_error_outcome("ctx_read", "fp1");
946 }
947
948 let r = detector.record_call("ctx_read", "fp1");
949 assert_eq!(r.level, ThrottleLevel::Normal);
950 assert_eq!(r.call_count, 1);
951 }
952
953 #[test]
954 fn throttle_result_default_is_normal() {
955 let r = ThrottleResult::default();
956 assert_eq!(r.level, ThrottleLevel::Normal);
957 assert_eq!(r.call_count, 0);
958 assert!(r.message.is_none());
959 }
960}