1use crate::LoopAction;
11use serde_json::Value;
12use std::collections::VecDeque;
13
14#[derive(Debug, Clone)]
16pub struct LoopDetectionConfig {
17 pub exact_repeat_threshold: usize,
19 pub similarity_threshold: f64,
21 pub similarity_window: usize,
23 pub max_tokens_per_run: u64,
25 pub max_cost_microdollars_per_run: u64,
27}
28
29impl Default for LoopDetectionConfig {
30 fn default() -> Self {
31 Self {
32 exact_repeat_threshold: 3,
33 similarity_threshold: 0.9,
34 similarity_window: 5,
35 max_tokens_per_run: 0,
36 max_cost_microdollars_per_run: 0,
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
43struct ToolCall {
44 name: String,
45 args_str: String,
46}
47
48#[derive(Debug)]
50pub struct ToolLoopDetector {
51 config: LoopDetectionConfig,
52 history: VecDeque<ToolCall>,
53}
54
55impl ToolLoopDetector {
56 pub fn new(config: LoopDetectionConfig) -> Self {
57 Self {
58 config,
59 history: VecDeque::new(),
60 }
61 }
62
63 pub fn check(
65 &mut self,
66 tool_name: &str,
67 args: &Value,
68 tokens_used: u64,
69 cost_microdollars: u64,
70 ) -> LoopAction {
71 let args_str = serde_json::to_string(args).unwrap_or_default();
72 let call = ToolCall {
73 name: tool_name.to_string(),
74 args_str,
75 };
76
77 self.history.push_back(call);
79 let max_window = self
80 .config
81 .similarity_window
82 .max(self.config.exact_repeat_threshold);
83 while self.history.len() > max_window {
84 self.history.pop_front();
85 }
86
87 let mut worst = LoopAction::Continue;
90
91 if let Some(action) = self.check_exact_repeat() {
93 worst = action;
94 }
95
96 if let Some(action) = self.check_similarity() {
98 if severity(&action) > severity(&worst) {
99 worst = action;
100 }
101 }
102
103 if let Some(action) = self.check_cost_runaway(tokens_used, cost_microdollars) {
105 if severity(&action) > severity(&worst) {
106 worst = action;
107 }
108 }
109
110 worst
111 }
112
113 fn check_exact_repeat(&self) -> Option<LoopAction> {
114 let threshold = self.config.exact_repeat_threshold;
115 if threshold == 0 || self.history.len() < threshold {
116 return None;
117 }
118
119 let recent: Vec<_> = self.history.iter().rev().take(threshold).collect();
120 let first = &recent[0];
121 let all_same = recent
122 .iter()
123 .all(|c| c.name == first.name && c.args_str == first.args_str);
124
125 if all_same {
126 Some(LoopAction::InjectMessage(format!(
127 "You have called '{}' with identical arguments {} times in a row. \
128 Try a different approach or different parameters.",
129 first.name, threshold
130 )))
131 } else {
132 None
133 }
134 }
135
136 fn check_similarity(&self) -> Option<LoopAction> {
137 let window = self.config.similarity_window;
138 if window < 2 || self.history.len() < window {
139 return None;
140 }
141
142 let recent: Vec<_> = self.history.iter().rev().take(window).collect();
143
144 let first = &recent[0];
146 let all_same_tool = recent.iter().all(|c| c.name == first.name);
147 if !all_same_tool {
148 return None;
149 }
150
151 let mut high_similarity_count = 0;
153 let total_pairs = recent.len() - 1;
154
155 for i in 0..total_pairs {
156 let sim = jaccard_bigram_similarity(&recent[i].args_str, &recent[i + 1].args_str);
157 if sim >= self.config.similarity_threshold {
158 high_similarity_count += 1;
159 }
160 }
161
162 if high_similarity_count > total_pairs / 2 {
164 Some(LoopAction::RestrictTools(vec![first.name.clone()]))
165 } else {
166 None
167 }
168 }
169
170 fn check_cost_runaway(&self, tokens_used: u64, cost_microdollars: u64) -> Option<LoopAction> {
171 if self.config.max_tokens_per_run > 0 && tokens_used > self.config.max_tokens_per_run {
172 return Some(LoopAction::ForceComplete(format!(
173 "Token budget exceeded: {} tokens used (limit: {})",
174 tokens_used, self.config.max_tokens_per_run
175 )));
176 }
177 if self.config.max_cost_microdollars_per_run > 0
178 && cost_microdollars > self.config.max_cost_microdollars_per_run
179 {
180 return Some(LoopAction::ForceComplete(format!(
181 "Cost budget exceeded: {} microdollars (limit: {})",
182 cost_microdollars, self.config.max_cost_microdollars_per_run
183 )));
184 }
185 None
186 }
187}
188
189pub fn severity(action: &LoopAction) -> u8 {
191 match action {
192 LoopAction::Continue => 0,
193 LoopAction::InjectMessage(_) => 1,
194 LoopAction::RestrictTools(_) => 2,
195 LoopAction::ForceComplete(_) => 3,
196 }
197}
198
199fn jaccard_bigram_similarity(a: &str, b: &str) -> f64 {
201 if a.is_empty() && b.is_empty() {
202 return 1.0;
203 }
204 if a.is_empty() || b.is_empty() {
205 return 0.0;
206 }
207
208 let bigrams_a: std::collections::HashSet<(char, char)> =
209 a.chars().zip(a.chars().skip(1)).collect();
210 let bigrams_b: std::collections::HashSet<(char, char)> =
211 b.chars().zip(b.chars().skip(1)).collect();
212
213 if bigrams_a.is_empty() && bigrams_b.is_empty() {
214 return 1.0;
215 }
216
217 let intersection = bigrams_a.intersection(&bigrams_b).count();
218 let union = bigrams_a.union(&bigrams_b).count();
219
220 if union == 0 {
221 return 1.0;
222 }
223
224 intersection as f64 / union as f64
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use serde_json::json;
231
232 #[test]
233 fn exact_repeat_triggers_after_threshold() {
234 let config = LoopDetectionConfig {
235 exact_repeat_threshold: 3,
236 ..Default::default()
237 };
238 let mut detector = ToolLoopDetector::new(config);
239 let args = json!({"path": "/tmp/foo"});
240
241 assert_eq!(
242 detector.check("read_file", &args, 0, 0),
243 LoopAction::Continue
244 );
245 assert_eq!(
246 detector.check("read_file", &args, 0, 0),
247 LoopAction::Continue
248 );
249 match detector.check("read_file", &args, 0, 0) {
250 LoopAction::InjectMessage(msg) => {
251 assert!(msg.contains("read_file"));
252 assert!(msg.contains("3 times"));
253 }
254 other => panic!("expected InjectMessage, got {other:?}"),
255 }
256 }
257
258 #[test]
259 fn exact_repeat_resets_on_different_call() {
260 let config = LoopDetectionConfig {
261 exact_repeat_threshold: 3,
262 ..Default::default()
263 };
264 let mut detector = ToolLoopDetector::new(config);
265 let args = json!({"path": "/tmp/foo"});
266
267 detector.check("read_file", &args, 0, 0);
268 detector.check("read_file", &args, 0, 0);
269 detector.check("write_file", &json!({}), 0, 0);
271 assert_eq!(
272 detector.check("read_file", &args, 0, 0),
273 LoopAction::Continue
274 );
275 }
276
277 #[test]
278 fn exact_repeat_different_args_no_trigger() {
279 let config = LoopDetectionConfig {
280 exact_repeat_threshold: 3,
281 ..Default::default()
282 };
283 let mut detector = ToolLoopDetector::new(config);
284
285 detector.check("read_file", &json!({"path": "/a"}), 0, 0);
286 detector.check("read_file", &json!({"path": "/b"}), 0, 0);
287 assert_eq!(
288 detector.check("read_file", &json!({"path": "/c"}), 0, 0),
289 LoopAction::Continue
290 );
291 }
292
293 #[test]
294 fn similarity_detector_triggers_on_similar_args() {
295 let config = LoopDetectionConfig {
296 exact_repeat_threshold: 10, similarity_threshold: 0.8,
298 similarity_window: 4,
299 ..Default::default()
300 };
301 let mut detector = ToolLoopDetector::new(config);
302
303 detector.check(
305 "search",
306 &json!({"query": "rust async programming guide"}),
307 0,
308 0,
309 );
310 detector.check(
311 "search",
312 &json!({"query": "rust async programming guide 2"}),
313 0,
314 0,
315 );
316 detector.check(
317 "search",
318 &json!({"query": "rust async programming guide 3"}),
319 0,
320 0,
321 );
322 let action = detector.check(
323 "search",
324 &json!({"query": "rust async programming guide 4"}),
325 0,
326 0,
327 );
328
329 match action {
330 LoopAction::RestrictTools(tools) => assert!(tools.contains(&"search".to_string())),
331 other => panic!("expected RestrictTools, got {other:?}"),
332 }
333 }
334
335 #[test]
336 fn similarity_detector_no_trigger_on_different_args() {
337 let config = LoopDetectionConfig {
338 exact_repeat_threshold: 10,
339 similarity_threshold: 0.9,
340 similarity_window: 3,
341 ..Default::default()
342 };
343 let mut detector = ToolLoopDetector::new(config);
344
345 detector.check("search", &json!({"query": "rust"}), 0, 0);
346 detector.check("search", &json!({"query": "python machine learning"}), 0, 0);
347 assert_eq!(
348 detector.check("search", &json!({"query": "go concurrency patterns"}), 0, 0),
349 LoopAction::Continue
350 );
351 }
352
353 #[test]
354 fn cost_runaway_triggers_on_token_limit() {
355 let config = LoopDetectionConfig {
356 max_tokens_per_run: 1000,
357 ..Default::default()
358 };
359 let mut detector = ToolLoopDetector::new(config);
360
361 assert_eq!(
362 detector.check("tool", &json!({}), 500, 0),
363 LoopAction::Continue
364 );
365 match detector.check("tool", &json!({}), 1500, 0) {
366 LoopAction::ForceComplete(msg) => assert!(msg.contains("Token budget")),
367 other => panic!("expected ForceComplete, got {other:?}"),
368 }
369 }
370
371 #[test]
372 fn cost_runaway_triggers_on_cost_limit() {
373 let config = LoopDetectionConfig {
374 max_cost_microdollars_per_run: 5000,
375 ..Default::default()
376 };
377 let mut detector = ToolLoopDetector::new(config);
378
379 assert_eq!(
380 detector.check("tool", &json!({}), 0, 3000),
381 LoopAction::Continue
382 );
383 match detector.check("tool", &json!({}), 0, 6000) {
384 LoopAction::ForceComplete(msg) => assert!(msg.contains("Cost budget")),
385 other => panic!("expected ForceComplete, got {other:?}"),
386 }
387 }
388
389 #[test]
390 fn cost_runaway_disabled_when_zero() {
391 let config = LoopDetectionConfig {
392 max_tokens_per_run: 0,
393 max_cost_microdollars_per_run: 0,
394 ..Default::default()
395 };
396 let mut detector = ToolLoopDetector::new(config);
397
398 assert_eq!(
400 detector.check("tool", &json!({}), 999_999, 999_999),
401 LoopAction::Continue
402 );
403 }
404
405 #[test]
406 fn highest_severity_wins() {
407 let config = LoopDetectionConfig {
409 exact_repeat_threshold: 2,
410 max_tokens_per_run: 100,
411 ..Default::default()
412 };
413 let mut detector = ToolLoopDetector::new(config);
414 let args = json!({"x": 1});
415
416 detector.check("tool", &args, 50, 0);
417 match detector.check("tool", &args, 200, 0) {
419 LoopAction::ForceComplete(_) => {} other => panic!("expected ForceComplete (highest severity), got {other:?}"),
421 }
422 }
423
424 #[test]
425 fn jaccard_similarity_identical_strings() {
426 assert!((jaccard_bigram_similarity("hello", "hello") - 1.0).abs() < f64::EPSILON);
427 }
428
429 #[test]
430 fn jaccard_similarity_completely_different() {
431 let sim = jaccard_bigram_similarity("abc", "xyz");
432 assert!(sim < 0.1);
433 }
434
435 #[test]
436 fn jaccard_similarity_empty_strings() {
437 assert!((jaccard_bigram_similarity("", "") - 1.0).abs() < f64::EPSILON);
438 assert!((jaccard_bigram_similarity("abc", "") - 0.0).abs() < f64::EPSILON);
439 }
440
441 #[test]
442 fn jaccard_similarity_single_char_strings() {
443 assert!((jaccard_bigram_similarity("a", "b") - 1.0).abs() < f64::EPSILON);
445 }
446}