1use crate::providers::ToolCall;
9use std::collections::{HashMap, VecDeque};
10
11pub const MAX_ITERATIONS_DEFAULT: u32 = 200;
13
14pub const MAX_SUB_AGENT_ITERATIONS: usize = 20;
16
17const REPEAT_THRESHOLD: usize = 3;
19
20const WINDOW_SIZE: usize = 20;
22
23const DISPLAY_RECENT: usize = 5;
25
26#[derive(Default)]
30pub struct LoopDetector {
31 window: VecDeque<String>,
33 recent: VecDeque<String>,
35}
36
37impl LoopDetector {
38 pub fn new() -> Self {
39 Self {
40 window: VecDeque::new(),
41 recent: VecDeque::new(),
42 }
43 }
44
45 pub fn record(&mut self, tool_calls: &[ToolCall]) -> Option<String> {
48 for tc in tool_calls {
49 let fp = fingerprint(&tc.function_name, &tc.arguments);
50
51 if crate::tools::is_mutating_tool(&tc.function_name) {
54 self.window.push_back(fp);
55 if self.window.len() > WINDOW_SIZE {
56 self.window.pop_front();
57 }
58 }
59
60 self.recent.push_back(tc.function_name.clone());
62 if self.recent.len() > DISPLAY_RECENT {
63 self.recent.pop_front();
64 }
65 }
66
67 self.check()
68 }
69
70 pub fn recent_names(&self) -> Vec<String> {
72 self.recent.iter().cloned().collect()
73 }
74
75 fn check(&self) -> Option<String> {
76 let mut counts: HashMap<&str, usize> = HashMap::new();
77 for fp in &self.window {
78 *counts.entry(fp.as_str()).or_insert(0) += 1;
79 }
80 counts
81 .into_iter()
82 .find(|(_, n)| *n >= REPEAT_THRESHOLD)
83 .map(|(fp, _)| fp.to_string())
84 }
85}
86
87fn fingerprint(name: &str, args: &str) -> String {
89 let prefix = &args[..args.len().min(200)];
90 format!("{name}:{prefix}")
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
99#[serde(rename_all = "snake_case")]
100pub enum LoopContinuation {
101 Stop,
102 Continue50,
103 Continue200,
104}
105
106impl LoopContinuation {
107 pub fn extra_iterations(self) -> u32 {
109 match self {
110 Self::Stop => 0,
111 Self::Continue50 => 50,
112 Self::Continue200 => 200,
113 }
114 }
115}
116
117pub fn ask_continue_or_stop(
123 cap: u32,
124 recent_names: &[String],
125 prompt_fn: &dyn Fn(u32, &[String]) -> LoopContinuation,
126) -> u32 {
127 prompt_fn(cap, recent_names).extra_iterations()
128}
129
130#[cfg(test)]
133mod tests {
134 use super::*;
135
136 fn call(name: &str, args: &str) -> ToolCall {
137 ToolCall {
138 id: "x".into(),
139 function_name: name.into(),
140 arguments: args.into(),
141 thought_signature: None,
142 }
143 }
144
145 #[test]
146 fn no_loop_on_unique_calls() {
147 let mut d = LoopDetector::new();
148 assert!(d.record(&[call("Edit", "{\"path\":\"a.rs\"}")]).is_none());
149 assert!(d.record(&[call("Edit", "{\"path\":\"b.rs\"}")]).is_none());
150 assert!(d.record(&[call("Bash", "{\"cmd\":\"ls\"}")]).is_none());
151 }
152
153 #[test]
154 fn detects_repeated_identical_call() {
155 let mut d = LoopDetector::new();
156 let tc = call("Edit", "{\"path\":\"src/main.rs\"}");
157 assert!(d.record(std::slice::from_ref(&tc)).is_none());
158 assert!(d.record(std::slice::from_ref(&tc)).is_none());
159 assert!(d.record(std::slice::from_ref(&tc)).is_some());
161 }
162
163 #[test]
164 fn different_args_not_a_loop() {
165 let mut d = LoopDetector::new();
166 for i in 0..10 {
167 let args = format!("{{\"path\":\"file{i}.rs\"}}");
168 assert!(d.record(&[call("Edit", &args)]).is_none());
169 }
170 }
171
172 #[test]
173 fn ignores_readonly_tools() {
174 let mut d = LoopDetector::new();
175 let tc = call("Read", "{\"path\":\"src/main.rs\"}");
176 assert!(d.record(std::slice::from_ref(&tc)).is_none());
177 assert!(d.record(std::slice::from_ref(&tc)).is_none());
178 assert!(d.record(std::slice::from_ref(&tc)).is_none());
179 assert!(d.record(std::slice::from_ref(&tc)).is_none());
180 assert!(d.check().is_none());
182 }
183
184 #[test]
185 fn recent_names_tracks_last_five() {
186 let mut d = LoopDetector::new();
187 for i in 0..8 {
188 let name = format!("Tool{i}");
189 d.record(&[call(&name, "{}")]);
190 }
191 let names = d.recent_names();
192 assert_eq!(names.len(), 5);
193 assert_eq!(names[0], "Tool3");
194 assert_eq!(names[4], "Tool7");
195 }
196
197 #[test]
198 fn fingerprint_truncates_long_args() {
199 let long_args = "x".repeat(500);
200 let fp = fingerprint("Bash", &long_args);
201 assert_eq!(fp.len(), "Bash:".len() + 200);
203 }
204}