use crate::providers::ToolCall;
use std::collections::{HashMap, VecDeque};
pub const MAX_ITERATIONS_DEFAULT: u32 = 200;
pub const MAX_SUB_AGENT_ITERATIONS: usize = 20;
const REPEAT_THRESHOLD: usize = 3;
const WINDOW_SIZE: usize = 20;
const DISPLAY_RECENT: usize = 5;
#[derive(Default)]
pub struct LoopDetector {
window: VecDeque<String>,
recent: VecDeque<String>,
}
impl LoopDetector {
pub fn new() -> Self {
Self {
window: VecDeque::new(),
recent: VecDeque::new(),
}
}
pub fn record(&mut self, tool_calls: &[ToolCall]) -> Option<String> {
for tc in tool_calls {
let fp = fingerprint(&tc.function_name, &tc.arguments);
if crate::tools::is_mutating_tool(&tc.function_name) {
self.window.push_back(fp);
if self.window.len() > WINDOW_SIZE {
self.window.pop_front();
}
}
self.recent.push_back(tc.function_name.clone());
if self.recent.len() > DISPLAY_RECENT {
self.recent.pop_front();
}
}
self.check()
}
pub fn recent_names(&self) -> Vec<String> {
self.recent.iter().cloned().collect()
}
fn check(&self) -> Option<String> {
let mut counts: HashMap<&str, usize> = HashMap::new();
for fp in &self.window {
*counts.entry(fp.as_str()).or_insert(0) += 1;
}
counts
.into_iter()
.find(|(_, n)| *n >= REPEAT_THRESHOLD)
.map(|(fp, _)| fp.to_string())
}
}
fn fingerprint(name: &str, args: &str) -> String {
let prefix = &args[..args.len().min(200)];
format!("{name}:{prefix}")
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LoopContinuation {
Stop,
Continue50,
Continue200,
}
impl LoopContinuation {
pub fn extra_iterations(self) -> u32 {
match self {
Self::Stop => 0,
Self::Continue50 => 50,
Self::Continue200 => 200,
}
}
}
pub fn ask_continue_or_stop(
cap: u32,
recent_names: &[String],
prompt_fn: &dyn Fn(u32, &[String]) -> LoopContinuation,
) -> u32 {
prompt_fn(cap, recent_names).extra_iterations()
}
#[cfg(test)]
mod tests {
use super::*;
fn call(name: &str, args: &str) -> ToolCall {
ToolCall {
id: "x".into(),
function_name: name.into(),
arguments: args.into(),
thought_signature: None,
}
}
#[test]
fn no_loop_on_unique_calls() {
let mut d = LoopDetector::new();
assert!(d.record(&[call("Edit", "{\"path\":\"a.rs\"}")]).is_none());
assert!(d.record(&[call("Edit", "{\"path\":\"b.rs\"}")]).is_none());
assert!(d.record(&[call("Bash", "{\"cmd\":\"ls\"}")]).is_none());
}
#[test]
fn detects_repeated_identical_call() {
let mut d = LoopDetector::new();
let tc = call("Edit", "{\"path\":\"src/main.rs\"}");
assert!(d.record(std::slice::from_ref(&tc)).is_none());
assert!(d.record(std::slice::from_ref(&tc)).is_none());
assert!(d.record(std::slice::from_ref(&tc)).is_some());
}
#[test]
fn different_args_not_a_loop() {
let mut d = LoopDetector::new();
for i in 0..10 {
let args = format!("{{\"path\":\"file{i}.rs\"}}");
assert!(d.record(&[call("Edit", &args)]).is_none());
}
}
#[test]
fn ignores_readonly_tools() {
let mut d = LoopDetector::new();
let tc = call("Read", "{\"path\":\"src/main.rs\"}");
assert!(d.record(std::slice::from_ref(&tc)).is_none());
assert!(d.record(std::slice::from_ref(&tc)).is_none());
assert!(d.record(std::slice::from_ref(&tc)).is_none());
assert!(d.record(std::slice::from_ref(&tc)).is_none());
assert!(d.check().is_none());
}
#[test]
fn recent_names_tracks_last_five() {
let mut d = LoopDetector::new();
for i in 0..8 {
let name = format!("Tool{i}");
d.record(&[call(&name, "{}")]);
}
let names = d.recent_names();
assert_eq!(names.len(), 5);
assert_eq!(names[0], "Tool3");
assert_eq!(names[4], "Tool7");
}
#[test]
fn fingerprint_truncates_long_args() {
let long_args = "x".repeat(500);
let fp = fingerprint("Bash", &long_args);
assert_eq!(fp.len(), "Bash:".len() + 200);
}
}