1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use super::config::LoopDetectionConfig;
5
6const SEARCH_TOOLS: &[&str] = &["ctx_search", "ctx_semantic_search"];
7
8const SEARCH_SHELL_PREFIXES: &[&str] = &["grep ", "rg ", "find ", "fd ", "ag ", "ack "];
9
10const CORRECTION_WINDOW: Duration = Duration::from_mins(2);
11const MODE_BOUNCE_WINDOW: Duration = Duration::from_secs(30);
12const SHELL_RERUN_WINDOW: Duration = Duration::from_mins(1);
13const COLD_START_CALLS: u32 = 3;
14
15#[derive(Debug, Clone, PartialEq)]
17pub enum CorrectionKind {
18 FreshReRead,
19 ShellReRun,
20 ModeBounce,
21}
22
23#[derive(Debug, Clone)]
25pub struct LoopDetector {
26 call_history: HashMap<String, Vec<Instant>>,
27 duplicate_counts: HashMap<String, u32>,
28 tool_total_counts: HashMap<String, u32>,
29 tool_total_limits: HashMap<String, u32>,
30 search_group_history: Vec<Instant>,
31 recent_search_patterns: Vec<String>,
32 normal_threshold: u32,
33 reduced_threshold: u32,
34 blocked_threshold: u32,
35 window: Duration,
36 search_group_limit: u32,
37 correction_signals: Vec<(Instant, CorrectionKind)>,
39 recent_reads: HashMap<String, (Instant, String)>,
40 recent_commands: HashMap<String, Instant>,
41 total_calls: u32,
42}
43
44#[derive(Debug, Clone, PartialEq)]
46pub enum ThrottleLevel {
47 Normal,
48 Reduced,
49 Blocked,
50}
51
52#[derive(Debug, Clone)]
54pub struct ThrottleResult {
55 pub level: ThrottleLevel,
56 pub call_count: u32,
57 pub message: Option<String>,
58}
59
60impl Default for LoopDetector {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl LoopDetector {
67 pub fn new() -> Self {
69 Self::with_config(&LoopDetectionConfig::default())
70 }
71
72 pub fn with_config(cfg: &LoopDetectionConfig) -> Self {
75 Self {
76 call_history: HashMap::new(),
77 duplicate_counts: HashMap::new(),
78 tool_total_counts: HashMap::new(),
79 tool_total_limits: cfg.tool_total_limits.clone(),
80 search_group_history: Vec::new(),
81 recent_search_patterns: Vec::new(),
82 normal_threshold: cfg.normal_threshold.max(1),
83 reduced_threshold: cfg.reduced_threshold.max(2),
84 blocked_threshold: cfg.blocked_threshold,
85 window: Duration::from_secs(cfg.window_secs),
86 search_group_limit: if cfg.blocked_threshold == 0 {
87 u32::MAX
88 } else {
89 cfg.search_group_limit.max(3)
90 },
91 correction_signals: Vec::new(),
92 recent_reads: HashMap::new(),
93 recent_commands: HashMap::new(),
94 total_calls: 0,
95 }
96 }
97
98 pub fn record_call(&mut self, tool: &str, args_fingerprint: &str) -> ThrottleResult {
100 let now = Instant::now();
101 self.prune_window(now);
102
103 let total = self.tool_total_counts.entry(tool.to_string()).or_insert(0);
105 *total += 1;
106 let total_count = *total;
107
108 if let Some(&limit) = self.tool_total_limits.get(tool) {
109 if total_count > limit {
110 let msg = if crate::core::protocol::meta_visible() {
111 Some(format!(
112 "Warning: {tool} called {total_count}x total (limit: {limit}). \
113 Consider ctx_compress or narrowing scope."
114 ))
115 } else {
116 None
117 };
118 return ThrottleResult {
119 level: ThrottleLevel::Reduced,
120 call_count: total_count,
121 message: msg,
122 };
123 }
124 }
125
126 let key = format!("{tool}:{args_fingerprint}");
127 let entries = self.call_history.entry(key.clone()).or_default();
128 entries.push(now);
129 let count = entries.len() as u32;
130 *self.duplicate_counts.entry(key).or_default() = count;
131
132 if self.blocked_threshold > 0 && count > self.blocked_threshold {
133 return ThrottleResult {
134 level: ThrottleLevel::Blocked,
135 call_count: count,
136 message: Some(self.block_message(tool, count)),
137 };
138 }
139 if count > self.reduced_threshold {
140 if !crate::core::protocol::meta_visible() {
141 return ThrottleResult {
142 level: ThrottleLevel::Reduced,
143 call_count: count,
144 message: None,
145 };
146 }
147 return ThrottleResult {
148 level: ThrottleLevel::Reduced,
149 call_count: count,
150 message: Some(format!(
151 "Warning: {tool} called {count}x with same args. \
152 Results reduced. Try a different approach or narrow your scope."
153 )),
154 };
155 }
156 if count > self.normal_threshold {
157 if !crate::core::protocol::meta_visible() {
158 return ThrottleResult {
159 level: ThrottleLevel::Reduced,
160 call_count: count,
161 message: None,
162 };
163 }
164 return ThrottleResult {
165 level: ThrottleLevel::Reduced,
166 call_count: count,
167 message: Some(format!(
168 "Note: {tool} called {count}x with similar args. Consider narrowing scope."
169 )),
170 };
171 }
172 ThrottleResult {
173 level: ThrottleLevel::Normal,
174 call_count: count,
175 message: None,
176 }
177 }
178
179 pub fn record_search(
182 &mut self,
183 tool: &str,
184 args_fingerprint: &str,
185 search_pattern: Option<&str>,
186 ) -> ThrottleResult {
187 let now = Instant::now();
188
189 self.search_group_history.push(now);
190 let search_count = self.search_group_history.len() as u32;
191
192 let similar_count = if let Some(pat) = search_pattern {
193 let sc = self.count_similar_patterns(pat);
194 if !pat.is_empty() {
195 self.recent_search_patterns.push(pat.to_string());
196 if self.recent_search_patterns.len() > 15 {
197 self.recent_search_patterns.remove(0);
198 }
199 }
200 sc
201 } else {
202 0
203 };
204
205 if self.blocked_threshold > 0 && similar_count >= self.blocked_threshold {
207 return ThrottleResult {
208 level: ThrottleLevel::Blocked,
209 call_count: similar_count,
210 message: Some(self.search_block_message(similar_count)),
211 };
212 }
213
214 if self.blocked_threshold > 0 && search_count > self.search_group_limit {
216 return ThrottleResult {
217 level: ThrottleLevel::Blocked,
218 call_count: search_count,
219 message: Some(self.search_group_block_message(search_count)),
220 };
221 }
222
223 if similar_count >= self.reduced_threshold {
224 if !crate::core::protocol::meta_visible() {
225 return ThrottleResult {
226 level: ThrottleLevel::Reduced,
227 call_count: similar_count,
228 message: None,
229 };
230 }
231 return ThrottleResult {
232 level: ThrottleLevel::Reduced,
233 call_count: similar_count,
234 message: Some(format!(
235 "Warning: You've searched for similar patterns {similar_count}x. \
236 Narrow your search with the 'path' parameter or try ctx_tree first."
237 )),
238 };
239 }
240
241 if search_count > self.search_group_limit.saturating_sub(3) {
242 let per_fp = self.record_call(tool, args_fingerprint);
243 if per_fp.level != ThrottleLevel::Normal {
244 return per_fp;
245 }
246 if !crate::core::protocol::meta_visible() {
247 return ThrottleResult {
248 level: ThrottleLevel::Reduced,
249 call_count: search_count,
250 message: None,
251 };
252 }
253 return ThrottleResult {
254 level: ThrottleLevel::Reduced,
255 call_count: search_count,
256 message: Some(format!(
257 "Note: {search_count} search calls in the last {}s. \
258 Use ctx_tree to orient first, then scope searches with 'path'.",
259 self.window.as_secs()
260 )),
261 };
262 }
263
264 self.record_call(tool, args_fingerprint)
265 }
266
267 pub fn is_search_tool(tool: &str) -> bool {
269 SEARCH_TOOLS.contains(&tool)
270 }
271
272 pub fn is_search_shell_command(command: &str) -> bool {
274 let cmd = command.trim_start();
275 SEARCH_SHELL_PREFIXES.iter().any(|p| cmd.starts_with(p))
276 }
277
278 pub fn fingerprint(args: &serde_json::Value) -> String {
280 use std::collections::hash_map::DefaultHasher;
281 use std::hash::{Hash, Hasher};
282
283 let canonical = canonical_json(args);
284 let mut hasher = DefaultHasher::new();
285 canonical.hash(&mut hasher);
286 format!("{:016x}", hasher.finish())
287 }
288
289 pub fn stats(&self) -> Vec<(String, u32)> {
291 let mut entries: Vec<(String, u32)> = self
292 .duplicate_counts
293 .iter()
294 .filter(|(_, &count)| count > 1)
295 .map(|(k, &v)| (k.clone(), v))
296 .collect();
297 entries.sort_by_key(|x| std::cmp::Reverse(x.1));
298 entries
299 }
300
301 pub fn record_read_for_correction(&mut self, path: &str, mode: &str, fresh: bool) {
305 self.total_calls += 1;
306 let now = Instant::now();
307
308 if self.total_calls <= COLD_START_CALLS {
309 self.recent_reads
310 .insert(path.to_string(), (now, mode.to_string()));
311 return;
312 }
313
314 if fresh {
315 if let Some((prev_time, _)) = self.recent_reads.get(path) {
316 if now.duration_since(*prev_time) < CORRECTION_WINDOW {
317 self.correction_signals
318 .push((now, CorrectionKind::FreshReRead));
319 }
320 }
321 }
322
323 if mode == "full" {
324 if let Some((prev_time, prev_mode)) = self.recent_reads.get(path) {
325 let is_bounce = (prev_mode == "map" || prev_mode == "signatures")
326 && now.duration_since(*prev_time) < MODE_BOUNCE_WINDOW;
327 if is_bounce {
328 self.correction_signals
329 .push((now, CorrectionKind::ModeBounce));
330 }
331 }
332 }
333
334 self.recent_reads
335 .insert(path.to_string(), (now, mode.to_string()));
336 }
337
338 pub fn record_shell_for_correction(&mut self, command: &str) {
340 self.total_calls += 1;
341 let now = Instant::now();
342
343 if self.total_calls <= COLD_START_CALLS {
344 self.recent_commands.insert(command.to_string(), now);
345 return;
346 }
347
348 let key = normalize_shell_command(command);
349 if let Some(prev_time) = self.recent_commands.get(&key) {
350 if now.duration_since(*prev_time) < SHELL_RERUN_WINDOW {
351 self.correction_signals
352 .push((now, CorrectionKind::ShellReRun));
353 }
354 }
355 self.recent_commands.insert(key, now);
356 }
357
358 pub fn correction_count(&self) -> u32 {
360 let now = Instant::now();
361 self.correction_signals
362 .iter()
363 .filter(|(t, _)| now.duration_since(*t) < CORRECTION_WINDOW)
364 .count() as u32
365 }
366
367 pub fn correction_rate(&self) -> f64 {
369 let count = self.correction_count();
370 if count == 0 {
371 return 0.0;
372 }
373 let window_mins = CORRECTION_WINDOW.as_secs_f64() / 60.0;
374 f64::from(count) / window_mins
375 }
376
377 pub fn prune_corrections(&mut self) {
379 let now = Instant::now();
380 self.correction_signals
381 .retain(|(t, _)| now.duration_since(*t) < CORRECTION_WINDOW);
382 self.recent_reads
383 .retain(|_, (t, _)| now.duration_since(*t) < CORRECTION_WINDOW);
384 self.recent_commands
385 .retain(|_, t| now.duration_since(*t) < CORRECTION_WINDOW);
386 }
387
388 pub fn reset(&mut self) {
390 self.call_history.clear();
391 self.duplicate_counts.clear();
392 self.search_group_history.clear();
393 self.recent_search_patterns.clear();
394 self.correction_signals.clear();
395 self.recent_reads.clear();
396 self.recent_commands.clear();
397 self.total_calls = 0;
398 }
399
400 fn prune_window(&mut self, now: Instant) {
401 for entries in self.call_history.values_mut() {
402 entries.retain(|t| now.duration_since(*t) < self.window);
403 }
404 self.search_group_history
405 .retain(|t| now.duration_since(*t) < self.window);
406 }
407
408 fn count_similar_patterns(&self, new_pattern: &str) -> u32 {
409 let new_lower = new_pattern.to_lowercase();
410 let new_root = extract_alpha_root(&new_lower);
411
412 let mut count = 0u32;
413 for existing in &self.recent_search_patterns {
414 let existing_lower = existing.to_lowercase();
415 if patterns_are_similar(&new_lower, &existing_lower) {
416 count += 1;
417 } else if new_root.len() >= 4 {
418 let existing_root = extract_alpha_root(&existing_lower);
419 if existing_root.len() >= 4
420 && (new_root.starts_with(&existing_root)
421 || existing_root.starts_with(&new_root))
422 {
423 count += 1;
424 }
425 }
426 }
427 count
428 }
429
430 fn block_message(&self, tool: &str, count: u32) -> String {
431 if Self::is_search_tool(tool) {
432 self.search_block_message(count)
433 } else {
434 format!(
435 "LOOP DETECTED: {tool} called {count}x with same/similar args. \
436 Call blocked. Change your approach — the current strategy is not working."
437 )
438 }
439 }
440
441 #[allow(clippy::unused_self)]
442 fn search_block_message(&self, count: u32) -> String {
443 format!(
444 "LOOP DETECTED: You've searched {count}x with similar patterns. STOP searching and change strategy. \
445 1) Use ctx_tree to understand the project structure first. \
446 2) Narrow your search with the 'path' parameter to a specific directory. \
447 3) Use ctx_read with mode='map' to understand a file before searching more."
448 )
449 }
450
451 fn search_group_block_message(&self, count: u32) -> String {
452 format!(
453 "LOOP DETECTED: {count} search calls in {}s — too many. STOP and rethink. \
454 1) Use ctx_tree to map the project structure. \
455 2) Pick ONE specific directory and search there with the 'path' parameter. \
456 3) Read files with ctx_read mode='map' instead of searching blindly.",
457 self.window.as_secs()
458 )
459 }
460}
461
462fn normalize_shell_command(cmd: &str) -> String {
463 cmd.split_whitespace()
464 .take(5)
465 .collect::<Vec<_>>()
466 .join(" ")
467 .to_lowercase()
468}
469
470fn extract_alpha_root(pattern: &str) -> String {
471 pattern
472 .chars()
473 .take_while(|c| c.is_alphanumeric())
474 .collect()
475}
476
477fn patterns_are_similar(a: &str, b: &str) -> bool {
478 if a == b {
479 return true;
480 }
481 if a.contains(b) || b.contains(a) {
482 return true;
483 }
484 let a_alpha: String = a.chars().filter(|c| c.is_alphanumeric()).collect();
485 let b_alpha: String = b.chars().filter(|c| c.is_alphanumeric()).collect();
486 if a_alpha.len() >= 3
487 && b_alpha.len() >= 3
488 && (a_alpha.contains(&b_alpha) || b_alpha.contains(&a_alpha))
489 {
490 return true;
491 }
492 false
493}
494
495fn canonical_json(value: &serde_json::Value) -> String {
496 match value {
497 serde_json::Value::Object(map) => {
498 let mut keys: Vec<&String> = map.keys().collect();
499 keys.sort();
500 let entries: Vec<String> = keys
501 .iter()
502 .map(|k| format!("{}:{}", k, canonical_json(&map[*k])))
503 .collect();
504 format!("{{{}}}", entries.join(","))
505 }
506 serde_json::Value::Array(arr) => {
507 let entries: Vec<String> = arr.iter().map(canonical_json).collect();
508 format!("[{}]", entries.join(","))
509 }
510 _ => value.to_string(),
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517
518 fn test_config(normal: u32, reduced: u32, blocked: u32) -> LoopDetectionConfig {
519 LoopDetectionConfig {
520 normal_threshold: normal,
521 reduced_threshold: reduced,
522 blocked_threshold: blocked,
523 window_secs: 300,
524 search_group_limit: 10,
525 tool_total_limits: std::collections::HashMap::new(),
526 }
527 }
528
529 #[test]
530 fn normal_calls_pass_through() {
531 let mut detector = LoopDetector::new();
532 let r1 = detector.record_call("ctx_read", "abc123");
533 assert_eq!(r1.level, ThrottleLevel::Normal);
534 assert_eq!(r1.call_count, 1);
535 assert!(r1.message.is_none());
536 }
537
538 #[test]
539 fn repeated_calls_trigger_reduced() {
540 let _lock = crate::core::data_dir::test_env_lock();
541 std::env::set_var("LEAN_CTX_META", "1");
542 let cfg = LoopDetectionConfig::default();
543 let mut detector = LoopDetector::with_config(&cfg);
544 for _ in 0..cfg.normal_threshold {
545 detector.record_call("ctx_read", "same_fp");
546 }
547 let result = detector.record_call("ctx_read", "same_fp");
548 assert_eq!(result.level, ThrottleLevel::Reduced);
549 assert!(result.message.is_some());
550 std::env::remove_var("LEAN_CTX_META");
551 }
552
553 #[test]
554 fn excessive_calls_get_blocked_when_enabled() {
555 let cfg = LoopDetectionConfig {
557 blocked_threshold: 6,
558 ..Default::default()
559 };
560 let mut detector = LoopDetector::with_config(&cfg);
561 for _ in 0..cfg.blocked_threshold {
562 detector.record_call("ctx_shell", "same_fp");
563 }
564 let result = detector.record_call("ctx_shell", "same_fp");
565 assert_eq!(result.level, ThrottleLevel::Blocked);
566 assert!(result.message.unwrap().contains("LOOP DETECTED"));
567 }
568
569 #[test]
570 fn blocking_disabled_by_default() {
571 let cfg = LoopDetectionConfig::default();
573 assert_eq!(cfg.blocked_threshold, 0);
574 let mut detector = LoopDetector::with_config(&cfg);
575 for _ in 0..100 {
577 detector.record_call("ctx_shell", "same_fp");
578 }
579 let result = detector.record_call("ctx_shell", "same_fp");
580 assert_ne!(result.level, ThrottleLevel::Blocked);
582 }
583
584 #[test]
585 fn different_args_tracked_separately() {
586 let mut detector = LoopDetector::new();
587 for _ in 0..10 {
588 detector.record_call("ctx_read", "fp_a");
589 }
590 let result = detector.record_call("ctx_read", "fp_b");
591 assert_eq!(result.level, ThrottleLevel::Normal);
592 assert_eq!(result.call_count, 1);
593 }
594
595 #[test]
596 fn fingerprint_deterministic() {
597 let args = serde_json::json!({"path": "test.rs", "mode": "full"});
598 let fp1 = LoopDetector::fingerprint(&args);
599 let fp2 = LoopDetector::fingerprint(&args);
600 assert_eq!(fp1, fp2);
601 }
602
603 #[test]
604 fn fingerprint_order_independent() {
605 let a = serde_json::json!({"mode": "full", "path": "test.rs"});
606 let b = serde_json::json!({"path": "test.rs", "mode": "full"});
607 assert_eq!(LoopDetector::fingerprint(&a), LoopDetector::fingerprint(&b));
608 }
609
610 #[test]
611 fn stats_shows_duplicates() {
612 let mut detector = LoopDetector::new();
613 for _ in 0..5 {
614 detector.record_call("ctx_read", "fp_a");
615 }
616 detector.record_call("ctx_shell", "fp_b");
617 let stats = detector.stats();
618 assert_eq!(stats.len(), 1);
619 assert_eq!(stats[0].1, 5);
620 }
621
622 #[test]
623 fn reset_clears_state() {
624 let mut detector = LoopDetector::new();
625 for _ in 0..5 {
626 detector.record_call("ctx_read", "fp_a");
627 }
628 detector.reset();
629 let result = detector.record_call("ctx_read", "fp_a");
630 assert_eq!(result.call_count, 1);
631 }
632
633 #[test]
634 fn custom_thresholds_from_config() {
635 let cfg = test_config(1, 2, 3);
636 let mut detector = LoopDetector::with_config(&cfg);
637 detector.record_call("ctx_read", "fp");
638 let r = detector.record_call("ctx_read", "fp");
639 assert_eq!(r.level, ThrottleLevel::Reduced);
640 detector.record_call("ctx_read", "fp");
641 let r = detector.record_call("ctx_read", "fp");
642 assert_eq!(r.level, ThrottleLevel::Blocked);
643 }
644
645 #[test]
646 fn similar_patterns_detected() {
647 assert!(patterns_are_similar("compress", "compress"));
648 assert!(patterns_are_similar("compress", "compression"));
649 assert!(patterns_are_similar("compress.*data", "compress"));
650 assert!(!patterns_are_similar("foo", "bar"));
651 assert!(!patterns_are_similar("ab", "cd"));
652 }
653
654 #[test]
655 fn search_group_tracking_when_blocking_enabled() {
656 let cfg = LoopDetectionConfig {
658 search_group_limit: 5,
659 blocked_threshold: 6, ..Default::default()
661 };
662 let mut detector = LoopDetector::with_config(&cfg);
663 for i in 0..5 {
664 let fp = format!("fp_{i}");
665 let r = detector.record_search("ctx_search", &fp, Some(&format!("pattern_{i}")));
666 assert_ne!(r.level, ThrottleLevel::Blocked, "call {i} should not block");
667 }
668 let r = detector.record_search("ctx_search", "fp_5", Some("pattern_5"));
669 assert_eq!(r.level, ThrottleLevel::Blocked);
670 assert!(r.message.unwrap().contains("search calls"));
671 }
672
673 #[test]
674 fn similar_search_patterns_trigger_block_when_enabled() {
675 let cfg = LoopDetectionConfig {
677 blocked_threshold: 6,
678 ..Default::default()
679 };
680 let mut detector = LoopDetector::with_config(&cfg);
681 let variants = [
682 "compress",
683 "compression",
684 "compress.*data",
685 "compress_output",
686 "compressor",
687 "compress_result",
688 "compress_file",
689 ];
690 for (i, pat) in variants
691 .iter()
692 .enumerate()
693 .take(cfg.blocked_threshold as usize)
694 {
695 detector.record_search("ctx_search", &format!("fp_{i}"), Some(pat));
696 }
697 let r = detector.record_search("ctx_search", "fp_new", Some("compress_all"));
698 assert_eq!(r.level, ThrottleLevel::Blocked);
699 }
700
701 #[test]
702 fn is_search_tool_detection() {
703 assert!(LoopDetector::is_search_tool("ctx_search"));
704 assert!(LoopDetector::is_search_tool("ctx_semantic_search"));
705 assert!(!LoopDetector::is_search_tool("ctx_read"));
706 assert!(!LoopDetector::is_search_tool("ctx_shell"));
707 }
708
709 #[test]
710 fn is_search_shell_command_detection() {
711 assert!(LoopDetector::is_search_shell_command("grep -r foo ."));
712 assert!(LoopDetector::is_search_shell_command("rg pattern src/"));
713 assert!(LoopDetector::is_search_shell_command("find . -name '*.rs'"));
714 assert!(!LoopDetector::is_search_shell_command("cargo build"));
715 assert!(!LoopDetector::is_search_shell_command("git status"));
716 }
717
718 #[test]
719 fn correction_fresh_reread_detected() {
720 let mut detector = LoopDetector::new();
721 detector.record_read_for_correction("src/main.rs", "full", false);
723 detector.record_read_for_correction("src/lib.rs", "full", false);
724 detector.record_read_for_correction("src/util.rs", "full", false);
725 detector.record_read_for_correction("src/main.rs", "full", false);
727 assert_eq!(detector.correction_count(), 0);
728 detector.record_read_for_correction("src/main.rs", "full", true);
730 assert_eq!(detector.correction_count(), 1);
731 }
732
733 #[test]
734 fn correction_mode_bounce_detected() {
735 let mut detector = LoopDetector::new();
736 for i in 0..COLD_START_CALLS {
738 detector.record_read_for_correction(&format!("f{i}.rs"), "full", false);
739 }
740 detector.record_read_for_correction("src/cache.rs", "map", false);
742 assert_eq!(detector.correction_count(), 0);
743 detector.record_read_for_correction("src/cache.rs", "full", false);
745 assert_eq!(detector.correction_count(), 1);
746 }
747
748 #[test]
749 fn correction_shell_rerun_detected() {
750 let mut detector = LoopDetector::new();
751 for i in 0..COLD_START_CALLS {
753 detector.record_shell_for_correction(&format!("echo {i}"));
754 }
755 detector.record_shell_for_correction("cargo test --lib");
757 assert_eq!(detector.correction_count(), 0);
758 detector.record_shell_for_correction("cargo test --lib");
760 assert_eq!(detector.correction_count(), 1);
761 }
762
763 #[test]
764 fn correction_rate_calculation() {
765 let mut detector = LoopDetector::new();
766 for i in 0..COLD_START_CALLS {
767 detector.record_shell_for_correction(&format!("init{i}"));
768 }
769 detector.record_shell_for_correction("cargo check");
770 detector.record_shell_for_correction("cargo check");
771 detector.record_shell_for_correction("cargo check");
772 assert_eq!(detector.correction_count(), 2);
774 assert!(detector.correction_rate() > 0.0);
775 }
776
777 #[test]
778 fn correction_cold_start_ignored() {
779 let mut detector = LoopDetector::new();
780 detector.record_shell_for_correction("cargo check");
782 detector.record_shell_for_correction("cargo check");
783 detector.record_shell_for_correction("cargo check");
784 assert_eq!(detector.correction_count(), 0);
785 }
786
787 #[test]
788 fn search_block_message_has_guidance_when_blocking_enabled() {
789 let cfg = LoopDetectionConfig {
791 blocked_threshold: 6,
792 search_group_limit: 8,
793 ..Default::default()
794 };
795 let mut detector = LoopDetector::with_config(&cfg);
796 for i in 0..10 {
797 detector.record_search("ctx_search", &format!("fp_{i}"), Some("compress"));
798 }
799 let r = detector.record_search("ctx_search", "fp_new", Some("compress"));
800 assert_eq!(r.level, ThrottleLevel::Blocked);
801 let msg = r.message.unwrap();
802 assert!(msg.contains("ctx_tree"));
803 assert!(msg.contains("path"));
804 assert!(msg.contains("ctx_read"));
805 }
806}