Skip to main content

lean_ctx/core/
loop_detection.rs

1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use super::config::LoopDetectionConfig;
5
6const SEARCH_TOOLS: &[&str] = &["ctx_search", "ctx_semantic_search"];
7
8const SEARCH_SHELL_PREFIXES: &[&str] = &["grep ", "rg ", "find ", "fd ", "ag ", "ack "];
9
10const CORRECTION_WINDOW: Duration = Duration::from_mins(2);
11const MODE_BOUNCE_WINDOW: Duration = Duration::from_secs(30);
12const SHELL_RERUN_WINDOW: Duration = Duration::from_mins(1);
13const COLD_START_CALLS: u32 = 3;
14
15/// Classification of why an agent re-requested data it already had.
16#[derive(Debug, Clone, PartialEq)]
17pub enum CorrectionKind {
18    FreshReRead,
19    ShellReRun,
20    ModeBounce,
21}
22
23/// Tracks repeated tool calls within a time window to detect and throttle agent loops.
24#[derive(Debug, Clone)]
25pub struct LoopDetector {
26    call_history: HashMap<String, Vec<Instant>>,
27    duplicate_counts: HashMap<String, u32>,
28    tool_total_counts: HashMap<String, u32>,
29    tool_total_limits: HashMap<String, u32>,
30    search_group_history: Vec<Instant>,
31    recent_search_patterns: Vec<String>,
32    normal_threshold: u32,
33    reduced_threshold: u32,
34    blocked_threshold: u32,
35    window: Duration,
36    search_group_limit: u32,
37    // Correction-loop tracking (Fix A)
38    correction_signals: Vec<(Instant, CorrectionKind)>,
39    recent_reads: HashMap<String, (Instant, String)>,
40    recent_commands: HashMap<String, Instant>,
41    total_calls: u32,
42}
43
44/// Severity of throttling applied to a repeated call: normal, reduced, or blocked.
45#[derive(Debug, Clone, PartialEq)]
46pub enum ThrottleLevel {
47    Normal,
48    Reduced,
49    Blocked,
50}
51
52/// Outcome of a loop detection check: throttle level, count, and optional warning.
53#[derive(Debug, Clone)]
54pub struct ThrottleResult {
55    pub level: ThrottleLevel,
56    pub call_count: u32,
57    pub message: Option<String>,
58}
59
60impl Default for LoopDetector {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl LoopDetector {
67    /// Creates a loop detector with default thresholds.
68    pub fn new() -> Self {
69        Self::with_config(&LoopDetectionConfig::default())
70    }
71
72    /// Creates a loop detector with custom thresholds from config.
73    /// Set blocked_threshold to 0 to disable blocking entirely (LeanCTX philosophy).
74    pub fn with_config(cfg: &LoopDetectionConfig) -> Self {
75        Self {
76            call_history: HashMap::new(),
77            duplicate_counts: HashMap::new(),
78            tool_total_counts: HashMap::new(),
79            tool_total_limits: cfg.tool_total_limits.clone(),
80            search_group_history: Vec::new(),
81            recent_search_patterns: Vec::new(),
82            normal_threshold: cfg.normal_threshold.max(1),
83            reduced_threshold: cfg.reduced_threshold.max(2),
84            blocked_threshold: cfg.blocked_threshold,
85            window: Duration::from_secs(cfg.window_secs),
86            search_group_limit: if cfg.blocked_threshold == 0 {
87                u32::MAX
88            } else {
89                cfg.search_group_limit.max(3)
90            },
91            correction_signals: Vec::new(),
92            recent_reads: HashMap::new(),
93            recent_commands: HashMap::new(),
94            total_calls: 0,
95        }
96    }
97
98    /// Records a tool call and returns the throttle result based on repetition count.
99    pub fn record_call(&mut self, tool: &str, args_fingerprint: &str) -> ThrottleResult {
100        let now = Instant::now();
101        self.prune_window(now);
102
103        // Per-tool total count (regardless of args)
104        let total = self.tool_total_counts.entry(tool.to_string()).or_insert(0);
105        *total += 1;
106        let total_count = *total;
107
108        if let Some(&limit) = self.tool_total_limits.get(tool) {
109            if total_count > limit {
110                let msg = if crate::core::protocol::meta_visible() {
111                    Some(format!(
112                        "Warning: {tool} called {total_count}x total (limit: {limit}). \
113                         Consider ctx_compress or narrowing scope."
114                    ))
115                } else {
116                    None
117                };
118                return ThrottleResult {
119                    level: ThrottleLevel::Reduced,
120                    call_count: total_count,
121                    message: msg,
122                };
123            }
124        }
125
126        let key = format!("{tool}:{args_fingerprint}");
127        let entries = self.call_history.entry(key.clone()).or_default();
128        entries.push(now);
129        let count = entries.len() as u32;
130        *self.duplicate_counts.entry(key).or_default() = count;
131
132        if self.blocked_threshold > 0 && count > self.blocked_threshold {
133            return ThrottleResult {
134                level: ThrottleLevel::Blocked,
135                call_count: count,
136                message: Some(self.block_message(tool, count)),
137            };
138        }
139        if count > self.reduced_threshold {
140            if !crate::core::protocol::meta_visible() {
141                return ThrottleResult {
142                    level: ThrottleLevel::Reduced,
143                    call_count: count,
144                    message: None,
145                };
146            }
147            return ThrottleResult {
148                level: ThrottleLevel::Reduced,
149                call_count: count,
150                message: Some(format!(
151                    "Warning: {tool} called {count}x with same args. \
152                     Results reduced. Try a different approach or narrow your scope."
153                )),
154            };
155        }
156        if count > self.normal_threshold {
157            if !crate::core::protocol::meta_visible() {
158                return ThrottleResult {
159                    level: ThrottleLevel::Reduced,
160                    call_count: count,
161                    message: None,
162                };
163            }
164            return ThrottleResult {
165                level: ThrottleLevel::Reduced,
166                call_count: count,
167                message: Some(format!(
168                    "Note: {tool} called {count}x with similar args. Consider narrowing scope."
169                )),
170            };
171        }
172        ThrottleResult {
173            level: ThrottleLevel::Normal,
174            call_count: count,
175            message: None,
176        }
177    }
178
179    /// Record a search-category call and check the cross-tool search group limit.
180    /// `search_pattern` is the extracted query/regex the agent is looking for (if available).
181    pub fn record_search(
182        &mut self,
183        tool: &str,
184        args_fingerprint: &str,
185        search_pattern: Option<&str>,
186    ) -> ThrottleResult {
187        let now = Instant::now();
188
189        self.search_group_history.push(now);
190        let search_count = self.search_group_history.len() as u32;
191
192        let similar_count = if let Some(pat) = search_pattern {
193            let sc = self.count_similar_patterns(pat);
194            if !pat.is_empty() {
195                self.recent_search_patterns.push(pat.to_string());
196                if self.recent_search_patterns.len() > 15 {
197                    self.recent_search_patterns.remove(0);
198                }
199            }
200            sc
201        } else {
202            0
203        };
204
205        // blocked_threshold == 0 means blocking is disabled (LeanCTX default)
206        if self.blocked_threshold > 0 && similar_count >= self.blocked_threshold {
207            return ThrottleResult {
208                level: ThrottleLevel::Blocked,
209                call_count: similar_count,
210                message: Some(self.search_block_message(similar_count)),
211            };
212        }
213
214        // search_group_limit == u32::MAX when blocking is disabled
215        if self.blocked_threshold > 0 && search_count > self.search_group_limit {
216            return ThrottleResult {
217                level: ThrottleLevel::Blocked,
218                call_count: search_count,
219                message: Some(self.search_group_block_message(search_count)),
220            };
221        }
222
223        if similar_count >= self.reduced_threshold {
224            if !crate::core::protocol::meta_visible() {
225                return ThrottleResult {
226                    level: ThrottleLevel::Reduced,
227                    call_count: similar_count,
228                    message: None,
229                };
230            }
231            return ThrottleResult {
232                level: ThrottleLevel::Reduced,
233                call_count: similar_count,
234                message: Some(format!(
235                    "Warning: You've searched for similar patterns {similar_count}x. \
236                     Narrow your search with the 'path' parameter or try ctx_tree first."
237                )),
238            };
239        }
240
241        if search_count > self.search_group_limit.saturating_sub(3) {
242            let per_fp = self.record_call(tool, args_fingerprint);
243            if per_fp.level != ThrottleLevel::Normal {
244                return per_fp;
245            }
246            if !crate::core::protocol::meta_visible() {
247                return ThrottleResult {
248                    level: ThrottleLevel::Reduced,
249                    call_count: search_count,
250                    message: None,
251                };
252            }
253            return ThrottleResult {
254                level: ThrottleLevel::Reduced,
255                call_count: search_count,
256                message: Some(format!(
257                    "Note: {search_count} search calls in the last {}s. \
258                     Use ctx_tree to orient first, then scope searches with 'path'.",
259                    self.window.as_secs()
260                )),
261            };
262        }
263
264        self.record_call(tool, args_fingerprint)
265    }
266
267    /// Returns `true` if the tool name is a known search tool (ctx_search, etc.).
268    pub fn is_search_tool(tool: &str) -> bool {
269        SEARCH_TOOLS.contains(&tool)
270    }
271
272    /// Returns `true` if the shell command starts with a search tool (grep, rg, find, etc.).
273    pub fn is_search_shell_command(command: &str) -> bool {
274        let cmd = command.trim_start();
275        SEARCH_SHELL_PREFIXES.iter().any(|p| cmd.starts_with(p))
276    }
277
278    /// Computes a deterministic hash fingerprint of JSON tool arguments.
279    pub fn fingerprint(args: &serde_json::Value) -> String {
280        use std::collections::hash_map::DefaultHasher;
281        use std::hash::{Hash, Hasher};
282
283        let canonical = canonical_json(args);
284        let mut hasher = DefaultHasher::new();
285        canonical.hash(&mut hasher);
286        format!("{:016x}", hasher.finish())
287    }
288
289    /// Returns duplicate call entries sorted by count (descending), filtered to count > 1.
290    pub fn stats(&self) -> Vec<(String, u32)> {
291        let mut entries: Vec<(String, u32)> = self
292            .duplicate_counts
293            .iter()
294            .filter(|(_, &count)| count > 1)
295            .map(|(k, &v)| (k.clone(), v))
296            .collect();
297        entries.sort_by_key(|x| std::cmp::Reverse(x.1));
298        entries
299    }
300
301    /// Records a ctx_read call and detects correction signals:
302    /// - `fresh=true` re-read of a previously cached file
303    /// - Mode bounce: map/signatures followed by full within 30s
304    pub fn record_read_for_correction(&mut self, path: &str, mode: &str, fresh: bool) {
305        self.total_calls += 1;
306        let now = Instant::now();
307
308        if self.total_calls <= COLD_START_CALLS {
309            self.recent_reads
310                .insert(path.to_string(), (now, mode.to_string()));
311            return;
312        }
313
314        if fresh {
315            if let Some((prev_time, _)) = self.recent_reads.get(path) {
316                if now.duration_since(*prev_time) < CORRECTION_WINDOW {
317                    self.correction_signals
318                        .push((now, CorrectionKind::FreshReRead));
319                }
320            }
321        }
322
323        if mode == "full" {
324            if let Some((prev_time, prev_mode)) = self.recent_reads.get(path) {
325                let is_bounce = (prev_mode == "map" || prev_mode == "signatures")
326                    && now.duration_since(*prev_time) < MODE_BOUNCE_WINDOW;
327                if is_bounce {
328                    self.correction_signals
329                        .push((now, CorrectionKind::ModeBounce));
330                }
331            }
332        }
333
334        self.recent_reads
335            .insert(path.to_string(), (now, mode.to_string()));
336    }
337
338    /// Records a ctx_shell command and detects re-runs of the same command within 60s.
339    pub fn record_shell_for_correction(&mut self, command: &str) {
340        self.total_calls += 1;
341        let now = Instant::now();
342
343        if self.total_calls <= COLD_START_CALLS {
344            self.recent_commands.insert(command.to_string(), now);
345            return;
346        }
347
348        let key = normalize_shell_command(command);
349        if let Some(prev_time) = self.recent_commands.get(&key) {
350            if now.duration_since(*prev_time) < SHELL_RERUN_WINDOW {
351                self.correction_signals
352                    .push((now, CorrectionKind::ShellReRun));
353            }
354        }
355        self.recent_commands.insert(key, now);
356    }
357
358    /// Returns the number of correction signals in the sliding window.
359    pub fn correction_count(&self) -> u32 {
360        let now = Instant::now();
361        self.correction_signals
362            .iter()
363            .filter(|(t, _)| now.duration_since(*t) < CORRECTION_WINDOW)
364            .count() as u32
365    }
366
367    /// Returns the correction rate: signals per minute within the window.
368    pub fn correction_rate(&self) -> f64 {
369        let count = self.correction_count();
370        if count == 0 {
371            return 0.0;
372        }
373        let window_mins = CORRECTION_WINDOW.as_secs_f64() / 60.0;
374        f64::from(count) / window_mins
375    }
376
377    /// Prunes expired correction signals and stale read/command entries.
378    pub fn prune_corrections(&mut self) {
379        let now = Instant::now();
380        self.correction_signals
381            .retain(|(t, _)| now.duration_since(*t) < CORRECTION_WINDOW);
382        self.recent_reads
383            .retain(|_, (t, _)| now.duration_since(*t) < CORRECTION_WINDOW);
384        self.recent_commands
385            .retain(|_, t| now.duration_since(*t) < CORRECTION_WINDOW);
386    }
387
388    /// Clears all tracking state (call history, search patterns, counters).
389    pub fn reset(&mut self) {
390        self.call_history.clear();
391        self.duplicate_counts.clear();
392        self.search_group_history.clear();
393        self.recent_search_patterns.clear();
394        self.correction_signals.clear();
395        self.recent_reads.clear();
396        self.recent_commands.clear();
397        self.total_calls = 0;
398    }
399
400    fn prune_window(&mut self, now: Instant) {
401        for entries in self.call_history.values_mut() {
402            entries.retain(|t| now.duration_since(*t) < self.window);
403        }
404        self.search_group_history
405            .retain(|t| now.duration_since(*t) < self.window);
406    }
407
408    fn count_similar_patterns(&self, new_pattern: &str) -> u32 {
409        let new_lower = new_pattern.to_lowercase();
410        let new_root = extract_alpha_root(&new_lower);
411
412        let mut count = 0u32;
413        for existing in &self.recent_search_patterns {
414            let existing_lower = existing.to_lowercase();
415            if patterns_are_similar(&new_lower, &existing_lower) {
416                count += 1;
417            } else if new_root.len() >= 4 {
418                let existing_root = extract_alpha_root(&existing_lower);
419                if existing_root.len() >= 4
420                    && (new_root.starts_with(&existing_root)
421                        || existing_root.starts_with(&new_root))
422                {
423                    count += 1;
424                }
425            }
426        }
427        count
428    }
429
430    fn block_message(&self, tool: &str, count: u32) -> String {
431        if Self::is_search_tool(tool) {
432            self.search_block_message(count)
433        } else {
434            format!(
435                "LOOP DETECTED: {tool} called {count}x with same/similar args. \
436                 Call blocked. Change your approach — the current strategy is not working."
437            )
438        }
439    }
440
441    #[allow(clippy::unused_self)]
442    fn search_block_message(&self, count: u32) -> String {
443        format!(
444            "LOOP DETECTED: You've searched {count}x with similar patterns. STOP searching and change strategy. \
445             1) Use ctx_tree to understand the project structure first. \
446             2) Narrow your search with the 'path' parameter to a specific directory. \
447             3) Use ctx_read with mode='map' to understand a file before searching more."
448        )
449    }
450
451    fn search_group_block_message(&self, count: u32) -> String {
452        format!(
453            "LOOP DETECTED: {count} search calls in {}s — too many. STOP and rethink. \
454             1) Use ctx_tree to map the project structure. \
455             2) Pick ONE specific directory and search there with the 'path' parameter. \
456             3) Read files with ctx_read mode='map' instead of searching blindly.",
457            self.window.as_secs()
458        )
459    }
460}
461
462fn normalize_shell_command(cmd: &str) -> String {
463    cmd.split_whitespace()
464        .take(5)
465        .collect::<Vec<_>>()
466        .join(" ")
467        .to_lowercase()
468}
469
470fn extract_alpha_root(pattern: &str) -> String {
471    pattern
472        .chars()
473        .take_while(|c| c.is_alphanumeric())
474        .collect()
475}
476
477fn patterns_are_similar(a: &str, b: &str) -> bool {
478    if a == b {
479        return true;
480    }
481    if a.contains(b) || b.contains(a) {
482        return true;
483    }
484    let a_alpha: String = a.chars().filter(|c| c.is_alphanumeric()).collect();
485    let b_alpha: String = b.chars().filter(|c| c.is_alphanumeric()).collect();
486    if a_alpha.len() >= 3
487        && b_alpha.len() >= 3
488        && (a_alpha.contains(&b_alpha) || b_alpha.contains(&a_alpha))
489    {
490        return true;
491    }
492    false
493}
494
495fn canonical_json(value: &serde_json::Value) -> String {
496    match value {
497        serde_json::Value::Object(map) => {
498            let mut keys: Vec<&String> = map.keys().collect();
499            keys.sort();
500            let entries: Vec<String> = keys
501                .iter()
502                .map(|k| format!("{}:{}", k, canonical_json(&map[*k])))
503                .collect();
504            format!("{{{}}}", entries.join(","))
505        }
506        serde_json::Value::Array(arr) => {
507            let entries: Vec<String> = arr.iter().map(canonical_json).collect();
508            format!("[{}]", entries.join(","))
509        }
510        _ => value.to_string(),
511    }
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517
518    fn test_config(normal: u32, reduced: u32, blocked: u32) -> LoopDetectionConfig {
519        LoopDetectionConfig {
520            normal_threshold: normal,
521            reduced_threshold: reduced,
522            blocked_threshold: blocked,
523            window_secs: 300,
524            search_group_limit: 10,
525            tool_total_limits: std::collections::HashMap::new(),
526        }
527    }
528
529    #[test]
530    fn normal_calls_pass_through() {
531        let mut detector = LoopDetector::new();
532        let r1 = detector.record_call("ctx_read", "abc123");
533        assert_eq!(r1.level, ThrottleLevel::Normal);
534        assert_eq!(r1.call_count, 1);
535        assert!(r1.message.is_none());
536    }
537
538    #[test]
539    fn repeated_calls_trigger_reduced() {
540        let _lock = crate::core::data_dir::test_env_lock();
541        std::env::set_var("LEAN_CTX_META", "1");
542        let cfg = LoopDetectionConfig::default();
543        let mut detector = LoopDetector::with_config(&cfg);
544        for _ in 0..cfg.normal_threshold {
545            detector.record_call("ctx_read", "same_fp");
546        }
547        let result = detector.record_call("ctx_read", "same_fp");
548        assert_eq!(result.level, ThrottleLevel::Reduced);
549        assert!(result.message.is_some());
550        std::env::remove_var("LEAN_CTX_META");
551    }
552
553    #[test]
554    fn excessive_calls_get_blocked_when_enabled() {
555        // Blocking must be explicitly enabled (blocked_threshold > 0)
556        let cfg = LoopDetectionConfig {
557            blocked_threshold: 6,
558            ..Default::default()
559        };
560        let mut detector = LoopDetector::with_config(&cfg);
561        for _ in 0..cfg.blocked_threshold {
562            detector.record_call("ctx_shell", "same_fp");
563        }
564        let result = detector.record_call("ctx_shell", "same_fp");
565        assert_eq!(result.level, ThrottleLevel::Blocked);
566        assert!(result.message.unwrap().contains("LOOP DETECTED"));
567    }
568
569    #[test]
570    fn blocking_disabled_by_default() {
571        // Default config has blocked_threshold = 0, so blocking never happens
572        let cfg = LoopDetectionConfig::default();
573        assert_eq!(cfg.blocked_threshold, 0);
574        let mut detector = LoopDetector::with_config(&cfg);
575        // Even 100 calls should not block when blocking is disabled
576        for _ in 0..100 {
577            detector.record_call("ctx_shell", "same_fp");
578        }
579        let result = detector.record_call("ctx_shell", "same_fp");
580        // Should be Reduced (warning) but never Blocked
581        assert_ne!(result.level, ThrottleLevel::Blocked);
582    }
583
584    #[test]
585    fn different_args_tracked_separately() {
586        let mut detector = LoopDetector::new();
587        for _ in 0..10 {
588            detector.record_call("ctx_read", "fp_a");
589        }
590        let result = detector.record_call("ctx_read", "fp_b");
591        assert_eq!(result.level, ThrottleLevel::Normal);
592        assert_eq!(result.call_count, 1);
593    }
594
595    #[test]
596    fn fingerprint_deterministic() {
597        let args = serde_json::json!({"path": "test.rs", "mode": "full"});
598        let fp1 = LoopDetector::fingerprint(&args);
599        let fp2 = LoopDetector::fingerprint(&args);
600        assert_eq!(fp1, fp2);
601    }
602
603    #[test]
604    fn fingerprint_order_independent() {
605        let a = serde_json::json!({"mode": "full", "path": "test.rs"});
606        let b = serde_json::json!({"path": "test.rs", "mode": "full"});
607        assert_eq!(LoopDetector::fingerprint(&a), LoopDetector::fingerprint(&b));
608    }
609
610    #[test]
611    fn stats_shows_duplicates() {
612        let mut detector = LoopDetector::new();
613        for _ in 0..5 {
614            detector.record_call("ctx_read", "fp_a");
615        }
616        detector.record_call("ctx_shell", "fp_b");
617        let stats = detector.stats();
618        assert_eq!(stats.len(), 1);
619        assert_eq!(stats[0].1, 5);
620    }
621
622    #[test]
623    fn reset_clears_state() {
624        let mut detector = LoopDetector::new();
625        for _ in 0..5 {
626            detector.record_call("ctx_read", "fp_a");
627        }
628        detector.reset();
629        let result = detector.record_call("ctx_read", "fp_a");
630        assert_eq!(result.call_count, 1);
631    }
632
633    #[test]
634    fn custom_thresholds_from_config() {
635        let cfg = test_config(1, 2, 3);
636        let mut detector = LoopDetector::with_config(&cfg);
637        detector.record_call("ctx_read", "fp");
638        let r = detector.record_call("ctx_read", "fp");
639        assert_eq!(r.level, ThrottleLevel::Reduced);
640        detector.record_call("ctx_read", "fp");
641        let r = detector.record_call("ctx_read", "fp");
642        assert_eq!(r.level, ThrottleLevel::Blocked);
643    }
644
645    #[test]
646    fn similar_patterns_detected() {
647        assert!(patterns_are_similar("compress", "compress"));
648        assert!(patterns_are_similar("compress", "compression"));
649        assert!(patterns_are_similar("compress.*data", "compress"));
650        assert!(!patterns_are_similar("foo", "bar"));
651        assert!(!patterns_are_similar("ab", "cd"));
652    }
653
654    #[test]
655    fn search_group_tracking_when_blocking_enabled() {
656        // Blocking must be explicitly enabled for search group limits to block
657        let cfg = LoopDetectionConfig {
658            search_group_limit: 5,
659            blocked_threshold: 6, // Enable blocking
660            ..Default::default()
661        };
662        let mut detector = LoopDetector::with_config(&cfg);
663        for i in 0..5 {
664            let fp = format!("fp_{i}");
665            let r = detector.record_search("ctx_search", &fp, Some(&format!("pattern_{i}")));
666            assert_ne!(r.level, ThrottleLevel::Blocked, "call {i} should not block");
667        }
668        let r = detector.record_search("ctx_search", "fp_5", Some("pattern_5"));
669        assert_eq!(r.level, ThrottleLevel::Blocked);
670        assert!(r.message.unwrap().contains("search calls"));
671    }
672
673    #[test]
674    fn similar_search_patterns_trigger_block_when_enabled() {
675        // Blocking must be explicitly enabled
676        let cfg = LoopDetectionConfig {
677            blocked_threshold: 6,
678            ..Default::default()
679        };
680        let mut detector = LoopDetector::with_config(&cfg);
681        let variants = [
682            "compress",
683            "compression",
684            "compress.*data",
685            "compress_output",
686            "compressor",
687            "compress_result",
688            "compress_file",
689        ];
690        for (i, pat) in variants
691            .iter()
692            .enumerate()
693            .take(cfg.blocked_threshold as usize)
694        {
695            detector.record_search("ctx_search", &format!("fp_{i}"), Some(pat));
696        }
697        let r = detector.record_search("ctx_search", "fp_new", Some("compress_all"));
698        assert_eq!(r.level, ThrottleLevel::Blocked);
699    }
700
701    #[test]
702    fn is_search_tool_detection() {
703        assert!(LoopDetector::is_search_tool("ctx_search"));
704        assert!(LoopDetector::is_search_tool("ctx_semantic_search"));
705        assert!(!LoopDetector::is_search_tool("ctx_read"));
706        assert!(!LoopDetector::is_search_tool("ctx_shell"));
707    }
708
709    #[test]
710    fn is_search_shell_command_detection() {
711        assert!(LoopDetector::is_search_shell_command("grep -r foo ."));
712        assert!(LoopDetector::is_search_shell_command("rg pattern src/"));
713        assert!(LoopDetector::is_search_shell_command("find . -name '*.rs'"));
714        assert!(!LoopDetector::is_search_shell_command("cargo build"));
715        assert!(!LoopDetector::is_search_shell_command("git status"));
716    }
717
718    #[test]
719    fn correction_fresh_reread_detected() {
720        let mut detector = LoopDetector::new();
721        // First read (cold start period, skipped)
722        detector.record_read_for_correction("src/main.rs", "full", false);
723        detector.record_read_for_correction("src/lib.rs", "full", false);
724        detector.record_read_for_correction("src/util.rs", "full", false);
725        // 4th call: past cold start
726        detector.record_read_for_correction("src/main.rs", "full", false);
727        assert_eq!(detector.correction_count(), 0);
728        // fresh=true re-read of previously read file = correction signal
729        detector.record_read_for_correction("src/main.rs", "full", true);
730        assert_eq!(detector.correction_count(), 1);
731    }
732
733    #[test]
734    fn correction_mode_bounce_detected() {
735        let mut detector = LoopDetector::new();
736        // Cold start
737        for i in 0..COLD_START_CALLS {
738            detector.record_read_for_correction(&format!("f{i}.rs"), "full", false);
739        }
740        // Read with map mode
741        detector.record_read_for_correction("src/cache.rs", "map", false);
742        assert_eq!(detector.correction_count(), 0);
743        // Immediately bounce to full mode = correction
744        detector.record_read_for_correction("src/cache.rs", "full", false);
745        assert_eq!(detector.correction_count(), 1);
746    }
747
748    #[test]
749    fn correction_shell_rerun_detected() {
750        let mut detector = LoopDetector::new();
751        // Cold start
752        for i in 0..COLD_START_CALLS {
753            detector.record_shell_for_correction(&format!("echo {i}"));
754        }
755        // First run
756        detector.record_shell_for_correction("cargo test --lib");
757        assert_eq!(detector.correction_count(), 0);
758        // Same command again within 60s = correction
759        detector.record_shell_for_correction("cargo test --lib");
760        assert_eq!(detector.correction_count(), 1);
761    }
762
763    #[test]
764    fn correction_rate_calculation() {
765        let mut detector = LoopDetector::new();
766        for i in 0..COLD_START_CALLS {
767            detector.record_shell_for_correction(&format!("init{i}"));
768        }
769        detector.record_shell_for_correction("cargo check");
770        detector.record_shell_for_correction("cargo check");
771        detector.record_shell_for_correction("cargo check");
772        // 2 corrections (first run doesn't count)
773        assert_eq!(detector.correction_count(), 2);
774        assert!(detector.correction_rate() > 0.0);
775    }
776
777    #[test]
778    fn correction_cold_start_ignored() {
779        let mut detector = LoopDetector::new();
780        // During cold start, same-command re-runs are not counted
781        detector.record_shell_for_correction("cargo check");
782        detector.record_shell_for_correction("cargo check");
783        detector.record_shell_for_correction("cargo check");
784        assert_eq!(detector.correction_count(), 0);
785    }
786
787    #[test]
788    fn search_block_message_has_guidance_when_blocking_enabled() {
789        // Blocking must be explicitly enabled to get block messages
790        let cfg = LoopDetectionConfig {
791            blocked_threshold: 6,
792            search_group_limit: 8,
793            ..Default::default()
794        };
795        let mut detector = LoopDetector::with_config(&cfg);
796        for i in 0..10 {
797            detector.record_search("ctx_search", &format!("fp_{i}"), Some("compress"));
798        }
799        let r = detector.record_search("ctx_search", "fp_new", Some("compress"));
800        assert_eq!(r.level, ThrottleLevel::Blocked);
801        let msg = r.message.unwrap();
802        assert!(msg.contains("ctx_tree"));
803        assert!(msg.contains("path"));
804        assert!(msg.contains("ctx_read"));
805    }
806}