Skip to main content

lean_ctx/core/
loop_detection.rs

1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use super::config::LoopDetectionConfig;
5
6const SEARCH_TOOLS: &[&str] = &["ctx_search", "ctx_semantic_search"];
7
8const SEARCH_SHELL_PREFIXES: &[&str] = &["grep ", "rg ", "find ", "fd ", "ag ", "ack "];
9
10const CORRECTION_WINDOW: Duration = Duration::from_mins(2);
11const MODE_BOUNCE_WINDOW: Duration = Duration::from_secs(30);
12const SHELL_RERUN_WINDOW: Duration = Duration::from_mins(1);
13const COLD_START_CALLS: u32 = 3;
14
15/// Classification of why an agent re-requested data it already had.
16#[derive(Debug, Clone, PartialEq)]
17pub enum CorrectionKind {
18    FreshReRead,
19    ShellReRun,
20    ModeBounce,
21}
22
23/// Tracks repeated tool calls within a time window to detect and throttle agent loops.
24#[derive(Debug, Clone)]
25pub struct LoopDetector {
26    call_history: HashMap<String, Vec<Instant>>,
27    duplicate_counts: HashMap<String, u32>,
28    tool_total_counts: HashMap<String, u32>,
29    tool_total_limits: HashMap<String, u32>,
30    search_group_history: Vec<Instant>,
31    recent_search_patterns: Vec<String>,
32    normal_threshold: u32,
33    reduced_threshold: u32,
34    blocked_threshold: u32,
35    window: Duration,
36    search_group_limit: u32,
37    // Correction-loop tracking (Fix A)
38    correction_signals: Vec<(Instant, CorrectionKind)>,
39    recent_reads: HashMap<String, (Instant, String)>,
40    recent_commands: HashMap<String, Instant>,
41    total_calls: u32,
42}
43
44/// Severity of throttling applied to a repeated call: normal, reduced, or blocked.
45#[derive(Debug, Clone, PartialEq)]
46pub enum ThrottleLevel {
47    Normal,
48    Reduced,
49    Blocked,
50}
51
52/// Outcome of a loop detection check: throttle level, count, and optional warning.
53#[derive(Debug, Clone)]
54pub struct ThrottleResult {
55    pub level: ThrottleLevel,
56    pub call_count: u32,
57    pub message: Option<String>,
58}
59
60impl Default for ThrottleResult {
61    fn default() -> Self {
62        Self {
63            level: ThrottleLevel::Normal,
64            call_count: 0,
65            message: None,
66        }
67    }
68}
69
70impl Default for LoopDetector {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl LoopDetector {
77    /// Creates a loop detector with default thresholds.
78    pub fn new() -> Self {
79        Self::with_config(&LoopDetectionConfig::default())
80    }
81
82    /// Creates a loop detector with custom thresholds from config.
83    /// Set blocked_threshold to 0 to disable blocking entirely (LeanCTX philosophy).
84    pub fn with_config(cfg: &LoopDetectionConfig) -> Self {
85        Self {
86            call_history: HashMap::new(),
87            duplicate_counts: HashMap::new(),
88            tool_total_counts: HashMap::new(),
89            tool_total_limits: cfg.tool_total_limits.clone(),
90            search_group_history: Vec::new(),
91            recent_search_patterns: Vec::new(),
92            normal_threshold: cfg.normal_threshold.max(1),
93            reduced_threshold: cfg.reduced_threshold.max(2),
94            blocked_threshold: cfg.blocked_threshold,
95            window: Duration::from_secs(cfg.window_secs),
96            search_group_limit: if cfg.blocked_threshold == 0 {
97                u32::MAX
98            } else {
99                cfg.search_group_limit.max(3)
100            },
101            correction_signals: Vec::new(),
102            recent_reads: HashMap::new(),
103            recent_commands: HashMap::new(),
104            total_calls: 0,
105        }
106    }
107
108    /// Records a tool call and returns the throttle result based on repetition count.
109    pub fn record_call(&mut self, tool: &str, args_fingerprint: &str) -> ThrottleResult {
110        let now = Instant::now();
111        self.prune_window(now);
112
113        // Per-tool total count (regardless of args)
114        let total = self.tool_total_counts.entry(tool.to_string()).or_insert(0);
115        *total += 1;
116        let total_count = *total;
117
118        if let Some(&limit) = self.tool_total_limits.get(tool) {
119            if total_count > limit {
120                let msg = if crate::core::protocol::meta_visible() {
121                    Some(format!(
122                        "Warning: {tool} called {total_count}x total (limit: {limit}). \
123                         Consider ctx_compress or narrowing scope."
124                    ))
125                } else {
126                    None
127                };
128                return ThrottleResult {
129                    level: ThrottleLevel::Reduced,
130                    call_count: total_count,
131                    message: msg,
132                };
133            }
134        }
135
136        let key = format!("{tool}:{args_fingerprint}");
137        let entries = self.call_history.entry(key.clone()).or_default();
138        entries.push(now);
139        let count = entries.len() as u32;
140        *self.duplicate_counts.entry(key).or_default() = count;
141
142        if self.blocked_threshold > 0 && count > self.blocked_threshold {
143            return ThrottleResult {
144                level: ThrottleLevel::Blocked,
145                call_count: count,
146                message: Some(self.block_message(tool, count)),
147            };
148        }
149        if count > self.reduced_threshold {
150            if !crate::core::protocol::meta_visible() {
151                return ThrottleResult {
152                    level: ThrottleLevel::Reduced,
153                    call_count: count,
154                    message: None,
155                };
156            }
157            return ThrottleResult {
158                level: ThrottleLevel::Reduced,
159                call_count: count,
160                message: Some(format!(
161                    "Warning: {tool} called {count}x with same args. \
162                     Results reduced. Try a different approach or narrow your scope."
163                )),
164            };
165        }
166        if count > self.normal_threshold {
167            if !crate::core::protocol::meta_visible() {
168                return ThrottleResult {
169                    level: ThrottleLevel::Reduced,
170                    call_count: count,
171                    message: None,
172                };
173            }
174            return ThrottleResult {
175                level: ThrottleLevel::Reduced,
176                call_count: count,
177                message: Some(format!(
178                    "Note: {tool} called {count}x with similar args. Consider narrowing scope."
179                )),
180            };
181        }
182        ThrottleResult {
183            level: ThrottleLevel::Normal,
184            call_count: count,
185            message: None,
186        }
187    }
188
189    /// Undo the pre-dispatch count for a call that resulted in an error.
190    /// Prevents failed retries from triggering throttling prematurely.
191    pub fn record_error_outcome(&mut self, tool: &str, args_fingerprint: &str) {
192        let key = format!("{tool}:{args_fingerprint}");
193        if let Some(entries) = self.call_history.get_mut(&key) {
194            entries.pop();
195            let count = entries.len() as u32;
196            self.duplicate_counts.insert(key, count);
197        }
198    }
199
200    /// Record a search-category call and check the cross-tool search group limit.
201    /// `search_pattern` is the extracted query/regex the agent is looking for (if available).
202    pub fn record_search(
203        &mut self,
204        tool: &str,
205        args_fingerprint: &str,
206        search_pattern: Option<&str>,
207    ) -> ThrottleResult {
208        let now = Instant::now();
209
210        self.search_group_history.push(now);
211        let search_count = self.search_group_history.len() as u32;
212
213        let similar_count = if let Some(pat) = search_pattern {
214            let sc = self.count_similar_patterns(pat);
215            if !pat.is_empty() {
216                self.recent_search_patterns.push(pat.to_string());
217                if self.recent_search_patterns.len() > 15 {
218                    self.recent_search_patterns.remove(0);
219                }
220            }
221            sc
222        } else {
223            0
224        };
225
226        // blocked_threshold == 0 means blocking is disabled (LeanCTX default)
227        if self.blocked_threshold > 0 && similar_count >= self.blocked_threshold {
228            return ThrottleResult {
229                level: ThrottleLevel::Blocked,
230                call_count: similar_count,
231                message: Some(self.search_block_message(similar_count)),
232            };
233        }
234
235        // search_group_limit == u32::MAX when blocking is disabled
236        if self.blocked_threshold > 0 && search_count > self.search_group_limit {
237            return ThrottleResult {
238                level: ThrottleLevel::Blocked,
239                call_count: search_count,
240                message: Some(self.search_group_block_message(search_count)),
241            };
242        }
243
244        if similar_count >= self.reduced_threshold {
245            if !crate::core::protocol::meta_visible() {
246                return ThrottleResult {
247                    level: ThrottleLevel::Reduced,
248                    call_count: similar_count,
249                    message: None,
250                };
251            }
252            return ThrottleResult {
253                level: ThrottleLevel::Reduced,
254                call_count: similar_count,
255                message: Some(format!(
256                    "Warning: You've searched for similar patterns {similar_count}x. \
257                     Narrow your search with the 'path' parameter or try ctx_tree first."
258                )),
259            };
260        }
261
262        if search_count > self.search_group_limit.saturating_sub(3) {
263            let per_fp = self.record_call(tool, args_fingerprint);
264            if per_fp.level != ThrottleLevel::Normal {
265                return per_fp;
266            }
267            if !crate::core::protocol::meta_visible() {
268                return ThrottleResult {
269                    level: ThrottleLevel::Reduced,
270                    call_count: search_count,
271                    message: None,
272                };
273            }
274            return ThrottleResult {
275                level: ThrottleLevel::Reduced,
276                call_count: search_count,
277                message: Some(format!(
278                    "Note: {search_count} search calls in the last {}s. \
279                     Use ctx_tree to orient first, then scope searches with 'path'.",
280                    self.window.as_secs()
281                )),
282            };
283        }
284
285        self.record_call(tool, args_fingerprint)
286    }
287
288    /// Returns `true` if the tool name is a known search tool (ctx_search, etc.).
289    pub fn is_search_tool(tool: &str) -> bool {
290        SEARCH_TOOLS.contains(&tool)
291    }
292
293    /// Returns `true` if the shell command starts with a search tool (grep, rg, find, etc.).
294    pub fn is_search_shell_command(command: &str) -> bool {
295        let cmd = command.trim_start();
296        SEARCH_SHELL_PREFIXES.iter().any(|p| cmd.starts_with(p))
297    }
298
299    /// Computes a deterministic hash fingerprint of JSON tool arguments.
300    pub fn fingerprint(args: &serde_json::Value) -> String {
301        use std::collections::hash_map::DefaultHasher;
302        use std::hash::{Hash, Hasher};
303
304        let canonical = canonical_json(args);
305        let mut hasher = DefaultHasher::new();
306        canonical.hash(&mut hasher);
307        format!("{:016x}", hasher.finish())
308    }
309
310    /// Returns duplicate call entries sorted by count (descending), filtered to count > 1.
311    pub fn stats(&self) -> Vec<(String, u32)> {
312        let mut entries: Vec<(String, u32)> = self
313            .duplicate_counts
314            .iter()
315            .filter(|(_, &count)| count > 1)
316            .map(|(k, &v)| (k.clone(), v))
317            .collect();
318        entries.sort_by_key(|x| std::cmp::Reverse(x.1));
319        entries
320    }
321
322    /// Records a ctx_read call and detects correction signals:
323    /// - `fresh=true` re-read of a previously cached file
324    /// - Mode bounce: map/signatures followed by full within 30s
325    pub fn record_read_for_correction(&mut self, path: &str, mode: &str, fresh: bool) {
326        self.total_calls += 1;
327        let now = Instant::now();
328
329        if self.total_calls <= COLD_START_CALLS {
330            self.recent_reads
331                .insert(path.to_string(), (now, mode.to_string()));
332            return;
333        }
334
335        if fresh {
336            if let Some((prev_time, _)) = self.recent_reads.get(path) {
337                if now.duration_since(*prev_time) < CORRECTION_WINDOW {
338                    self.correction_signals
339                        .push((now, CorrectionKind::FreshReRead));
340                }
341            }
342        }
343
344        if mode == "full" {
345            if let Some((prev_time, prev_mode)) = self.recent_reads.get(path) {
346                let is_bounce = (prev_mode == "map" || prev_mode == "signatures")
347                    && now.duration_since(*prev_time) < MODE_BOUNCE_WINDOW;
348                if is_bounce {
349                    self.correction_signals
350                        .push((now, CorrectionKind::ModeBounce));
351                }
352            }
353        }
354
355        self.recent_reads
356            .insert(path.to_string(), (now, mode.to_string()));
357    }
358
359    /// Records a ctx_shell command and detects re-runs of the same command within 60s.
360    pub fn record_shell_for_correction(&mut self, command: &str) {
361        self.total_calls += 1;
362        let now = Instant::now();
363
364        if self.total_calls <= COLD_START_CALLS {
365            self.recent_commands.insert(command.to_string(), now);
366            return;
367        }
368
369        let key = normalize_shell_command(command);
370        if let Some(prev_time) = self.recent_commands.get(&key) {
371            if now.duration_since(*prev_time) < SHELL_RERUN_WINDOW {
372                self.correction_signals
373                    .push((now, CorrectionKind::ShellReRun));
374            }
375        }
376        self.recent_commands.insert(key, now);
377    }
378
379    /// Returns the number of correction signals in the sliding window.
380    pub fn correction_count(&self) -> u32 {
381        let now = Instant::now();
382        self.correction_signals
383            .iter()
384            .filter(|(t, _)| now.duration_since(*t) < CORRECTION_WINDOW)
385            .count() as u32
386    }
387
388    /// Returns the correction rate: signals per minute within the window.
389    pub fn correction_rate(&self) -> f64 {
390        let count = self.correction_count();
391        if count == 0 {
392            return 0.0;
393        }
394        let window_mins = CORRECTION_WINDOW.as_secs_f64() / 60.0;
395        f64::from(count) / window_mins
396    }
397
398    /// Prunes expired correction signals and stale read/command entries.
399    pub fn prune_corrections(&mut self) {
400        let now = Instant::now();
401        self.correction_signals
402            .retain(|(t, _)| now.duration_since(*t) < CORRECTION_WINDOW);
403        self.recent_reads
404            .retain(|_, (t, _)| now.duration_since(*t) < CORRECTION_WINDOW);
405        self.recent_commands
406            .retain(|_, t| now.duration_since(*t) < CORRECTION_WINDOW);
407    }
408
409    /// Clears all tracking state (call history, search patterns, counters).
410    pub fn reset(&mut self) {
411        self.call_history.clear();
412        self.duplicate_counts.clear();
413        self.search_group_history.clear();
414        self.recent_search_patterns.clear();
415        self.correction_signals.clear();
416        self.recent_reads.clear();
417        self.recent_commands.clear();
418        self.total_calls = 0;
419    }
420
421    fn prune_window(&mut self, now: Instant) {
422        for entries in self.call_history.values_mut() {
423            entries.retain(|t| now.duration_since(*t) < self.window);
424        }
425        self.search_group_history
426            .retain(|t| now.duration_since(*t) < self.window);
427    }
428
429    fn count_similar_patterns(&self, new_pattern: &str) -> u32 {
430        let new_lower = new_pattern.to_lowercase();
431        let new_root = extract_alpha_root(&new_lower);
432
433        let mut count = 0u32;
434        for existing in &self.recent_search_patterns {
435            let existing_lower = existing.to_lowercase();
436            if patterns_are_similar(&new_lower, &existing_lower) {
437                count += 1;
438            } else if new_root.len() >= 4 {
439                let existing_root = extract_alpha_root(&existing_lower);
440                if existing_root.len() >= 4
441                    && (new_root.starts_with(&existing_root)
442                        || existing_root.starts_with(&new_root))
443                {
444                    count += 1;
445                }
446            }
447        }
448        count
449    }
450
451    fn block_message(&self, tool: &str, count: u32) -> String {
452        if Self::is_search_tool(tool) {
453            self.search_block_message(count)
454        } else {
455            format!(
456                "LOOP DETECTED: {tool} called {count}x with same/similar args. \
457                 Call blocked. Change your approach — the current strategy is not working."
458            )
459        }
460    }
461
462    #[allow(clippy::unused_self)]
463    fn search_block_message(&self, count: u32) -> String {
464        format!(
465            "LOOP DETECTED: You've searched {count}x with similar patterns. STOP searching and change strategy. \
466             1) Use ctx_tree to understand the project structure first. \
467             2) Narrow your search with the 'path' parameter to a specific directory. \
468             3) Use ctx_read with mode='map' to understand a file before searching more."
469        )
470    }
471
472    fn search_group_block_message(&self, count: u32) -> String {
473        format!(
474            "LOOP DETECTED: {count} search calls in {}s — too many. STOP and rethink. \
475             1) Use ctx_tree to map the project structure. \
476             2) Pick ONE specific directory and search there with the 'path' parameter. \
477             3) Read files with ctx_read mode='map' instead of searching blindly.",
478            self.window.as_secs()
479        )
480    }
481}
482
483fn normalize_shell_command(cmd: &str) -> String {
484    cmd.split_whitespace()
485        .take(5)
486        .collect::<Vec<_>>()
487        .join(" ")
488        .to_lowercase()
489}
490
491fn extract_alpha_root(pattern: &str) -> String {
492    pattern
493        .chars()
494        .take_while(|c| c.is_alphanumeric())
495        .collect()
496}
497
498fn patterns_are_similar(a: &str, b: &str) -> bool {
499    if a == b {
500        return true;
501    }
502    if a.contains(b) || b.contains(a) {
503        return true;
504    }
505    let a_alpha: String = a.chars().filter(|c| c.is_alphanumeric()).collect();
506    let b_alpha: String = b.chars().filter(|c| c.is_alphanumeric()).collect();
507    if a_alpha.len() >= 3
508        && b_alpha.len() >= 3
509        && (a_alpha.contains(&b_alpha) || b_alpha.contains(&a_alpha))
510    {
511        return true;
512    }
513    false
514}
515
516fn canonical_json(value: &serde_json::Value) -> String {
517    match value {
518        serde_json::Value::Object(map) => {
519            let mut keys: Vec<&String> = map.keys().collect();
520            keys.sort();
521            let entries: Vec<String> = keys
522                .iter()
523                .map(|k| format!("{}:{}", k, canonical_json(&map[*k])))
524                .collect();
525            format!("{{{}}}", entries.join(","))
526        }
527        serde_json::Value::Array(arr) => {
528            let entries: Vec<String> = arr.iter().map(canonical_json).collect();
529            format!("[{}]", entries.join(","))
530        }
531        _ => value.to_string(),
532    }
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538
539    fn test_config(normal: u32, reduced: u32, blocked: u32) -> LoopDetectionConfig {
540        LoopDetectionConfig {
541            normal_threshold: normal,
542            reduced_threshold: reduced,
543            blocked_threshold: blocked,
544            window_secs: 300,
545            search_group_limit: 10,
546            tool_total_limits: std::collections::HashMap::new(),
547        }
548    }
549
550    #[test]
551    fn normal_calls_pass_through() {
552        let mut detector = LoopDetector::new();
553        let r1 = detector.record_call("ctx_read", "abc123");
554        assert_eq!(r1.level, ThrottleLevel::Normal);
555        assert_eq!(r1.call_count, 1);
556        assert!(r1.message.is_none());
557    }
558
559    #[test]
560    fn repeated_calls_trigger_reduced() {
561        let _lock = crate::core::data_dir::test_env_lock();
562        std::env::set_var("LEAN_CTX_META", "1");
563        let cfg = LoopDetectionConfig::default();
564        let mut detector = LoopDetector::with_config(&cfg);
565        for _ in 0..cfg.normal_threshold {
566            detector.record_call("ctx_read", "same_fp");
567        }
568        let result = detector.record_call("ctx_read", "same_fp");
569        assert_eq!(result.level, ThrottleLevel::Reduced);
570        assert!(result.message.is_some());
571        std::env::remove_var("LEAN_CTX_META");
572    }
573
574    #[test]
575    fn excessive_calls_get_blocked_when_enabled() {
576        // Blocking must be explicitly enabled (blocked_threshold > 0)
577        let cfg = LoopDetectionConfig {
578            blocked_threshold: 6,
579            ..Default::default()
580        };
581        let mut detector = LoopDetector::with_config(&cfg);
582        for _ in 0..cfg.blocked_threshold {
583            detector.record_call("ctx_shell", "same_fp");
584        }
585        let result = detector.record_call("ctx_shell", "same_fp");
586        assert_eq!(result.level, ThrottleLevel::Blocked);
587        assert!(result.message.unwrap().contains("LOOP DETECTED"));
588    }
589
590    #[test]
591    fn blocking_disabled_by_default() {
592        // Default config has blocked_threshold = 0, so blocking never happens
593        let cfg = LoopDetectionConfig::default();
594        assert_eq!(cfg.blocked_threshold, 0);
595        let mut detector = LoopDetector::with_config(&cfg);
596        // Even 100 calls should not block when blocking is disabled
597        for _ in 0..100 {
598            detector.record_call("ctx_shell", "same_fp");
599        }
600        let result = detector.record_call("ctx_shell", "same_fp");
601        // Should be Reduced (warning) but never Blocked
602        assert_ne!(result.level, ThrottleLevel::Blocked);
603    }
604
605    #[test]
606    fn different_args_tracked_separately() {
607        let mut detector = LoopDetector::new();
608        for _ in 0..10 {
609            detector.record_call("ctx_read", "fp_a");
610        }
611        let result = detector.record_call("ctx_read", "fp_b");
612        assert_eq!(result.level, ThrottleLevel::Normal);
613        assert_eq!(result.call_count, 1);
614    }
615
616    #[test]
617    fn fingerprint_deterministic() {
618        let args = serde_json::json!({"path": "test.rs", "mode": "full"});
619        let fp1 = LoopDetector::fingerprint(&args);
620        let fp2 = LoopDetector::fingerprint(&args);
621        assert_eq!(fp1, fp2);
622    }
623
624    #[test]
625    fn fingerprint_order_independent() {
626        let a = serde_json::json!({"mode": "full", "path": "test.rs"});
627        let b = serde_json::json!({"path": "test.rs", "mode": "full"});
628        assert_eq!(LoopDetector::fingerprint(&a), LoopDetector::fingerprint(&b));
629    }
630
631    #[test]
632    fn stats_shows_duplicates() {
633        let mut detector = LoopDetector::new();
634        for _ in 0..5 {
635            detector.record_call("ctx_read", "fp_a");
636        }
637        detector.record_call("ctx_shell", "fp_b");
638        let stats = detector.stats();
639        assert_eq!(stats.len(), 1);
640        assert_eq!(stats[0].1, 5);
641    }
642
643    #[test]
644    fn reset_clears_state() {
645        let mut detector = LoopDetector::new();
646        for _ in 0..5 {
647            detector.record_call("ctx_read", "fp_a");
648        }
649        detector.reset();
650        let result = detector.record_call("ctx_read", "fp_a");
651        assert_eq!(result.call_count, 1);
652    }
653
654    #[test]
655    fn custom_thresholds_from_config() {
656        let cfg = test_config(1, 2, 3);
657        let mut detector = LoopDetector::with_config(&cfg);
658        detector.record_call("ctx_read", "fp");
659        let r = detector.record_call("ctx_read", "fp");
660        assert_eq!(r.level, ThrottleLevel::Reduced);
661        detector.record_call("ctx_read", "fp");
662        let r = detector.record_call("ctx_read", "fp");
663        assert_eq!(r.level, ThrottleLevel::Blocked);
664    }
665
666    #[test]
667    fn similar_patterns_detected() {
668        assert!(patterns_are_similar("compress", "compress"));
669        assert!(patterns_are_similar("compress", "compression"));
670        assert!(patterns_are_similar("compress.*data", "compress"));
671        assert!(!patterns_are_similar("foo", "bar"));
672        assert!(!patterns_are_similar("ab", "cd"));
673    }
674
675    #[test]
676    fn search_group_tracking_when_blocking_enabled() {
677        // Blocking must be explicitly enabled for search group limits to block
678        let cfg = LoopDetectionConfig {
679            search_group_limit: 5,
680            blocked_threshold: 6, // Enable blocking
681            ..Default::default()
682        };
683        let mut detector = LoopDetector::with_config(&cfg);
684        for i in 0..5 {
685            let fp = format!("fp_{i}");
686            let r = detector.record_search("ctx_search", &fp, Some(&format!("pattern_{i}")));
687            assert_ne!(r.level, ThrottleLevel::Blocked, "call {i} should not block");
688        }
689        let r = detector.record_search("ctx_search", "fp_5", Some("pattern_5"));
690        assert_eq!(r.level, ThrottleLevel::Blocked);
691        assert!(r.message.unwrap().contains("search calls"));
692    }
693
694    #[test]
695    fn similar_search_patterns_trigger_block_when_enabled() {
696        // Blocking must be explicitly enabled
697        let cfg = LoopDetectionConfig {
698            blocked_threshold: 6,
699            ..Default::default()
700        };
701        let mut detector = LoopDetector::with_config(&cfg);
702        let variants = [
703            "compress",
704            "compression",
705            "compress.*data",
706            "compress_output",
707            "compressor",
708            "compress_result",
709            "compress_file",
710        ];
711        for (i, pat) in variants
712            .iter()
713            .enumerate()
714            .take(cfg.blocked_threshold as usize)
715        {
716            detector.record_search("ctx_search", &format!("fp_{i}"), Some(pat));
717        }
718        let r = detector.record_search("ctx_search", "fp_new", Some("compress_all"));
719        assert_eq!(r.level, ThrottleLevel::Blocked);
720    }
721
722    #[test]
723    fn is_search_tool_detection() {
724        assert!(LoopDetector::is_search_tool("ctx_search"));
725        assert!(LoopDetector::is_search_tool("ctx_semantic_search"));
726        assert!(!LoopDetector::is_search_tool("ctx_read"));
727        assert!(!LoopDetector::is_search_tool("ctx_shell"));
728    }
729
730    #[test]
731    fn is_search_shell_command_detection() {
732        assert!(LoopDetector::is_search_shell_command("grep -r foo ."));
733        assert!(LoopDetector::is_search_shell_command("rg pattern src/"));
734        assert!(LoopDetector::is_search_shell_command("find . -name '*.rs'"));
735        assert!(!LoopDetector::is_search_shell_command("cargo build"));
736        assert!(!LoopDetector::is_search_shell_command("git status"));
737    }
738
739    #[test]
740    fn correction_fresh_reread_detected() {
741        let mut detector = LoopDetector::new();
742        // First read (cold start period, skipped)
743        detector.record_read_for_correction("src/main.rs", "full", false);
744        detector.record_read_for_correction("src/lib.rs", "full", false);
745        detector.record_read_for_correction("src/util.rs", "full", false);
746        // 4th call: past cold start
747        detector.record_read_for_correction("src/main.rs", "full", false);
748        assert_eq!(detector.correction_count(), 0);
749        // fresh=true re-read of previously read file = correction signal
750        detector.record_read_for_correction("src/main.rs", "full", true);
751        assert_eq!(detector.correction_count(), 1);
752    }
753
754    #[test]
755    fn correction_mode_bounce_detected() {
756        let mut detector = LoopDetector::new();
757        // Cold start
758        for i in 0..COLD_START_CALLS {
759            detector.record_read_for_correction(&format!("f{i}.rs"), "full", false);
760        }
761        // Read with map mode
762        detector.record_read_for_correction("src/cache.rs", "map", false);
763        assert_eq!(detector.correction_count(), 0);
764        // Immediately bounce to full mode = correction
765        detector.record_read_for_correction("src/cache.rs", "full", false);
766        assert_eq!(detector.correction_count(), 1);
767    }
768
769    #[test]
770    fn correction_shell_rerun_detected() {
771        let mut detector = LoopDetector::new();
772        // Cold start
773        for i in 0..COLD_START_CALLS {
774            detector.record_shell_for_correction(&format!("echo {i}"));
775        }
776        // First run
777        detector.record_shell_for_correction("cargo test --lib");
778        assert_eq!(detector.correction_count(), 0);
779        // Same command again within 60s = correction
780        detector.record_shell_for_correction("cargo test --lib");
781        assert_eq!(detector.correction_count(), 1);
782    }
783
784    #[test]
785    fn correction_rate_calculation() {
786        let mut detector = LoopDetector::new();
787        for i in 0..COLD_START_CALLS {
788            detector.record_shell_for_correction(&format!("init{i}"));
789        }
790        detector.record_shell_for_correction("cargo check");
791        detector.record_shell_for_correction("cargo check");
792        detector.record_shell_for_correction("cargo check");
793        // 2 corrections (first run doesn't count)
794        assert_eq!(detector.correction_count(), 2);
795        assert!(detector.correction_rate() > 0.0);
796    }
797
798    #[test]
799    fn correction_cold_start_ignored() {
800        let mut detector = LoopDetector::new();
801        // During cold start, same-command re-runs are not counted
802        detector.record_shell_for_correction("cargo check");
803        detector.record_shell_for_correction("cargo check");
804        detector.record_shell_for_correction("cargo check");
805        assert_eq!(detector.correction_count(), 0);
806    }
807
808    #[test]
809    fn search_block_message_has_guidance_when_blocking_enabled() {
810        // Blocking must be explicitly enabled to get block messages
811        let cfg = LoopDetectionConfig {
812            blocked_threshold: 6,
813            search_group_limit: 8,
814            ..Default::default()
815        };
816        let mut detector = LoopDetector::with_config(&cfg);
817        for i in 0..10 {
818            detector.record_search("ctx_search", &format!("fp_{i}"), Some("compress"));
819        }
820        let r = detector.record_search("ctx_search", "fp_new", Some("compress"));
821        assert_eq!(r.level, ThrottleLevel::Blocked);
822        let msg = r.message.unwrap();
823        assert!(msg.contains("ctx_tree"));
824        assert!(msg.contains("path"));
825        assert!(msg.contains("ctx_read"));
826    }
827
828    #[test]
829    fn error_outcome_undoes_pre_dispatch_count() {
830        let cfg = test_config(2, 4, 0);
831        let mut detector = LoopDetector::with_config(&cfg);
832
833        detector.record_call("ctx_read", "fp1");
834        detector.record_call("ctx_read", "fp1");
835        detector.record_error_outcome("ctx_read", "fp1");
836
837        let r = detector.record_call("ctx_read", "fp1");
838        assert_eq!(r.call_count, 2, "error should have undone one count");
839        assert_eq!(r.level, ThrottleLevel::Normal);
840    }
841
842    #[test]
843    fn repeated_errors_dont_trigger_reduced() {
844        let cfg = test_config(2, 4, 0);
845        let mut detector = LoopDetector::with_config(&cfg);
846
847        for _ in 0..5 {
848            detector.record_call("ctx_read", "fp1");
849            detector.record_error_outcome("ctx_read", "fp1");
850        }
851
852        let r = detector.record_call("ctx_read", "fp1");
853        assert_eq!(
854            r.level,
855            ThrottleLevel::Normal,
856            "5 failed retries should not throttle"
857        );
858    }
859
860    #[test]
861    fn mixed_success_and_error_correct_count() {
862        let cfg = test_config(2, 4, 0);
863        let mut detector = LoopDetector::with_config(&cfg);
864
865        detector.record_call("ctx_read", "fp1");
866        detector.record_error_outcome("ctx_read", "fp1");
867        detector.record_call("ctx_read", "fp1");
868        detector.record_error_outcome("ctx_read", "fp1");
869        detector.record_call("ctx_read", "fp1");
870        // 3 pre-dispatch, 2 error undos -> effective count = 1
871        assert_eq!(detector.record_call("ctx_read", "fp1").call_count, 2);
872    }
873
874    #[test]
875    fn error_outcome_on_nonexistent_key_is_noop() {
876        let mut detector = LoopDetector::new();
877        detector.record_error_outcome("ctx_read", "never_called");
878        let r = detector.record_call("ctx_read", "never_called");
879        assert_eq!(r.call_count, 1);
880    }
881
882    #[test]
883    fn error_outcome_doesnt_go_negative() {
884        let mut detector = LoopDetector::new();
885        detector.record_call("ctx_read", "fp1");
886        detector.record_error_outcome("ctx_read", "fp1");
887        detector.record_error_outcome("ctx_read", "fp1");
888        let r = detector.record_call("ctx_read", "fp1");
889        assert_eq!(r.call_count, 1, "count should never go below 0");
890    }
891
892    #[test]
893    fn error_in_tool_a_doesnt_affect_tool_b() {
894        let cfg = test_config(2, 4, 0);
895        let mut detector = LoopDetector::with_config(&cfg);
896
897        for _ in 0..5 {
898            detector.record_call("ctx_read", "fp1");
899            detector.record_error_outcome("ctx_read", "fp1");
900        }
901
902        let r = detector.record_call("ctx_shell", "fp_shell");
903        assert_eq!(r.call_count, 1);
904        assert_eq!(r.level, ThrottleLevel::Normal);
905    }
906
907    #[test]
908    fn different_fingerprints_independent_after_errors() {
909        let cfg = test_config(2, 4, 0);
910        let mut detector = LoopDetector::with_config(&cfg);
911
912        detector.record_call("ctx_read", "fp_a");
913        detector.record_error_outcome("ctx_read", "fp_a");
914
915        detector.record_call("ctx_read", "fp_b");
916        let r = detector.record_call("ctx_read", "fp_b");
917        assert_eq!(r.call_count, 2);
918
919        let r_a = detector.record_call("ctx_read", "fp_a");
920        assert_eq!(r_a.call_count, 1, "fp_a count should be reset to 0 then +1");
921    }
922
923    #[test]
924    fn correction_degrade_recovery_after_prune() {
925        let mut detector = LoopDetector::new();
926        for i in 0..4u32 {
927            detector.record_read_for_correction(&format!("warmup{i}.rs"), "full", false);
928        }
929        detector.record_read_for_correction("target.rs", "full", false);
930        detector.record_read_for_correction("target.rs", "full", true);
931        assert!(detector.correction_count() > 0);
932        detector.prune_corrections();
933        // After prune, count is still > 0 because signal is within window
934        // but the mechanism works: once window expires, count drops
935        assert!(detector.correction_count() >= 1);
936    }
937
938    #[test]
939    fn success_after_errors_resets_to_normal() {
940        let cfg = test_config(2, 4, 0);
941        let mut detector = LoopDetector::with_config(&cfg);
942
943        for _ in 0..3 {
944            detector.record_call("ctx_read", "fp1");
945            detector.record_error_outcome("ctx_read", "fp1");
946        }
947
948        let r = detector.record_call("ctx_read", "fp1");
949        assert_eq!(r.level, ThrottleLevel::Normal);
950        assert_eq!(r.call_count, 1);
951    }
952
953    #[test]
954    fn throttle_result_default_is_normal() {
955        let r = ThrottleResult::default();
956        assert_eq!(r.level, ThrottleLevel::Normal);
957        assert_eq!(r.call_count, 0);
958        assert!(r.message.is_none());
959    }
960}