use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use lellm_core::ToolCall;
#[derive(Clone, Debug)]
pub struct ToolCallFingerprint {
pub tool_name: String,
pub normalized_args: String,
}
impl Hash for ToolCallFingerprint {
fn hash<H: Hasher>(&self, state: &mut H) {
self.tool_name.hash(state);
self.normalized_args.hash(state);
}
}
impl PartialEq for ToolCallFingerprint {
fn eq(&self, other: &Self) -> bool {
self.tool_name == other.tool_name && self.normalized_args == other.normalized_args
}
}
impl Eq for ToolCallFingerprint {}
impl ToolCallFingerprint {
pub fn from_call(call: &ToolCall) -> Self {
let normalized = Self::normalize_json(&call.arguments);
Self {
tool_name: call.name.clone(),
normalized_args: normalized,
}
}
fn normalize_json(value: &serde_json::Value) -> String {
match value {
serde_json::Value::Object(map) => {
let mut entries: Vec<_> = map.iter().collect();
entries.sort_by_key(|(k, _)| (*k).clone());
let parts: Vec<_> = entries
.iter()
.map(|(k, v)| format!("{}:{}", k, Self::normalize_json(v)))
.collect();
parts.join(",")
}
serde_json::Value::Array(arr) => {
let parts: Vec<_> = arr.iter().map(Self::normalize_json).collect();
format!("[{}]", parts.join(","))
}
serde_json::Value::String(s) => s.replace(char::is_whitespace, ""),
other => other.to_string().replace(char::is_whitespace, ""),
}
}
}
#[derive(Debug, Clone)]
pub enum LoopIntervention {
InjectHint(String),
Break,
}
pub struct LoopDetector {
history: Vec<ToolCallFingerprint>,
threshold: usize,
}
impl LoopDetector {
pub fn new(threshold: usize) -> Self {
Self {
history: Vec::new(),
threshold,
}
}
pub fn record(&mut self, calls: &[ToolCall]) {
let fingerprints: Vec<_> = calls.iter().map(ToolCallFingerprint::from_call).collect();
self.history.extend(fingerprints);
}
pub fn check(&self) -> Option<LoopIntervention> {
if self.history.len() < self.threshold * 2 {
return None;
}
let recent_len = self.threshold.min(self.history.len());
let recent = &self.history[self.history.len() - recent_len..];
let mut counts: HashMap<&ToolCallFingerprint, usize> = HashMap::new();
for fp in recent {
*counts.entry(fp).or_insert(0) += 1;
}
let max_repeat = counts.values().max().copied().unwrap_or(0);
if max_repeat >= self.threshold {
Some(LoopIntervention::InjectHint(
"你正在重复调用相同的工具,请尝试不同方法".to_string(),
))
} else {
None
}
}
pub fn reset(&mut self) {
self.history.clear();
}
}
impl Default for LoopDetector {
fn default() -> Self {
Self::new(3)
}
}