lean_ctx/core/
loop_detection.rs1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4const NORMAL_THRESHOLD: u32 = 3;
5const REDUCED_THRESHOLD: u32 = 8;
6const BLOCKED_THRESHOLD: u32 = 12;
7const WINDOW_SECS: u64 = 300;
8
9#[derive(Debug, Clone)]
10pub struct LoopDetector {
11 call_history: HashMap<String, Vec<Instant>>,
12 duplicate_counts: HashMap<String, u32>,
13}
14
15#[derive(Debug, Clone, PartialEq)]
16pub enum ThrottleLevel {
17 Normal,
18 Reduced,
19 Blocked,
20}
21
22#[derive(Debug, Clone)]
23pub struct ThrottleResult {
24 pub level: ThrottleLevel,
25 pub call_count: u32,
26 pub message: Option<String>,
27}
28
29impl Default for LoopDetector {
30 fn default() -> Self {
31 Self::new()
32 }
33}
34
35impl LoopDetector {
36 pub fn new() -> Self {
37 Self {
38 call_history: HashMap::new(),
39 duplicate_counts: HashMap::new(),
40 }
41 }
42
43 pub fn record_call(&mut self, tool: &str, args_fingerprint: &str) -> ThrottleResult {
44 let key = format!("{tool}:{args_fingerprint}");
45 let now = Instant::now();
46 let window = Duration::from_secs(WINDOW_SECS);
47
48 let entries = self.call_history.entry(key.clone()).or_default();
49 entries.retain(|t| now.duration_since(*t) < window);
50 entries.push(now);
51
52 let count = entries.len() as u32;
53 *self.duplicate_counts.entry(key).or_default() = count;
54
55 if count > BLOCKED_THRESHOLD {
56 ThrottleResult {
57 level: ThrottleLevel::Blocked,
58 call_count: count,
59 message: Some(format!(
60 "⚠ LOOP DETECTED: {tool} called {count}× with same args in {WINDOW_SECS}s. \
61 Call blocked. Use ctx_batch_execute or vary your approach."
62 )),
63 }
64 } else if count > REDUCED_THRESHOLD {
65 ThrottleResult {
66 level: ThrottleLevel::Reduced,
67 call_count: count,
68 message: Some(format!(
69 "⚠ Repetitive pattern: {tool} called {count}× with same args. \
70 Results reduced. Consider batching with ctx_batch_execute."
71 )),
72 }
73 } else if count > NORMAL_THRESHOLD {
74 ThrottleResult {
75 level: ThrottleLevel::Reduced,
76 call_count: count,
77 message: Some(format!(
78 "Note: {tool} called {count}× with similar args. Consider batching."
79 )),
80 }
81 } else {
82 ThrottleResult {
83 level: ThrottleLevel::Normal,
84 call_count: count,
85 message: None,
86 }
87 }
88 }
89
90 pub fn fingerprint(args: &serde_json::Value) -> String {
91 use std::collections::hash_map::DefaultHasher;
92 use std::hash::{Hash, Hasher};
93
94 let canonical = canonical_json(args);
95 let mut hasher = DefaultHasher::new();
96 canonical.hash(&mut hasher);
97 format!("{:016x}", hasher.finish())
98 }
99
100 pub fn stats(&self) -> Vec<(String, u32)> {
101 let mut entries: Vec<(String, u32)> = self
102 .duplicate_counts
103 .iter()
104 .filter(|(_, &count)| count > 1)
105 .map(|(k, &v)| (k.clone(), v))
106 .collect();
107 entries.sort_by(|a, b| b.1.cmp(&a.1));
108 entries
109 }
110
111 pub fn reset(&mut self) {
112 self.call_history.clear();
113 self.duplicate_counts.clear();
114 }
115}
116
117fn canonical_json(value: &serde_json::Value) -> String {
118 match value {
119 serde_json::Value::Object(map) => {
120 let mut keys: Vec<&String> = map.keys().collect();
121 keys.sort();
122 let entries: Vec<String> = keys
123 .iter()
124 .map(|k| format!("{}:{}", k, canonical_json(&map[*k])))
125 .collect();
126 format!("{{{}}}", entries.join(","))
127 }
128 serde_json::Value::Array(arr) => {
129 let entries: Vec<String> = arr.iter().map(canonical_json).collect();
130 format!("[{}]", entries.join(","))
131 }
132 _ => value.to_string(),
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn normal_calls_pass_through() {
142 let mut detector = LoopDetector::new();
143 let r1 = detector.record_call("ctx_read", "abc123");
144 assert_eq!(r1.level, ThrottleLevel::Normal);
145 assert_eq!(r1.call_count, 1);
146 assert!(r1.message.is_none());
147 }
148
149 #[test]
150 fn repeated_calls_trigger_reduced() {
151 let mut detector = LoopDetector::new();
152 for _ in 0..NORMAL_THRESHOLD {
153 detector.record_call("ctx_read", "same_fp");
154 }
155 let result = detector.record_call("ctx_read", "same_fp");
156 assert_eq!(result.level, ThrottleLevel::Reduced);
157 assert!(result.message.is_some());
158 }
159
160 #[test]
161 fn excessive_calls_get_blocked() {
162 let mut detector = LoopDetector::new();
163 for _ in 0..BLOCKED_THRESHOLD {
164 detector.record_call("ctx_shell", "same_fp");
165 }
166 let result = detector.record_call("ctx_shell", "same_fp");
167 assert_eq!(result.level, ThrottleLevel::Blocked);
168 assert!(result.message.unwrap().contains("LOOP DETECTED"));
169 }
170
171 #[test]
172 fn different_args_tracked_separately() {
173 let mut detector = LoopDetector::new();
174 for _ in 0..10 {
175 detector.record_call("ctx_read", "fp_a");
176 }
177 let result = detector.record_call("ctx_read", "fp_b");
178 assert_eq!(result.level, ThrottleLevel::Normal);
179 assert_eq!(result.call_count, 1);
180 }
181
182 #[test]
183 fn fingerprint_deterministic() {
184 let args = serde_json::json!({"path": "test.rs", "mode": "full"});
185 let fp1 = LoopDetector::fingerprint(&args);
186 let fp2 = LoopDetector::fingerprint(&args);
187 assert_eq!(fp1, fp2);
188 }
189
190 #[test]
191 fn fingerprint_order_independent() {
192 let a = serde_json::json!({"mode": "full", "path": "test.rs"});
193 let b = serde_json::json!({"path": "test.rs", "mode": "full"});
194 assert_eq!(LoopDetector::fingerprint(&a), LoopDetector::fingerprint(&b));
195 }
196
197 #[test]
198 fn stats_shows_duplicates() {
199 let mut detector = LoopDetector::new();
200 for _ in 0..5 {
201 detector.record_call("ctx_read", "fp_a");
202 }
203 detector.record_call("ctx_shell", "fp_b");
204 let stats = detector.stats();
205 assert_eq!(stats.len(), 1);
206 assert_eq!(stats[0].1, 5);
207 }
208
209 #[test]
210 fn reset_clears_state() {
211 let mut detector = LoopDetector::new();
212 for _ in 0..5 {
213 detector.record_call("ctx_read", "fp_a");
214 }
215 detector.reset();
216 let result = detector.record_call("ctx_read", "fp_a");
217 assert_eq!(result.call_count, 1);
218 }
219}