1use crate::providers::ToolCall;
9use std::collections::{HashMap, VecDeque};
10
11pub const MAX_ITERATIONS_DEFAULT: u32 = 200;
13
14pub const MAX_SUB_AGENT_ITERATIONS: usize = 20;
16
17const REPEAT_THRESHOLD: usize = 3;
19
20const WINDOW_SIZE: usize = 20;
22
23const DISPLAY_RECENT: usize = 5;
25
26#[derive(Default)]
30pub struct LoopDetector {
31 window: VecDeque<String>,
33 recent: VecDeque<String>,
35}
36
37impl LoopDetector {
38 pub fn new() -> Self {
39 Self {
40 window: VecDeque::new(),
41 recent: VecDeque::new(),
42 }
43 }
44
45 pub fn record(&mut self, tool_calls: &[ToolCall]) -> Option<String> {
48 for tc in tool_calls {
49 let fp = fingerprint(&tc.function_name, &tc.arguments);
50
51 if is_mutating_tool(&tc.function_name) {
54 self.window.push_back(fp);
55 if self.window.len() > WINDOW_SIZE {
56 self.window.pop_front();
57 }
58 }
59
60 self.recent.push_back(tc.function_name.clone());
62 if self.recent.len() > DISPLAY_RECENT {
63 self.recent.pop_front();
64 }
65 }
66
67 self.check()
68 }
69
70 pub fn recent_names(&self) -> Vec<String> {
72 self.recent.iter().cloned().collect()
73 }
74
75 fn check(&self) -> Option<String> {
76 let mut counts: HashMap<&str, usize> = HashMap::new();
77 for fp in &self.window {
78 *counts.entry(fp.as_str()).or_insert(0) += 1;
79 }
80 counts
81 .into_iter()
82 .find(|(_, n)| *n >= REPEAT_THRESHOLD)
83 .map(|(fp, _)| fp.to_string())
84 }
85}
86
87fn fingerprint(name: &str, args: &str) -> String {
89 let prefix = &args[..args.len().min(200)];
90 format!("{name}:{prefix}")
91}
92
93fn is_mutating_tool(name: &str) -> bool {
96 matches!(
97 name,
98 "Bash" | "Edit" | "Write" | "Delete" | "MemoryWrite" | "CreateAgent" | "InvokeAgent"
99 )
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
108#[serde(rename_all = "snake_case")]
109pub enum LoopContinuation {
110 Stop,
111 Continue50,
112 Continue200,
113}
114
115impl LoopContinuation {
116 pub fn extra_iterations(self) -> u32 {
118 match self {
119 Self::Stop => 0,
120 Self::Continue50 => 50,
121 Self::Continue200 => 200,
122 }
123 }
124}
125
126pub fn ask_continue_or_stop(
132 cap: u32,
133 recent_names: &[String],
134 prompt_fn: &dyn Fn(u32, &[String]) -> LoopContinuation,
135) -> u32 {
136 prompt_fn(cap, recent_names).extra_iterations()
137}
138
139#[cfg(test)]
142mod tests {
143 use super::*;
144
145 fn call(name: &str, args: &str) -> ToolCall {
146 ToolCall {
147 id: "x".into(),
148 function_name: name.into(),
149 arguments: args.into(),
150 thought_signature: None,
151 }
152 }
153
154 #[test]
155 fn no_loop_on_unique_calls() {
156 let mut d = LoopDetector::new();
157 assert!(d.record(&[call("Edit", "{\"path\":\"a.rs\"}")]).is_none());
158 assert!(d.record(&[call("Edit", "{\"path\":\"b.rs\"}")]).is_none());
159 assert!(d.record(&[call("Bash", "{\"cmd\":\"ls\"}")]).is_none());
160 }
161
162 #[test]
163 fn detects_repeated_identical_call() {
164 let mut d = LoopDetector::new();
165 let tc = call("Edit", "{\"path\":\"src/main.rs\"}");
166 assert!(d.record(std::slice::from_ref(&tc)).is_none());
167 assert!(d.record(std::slice::from_ref(&tc)).is_none());
168 assert!(d.record(std::slice::from_ref(&tc)).is_some());
170 }
171
172 #[test]
173 fn different_args_not_a_loop() {
174 let mut d = LoopDetector::new();
175 for i in 0..10 {
176 let args = format!("{{\"path\":\"file{i}.rs\"}}");
177 assert!(d.record(&[call("Edit", &args)]).is_none());
178 }
179 }
180
181 #[test]
182 fn ignores_readonly_tools() {
183 let mut d = LoopDetector::new();
184 let tc = call("Read", "{\"path\":\"src/main.rs\"}");
185 assert!(d.record(std::slice::from_ref(&tc)).is_none());
186 assert!(d.record(std::slice::from_ref(&tc)).is_none());
187 assert!(d.record(std::slice::from_ref(&tc)).is_none());
188 assert!(d.record(std::slice::from_ref(&tc)).is_none());
189 assert!(d.check().is_none());
191 }
192
193 #[test]
194 fn recent_names_tracks_last_five() {
195 let mut d = LoopDetector::new();
196 for i in 0..8 {
197 let name = format!("Tool{i}");
198 d.record(&[call(&name, "{}")]);
199 }
200 let names = d.recent_names();
201 assert_eq!(names.len(), 5);
202 assert_eq!(names[0], "Tool3");
203 assert_eq!(names[4], "Tool7");
204 }
205
206 #[test]
207 fn fingerprint_truncates_long_args() {
208 let long_args = "x".repeat(500);
209 let fp = fingerprint("Bash", &long_args);
210 assert_eq!(fp.len(), "Bash:".len() + 200);
212 }
213}