use std::collections::HashMap;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone)]
pub struct LoopDetectorConfig {
pub enabled: bool,
pub window_size: usize,
pub max_repeats: usize,
}
impl Default for LoopDetectorConfig {
fn default() -> Self {
Self {
enabled: true,
window_size: 20,
max_repeats: 3,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum LoopDetectionResult {
Ok,
Warning(String),
Block(String),
Break(String),
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct ToolCallSignature {
tool_name: String,
args_hash: u64,
output_hash: u64,
}
pub struct LoopDetector {
config: LoopDetectorConfig,
window: Vec<ToolCallSignature>,
counts: HashMap<ToolCallSignature, usize>,
}
impl LoopDetector {
pub fn new(config: LoopDetectorConfig) -> Self {
Self {
config,
window: Vec::new(),
counts: HashMap::new(),
}
}
pub fn record(
&mut self,
tool_name: &str,
args: &serde_json::Value,
output: &str,
) -> LoopDetectionResult {
if !self.config.enabled {
return LoopDetectionResult::Ok;
}
let args_hash = hash_value(args);
let output_hash = hash_str(output);
let sig = ToolCallSignature {
tool_name: tool_name.to_string(),
args_hash,
output_hash,
};
self.window.push(sig.clone());
*self.counts.entry(sig.clone()).or_insert(0) += 1;
if self.window.len() > self.config.window_size {
let oldest = self.window.remove(0);
let count = self.counts.get_mut(&oldest).copied().unwrap_or(1);
if count <= 1 {
self.counts.remove(&oldest);
} else {
*self.counts.get_mut(&oldest).unwrap() -= 1;
}
}
let repeat_count = *self.counts.get(&sig).unwrap_or(&0);
let max = self.config.max_repeats;
if repeat_count > max {
let msg = format!(
"工具 '{tool_name}' 产生了相同的参数和输出,连续重复 {repeat_count} 次(超过阈值 {max}),终止 agent loop。"
);
LoopDetectionResult::Break(msg)
} else if repeat_count == max {
let msg = format!(
"工具 '{tool_name}' 重复调用 {repeat_count} 次,输出相同。请换一种方法完成任务,不要继续重复调用。"
);
LoopDetectionResult::Block(msg)
} else if repeat_count > max / 2 && max >= 2 {
let msg = format!(
"[Loop Warning] 工具 '{tool_name}' 已重复产生相同输出 {repeat_count} 次。请调整策略,避免循环。"
);
LoopDetectionResult::Warning(msg)
} else {
LoopDetectionResult::Ok
}
}
}
fn hash_str(s: &str) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
s.hash(&mut hasher);
hasher.finish()
}
fn hash_value(v: &serde_json::Value) -> u64 {
hash_str(&v.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn no_repeat_returns_ok() {
let mut detector = LoopDetector::new(LoopDetectorConfig::default());
let result = detector.record("search", &serde_json::json!({"q": "rust"}), "result 1");
assert_eq!(result, LoopDetectionResult::Ok);
}
#[test]
fn repeat_triggers_warning_then_block_then_break() {
let config = LoopDetectorConfig {
enabled: true,
window_size: 10,
max_repeats: 3,
};
let mut detector = LoopDetector::new(config);
let args = serde_json::json!({"q": "test"});
let output = "same output";
assert_eq!(
detector.record("tool", &args, output),
LoopDetectionResult::Ok
);
let r2 = detector.record("tool", &args, output);
assert!(matches!(r2, LoopDetectionResult::Warning(_)));
let r3 = detector.record("tool", &args, output);
assert!(matches!(r3, LoopDetectionResult::Block(_)));
let r4 = detector.record("tool", &args, output);
assert!(matches!(r4, LoopDetectionResult::Break(_)));
}
#[test]
fn different_args_no_loop() {
let mut detector = LoopDetector::new(LoopDetectorConfig::default());
for i in 0..5 {
let args = serde_json::json!({"q": i});
let result = detector.record("search", &args, "result");
assert_eq!(result, LoopDetectionResult::Ok);
}
}
#[test]
fn disabled_always_returns_ok() {
let config = LoopDetectorConfig {
enabled: false,
..Default::default()
};
let mut detector = LoopDetector::new(config);
let args = serde_json::json!({});
for _ in 0..10 {
assert_eq!(
detector.record("tool", &args, "out"),
LoopDetectionResult::Ok
);
}
}
}