lellm_agent/tools/
loop_detector.rs1use std::collections::HashMap;
4use std::hash::{Hash, Hasher};
5
6use lellm_core::ToolCall;
7
8#[derive(Clone, Debug)]
10pub struct ToolCallFingerprint {
11 pub tool_name: String,
12 pub normalized_args: String,
13}
14
15impl Hash for ToolCallFingerprint {
16 fn hash<H: Hasher>(&self, state: &mut H) {
17 self.tool_name.hash(state);
18 self.normalized_args.hash(state);
19 }
20}
21
22impl PartialEq for ToolCallFingerprint {
23 fn eq(&self, other: &Self) -> bool {
24 self.tool_name == other.tool_name && self.normalized_args == other.normalized_args
25 }
26}
27
28impl Eq for ToolCallFingerprint {}
29
30impl ToolCallFingerprint {
31 pub fn from_call(call: &ToolCall) -> Self {
33 let normalized = Self::normalize_json(&call.arguments);
34 Self {
35 tool_name: call.name.clone(),
36 normalized_args: normalized,
37 }
38 }
39
40 fn normalize_json(value: &serde_json::Value) -> String {
42 match value {
43 serde_json::Value::Object(map) => {
44 let mut entries: Vec<_> = map.iter().collect();
45 entries.sort_by_key(|(k, _)| (*k).clone());
46 let parts: Vec<_> = entries
47 .iter()
48 .map(|(k, v)| format!("{}:{}", k, Self::normalize_json(v)))
49 .collect();
50 parts.join(",")
51 }
52 serde_json::Value::Array(arr) => {
53 let parts: Vec<_> = arr.iter().map(Self::normalize_json).collect();
54 format!("[{}]", parts.join(","))
55 }
56 serde_json::Value::String(s) => s.replace(char::is_whitespace, ""),
57 other => other.to_string().replace(char::is_whitespace, ""),
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub enum LoopIntervention {
65 InjectHint(String),
67 Break,
69}
70
71pub struct LoopDetector {
73 history: Vec<ToolCallFingerprint>,
74 threshold: usize,
75}
76
77impl LoopDetector {
78 pub fn new(threshold: usize) -> Self {
79 Self {
80 history: Vec::new(),
81 threshold,
82 }
83 }
84
85 pub fn record(&mut self, calls: &[ToolCall]) {
87 let fingerprints: Vec<_> = calls.iter().map(ToolCallFingerprint::from_call).collect();
88 self.history.extend(fingerprints);
89 }
90
91 pub fn check(&self) -> Option<LoopIntervention> {
93 if self.history.len() < self.threshold * 2 {
94 return None;
95 }
96
97 let recent_len = self.threshold.min(self.history.len());
98 let recent = &self.history[self.history.len() - recent_len..];
99
100 let mut counts: HashMap<&ToolCallFingerprint, usize> = HashMap::new();
102 for fp in recent {
103 *counts.entry(fp).or_insert(0) += 1;
104 }
105
106 let max_repeat = counts.values().max().copied().unwrap_or(0);
107 if max_repeat >= self.threshold {
108 Some(LoopIntervention::InjectHint(
109 "你正在重复调用相同的工具,请尝试不同方法".to_string(),
110 ))
111 } else {
112 None
113 }
114 }
115
116 pub fn reset(&mut self) {
118 self.history.clear();
119 }
120}
121
122impl Default for LoopDetector {
123 fn default() -> Self {
124 Self::new(3)
125 }
126}