Skip to main content

lean_ctx/core/
loop_detection.rs

1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4const NORMAL_THRESHOLD: u32 = 3;
5const REDUCED_THRESHOLD: u32 = 8;
6const BLOCKED_THRESHOLD: u32 = 12;
7const WINDOW_SECS: u64 = 300;
8
9#[derive(Debug, Clone)]
10pub struct LoopDetector {
11    call_history: HashMap<String, Vec<Instant>>,
12    duplicate_counts: HashMap<String, u32>,
13}
14
15#[derive(Debug, Clone, PartialEq)]
16pub enum ThrottleLevel {
17    Normal,
18    Reduced,
19    Blocked,
20}
21
22#[derive(Debug, Clone)]
23pub struct ThrottleResult {
24    pub level: ThrottleLevel,
25    pub call_count: u32,
26    pub message: Option<String>,
27}
28
29impl Default for LoopDetector {
30    fn default() -> Self {
31        Self::new()
32    }
33}
34
35impl LoopDetector {
36    pub fn new() -> Self {
37        Self {
38            call_history: HashMap::new(),
39            duplicate_counts: HashMap::new(),
40        }
41    }
42
43    pub fn record_call(&mut self, tool: &str, args_fingerprint: &str) -> ThrottleResult {
44        let key = format!("{tool}:{args_fingerprint}");
45        let now = Instant::now();
46        let window = Duration::from_secs(WINDOW_SECS);
47
48        let entries = self.call_history.entry(key.clone()).or_default();
49        entries.retain(|t| now.duration_since(*t) < window);
50        entries.push(now);
51
52        let count = entries.len() as u32;
53        *self.duplicate_counts.entry(key).or_default() = count;
54
55        if count > BLOCKED_THRESHOLD {
56            ThrottleResult {
57                level: ThrottleLevel::Blocked,
58                call_count: count,
59                message: Some(format!(
60                    "⚠ LOOP DETECTED: {tool} called {count}× with same args in {WINDOW_SECS}s. \
61                     Call blocked. Use ctx_batch_execute or vary your approach."
62                )),
63            }
64        } else if count > REDUCED_THRESHOLD {
65            ThrottleResult {
66                level: ThrottleLevel::Reduced,
67                call_count: count,
68                message: Some(format!(
69                    "⚠ Repetitive pattern: {tool} called {count}× with same args. \
70                     Results reduced. Consider batching with ctx_batch_execute."
71                )),
72            }
73        } else if count > NORMAL_THRESHOLD {
74            ThrottleResult {
75                level: ThrottleLevel::Reduced,
76                call_count: count,
77                message: Some(format!(
78                    "Note: {tool} called {count}× with similar args. Consider batching."
79                )),
80            }
81        } else {
82            ThrottleResult {
83                level: ThrottleLevel::Normal,
84                call_count: count,
85                message: None,
86            }
87        }
88    }
89
90    pub fn fingerprint(args: &serde_json::Value) -> String {
91        use std::collections::hash_map::DefaultHasher;
92        use std::hash::{Hash, Hasher};
93
94        let canonical = canonical_json(args);
95        let mut hasher = DefaultHasher::new();
96        canonical.hash(&mut hasher);
97        format!("{:016x}", hasher.finish())
98    }
99
100    pub fn stats(&self) -> Vec<(String, u32)> {
101        let mut entries: Vec<(String, u32)> = self
102            .duplicate_counts
103            .iter()
104            .filter(|(_, &count)| count > 1)
105            .map(|(k, &v)| (k.clone(), v))
106            .collect();
107        entries.sort_by(|a, b| b.1.cmp(&a.1));
108        entries
109    }
110
111    pub fn reset(&mut self) {
112        self.call_history.clear();
113        self.duplicate_counts.clear();
114    }
115}
116
117fn canonical_json(value: &serde_json::Value) -> String {
118    match value {
119        serde_json::Value::Object(map) => {
120            let mut keys: Vec<&String> = map.keys().collect();
121            keys.sort();
122            let entries: Vec<String> = keys
123                .iter()
124                .map(|k| format!("{}:{}", k, canonical_json(&map[*k])))
125                .collect();
126            format!("{{{}}}", entries.join(","))
127        }
128        serde_json::Value::Array(arr) => {
129            let entries: Vec<String> = arr.iter().map(canonical_json).collect();
130            format!("[{}]", entries.join(","))
131        }
132        _ => value.to_string(),
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[test]
141    fn normal_calls_pass_through() {
142        let mut detector = LoopDetector::new();
143        let r1 = detector.record_call("ctx_read", "abc123");
144        assert_eq!(r1.level, ThrottleLevel::Normal);
145        assert_eq!(r1.call_count, 1);
146        assert!(r1.message.is_none());
147    }
148
149    #[test]
150    fn repeated_calls_trigger_reduced() {
151        let mut detector = LoopDetector::new();
152        for _ in 0..NORMAL_THRESHOLD {
153            detector.record_call("ctx_read", "same_fp");
154        }
155        let result = detector.record_call("ctx_read", "same_fp");
156        assert_eq!(result.level, ThrottleLevel::Reduced);
157        assert!(result.message.is_some());
158    }
159
160    #[test]
161    fn excessive_calls_get_blocked() {
162        let mut detector = LoopDetector::new();
163        for _ in 0..BLOCKED_THRESHOLD {
164            detector.record_call("ctx_shell", "same_fp");
165        }
166        let result = detector.record_call("ctx_shell", "same_fp");
167        assert_eq!(result.level, ThrottleLevel::Blocked);
168        assert!(result.message.unwrap().contains("LOOP DETECTED"));
169    }
170
171    #[test]
172    fn different_args_tracked_separately() {
173        let mut detector = LoopDetector::new();
174        for _ in 0..10 {
175            detector.record_call("ctx_read", "fp_a");
176        }
177        let result = detector.record_call("ctx_read", "fp_b");
178        assert_eq!(result.level, ThrottleLevel::Normal);
179        assert_eq!(result.call_count, 1);
180    }
181
182    #[test]
183    fn fingerprint_deterministic() {
184        let args = serde_json::json!({"path": "test.rs", "mode": "full"});
185        let fp1 = LoopDetector::fingerprint(&args);
186        let fp2 = LoopDetector::fingerprint(&args);
187        assert_eq!(fp1, fp2);
188    }
189
190    #[test]
191    fn fingerprint_order_independent() {
192        let a = serde_json::json!({"mode": "full", "path": "test.rs"});
193        let b = serde_json::json!({"path": "test.rs", "mode": "full"});
194        assert_eq!(LoopDetector::fingerprint(&a), LoopDetector::fingerprint(&b));
195    }
196
197    #[test]
198    fn stats_shows_duplicates() {
199        let mut detector = LoopDetector::new();
200        for _ in 0..5 {
201            detector.record_call("ctx_read", "fp_a");
202        }
203        detector.record_call("ctx_shell", "fp_b");
204        let stats = detector.stats();
205        assert_eq!(stats.len(), 1);
206        assert_eq!(stats[0].1, 5);
207    }
208
209    #[test]
210    fn reset_clears_state() {
211        let mut detector = LoopDetector::new();
212        for _ in 0..5 {
213            detector.record_call("ctx_read", "fp_a");
214        }
215        detector.reset();
216        let result = detector.record_call("ctx_read", "fp_a");
217        assert_eq!(result.call_count, 1);
218    }
219}