Skip to main content

lellm_agent/tools/
loop_detector.rs

1//! 循环检测器 — 指纹去重 + 阈值触发。
2
3use std::collections::HashMap;
4use std::hash::{Hash, Hasher};
5
6use lellm_core::ToolCall;
7
8/// 工具调用指纹(参数归一化后)。
9#[derive(Clone, Debug)]
10pub struct ToolCallFingerprint {
11    pub tool_name: String,
12    pub normalized_args: String,
13}
14
15impl Hash for ToolCallFingerprint {
16    fn hash<H: Hasher>(&self, state: &mut H) {
17        self.tool_name.hash(state);
18        self.normalized_args.hash(state);
19    }
20}
21
22impl PartialEq for ToolCallFingerprint {
23    fn eq(&self, other: &Self) -> bool {
24        self.tool_name == other.tool_name && self.normalized_args == other.normalized_args
25    }
26}
27
28impl Eq for ToolCallFingerprint {}
29
30impl ToolCallFingerprint {
31    /// 从 ToolCall 生成指纹
32    pub fn from_call(call: &ToolCall) -> Self {
33        let normalized = Self::normalize_json(&call.arguments);
34        Self {
35            tool_name: call.name.clone(),
36            normalized_args: normalized,
37        }
38    }
39
40    /// JSON 键排序 + 空白去除
41    fn normalize_json(value: &serde_json::Value) -> String {
42        match value {
43            serde_json::Value::Object(map) => {
44                let mut entries: Vec<_> = map.iter().collect();
45                entries.sort_by_key(|(k, _)| (*k).clone());
46                let parts: Vec<_> = entries
47                    .iter()
48                    .map(|(k, v)| format!("{}:{}", k, Self::normalize_json(v)))
49                    .collect();
50                parts.join(",")
51            }
52            serde_json::Value::Array(arr) => {
53                let parts: Vec<_> = arr.iter().map(Self::normalize_json).collect();
54                format!("[{}]", parts.join(","))
55            }
56            serde_json::Value::String(s) => s.replace(char::is_whitespace, ""),
57            other => other.to_string().replace(char::is_whitespace, ""),
58        }
59    }
60}
61
62/// 循环干预方式
63#[derive(Debug, Clone)]
64pub enum LoopIntervention {
65    /// 注入系统提示
66    InjectHint(String),
67    /// 中断循环
68    Break,
69}
70
71/// 循环检测器
72pub struct LoopDetector {
73    history: Vec<ToolCallFingerprint>,
74    threshold: usize,
75}
76
77impl LoopDetector {
78    pub fn new(threshold: usize) -> Self {
79        Self {
80            history: Vec::new(),
81            threshold,
82        }
83    }
84
85    /// 记录一轮 tool_calls 的指纹
86    pub fn record(&mut self, calls: &[ToolCall]) {
87        let fingerprints: Vec<_> = calls.iter().map(ToolCallFingerprint::from_call).collect();
88        self.history.extend(fingerprints);
89    }
90
91    /// 检查是否检测到循环
92    pub fn check(&self) -> Option<LoopIntervention> {
93        if self.history.len() < self.threshold * 2 {
94            return None;
95        }
96
97        let recent_len = self.threshold.min(self.history.len());
98        let recent = &self.history[self.history.len() - recent_len..];
99
100        // 检查最近 N 个指纹是否高度重复
101        let mut counts: HashMap<&ToolCallFingerprint, usize> = HashMap::new();
102        for fp in recent {
103            *counts.entry(fp).or_insert(0) += 1;
104        }
105
106        let max_repeat = counts.values().max().copied().unwrap_or(0);
107        if max_repeat >= self.threshold {
108            Some(LoopIntervention::InjectHint(
109                "你正在重复调用相同的工具,请尝试不同方法".to_string(),
110            ))
111        } else {
112            None
113        }
114    }
115
116    /// 重置检测器
117    pub fn reset(&mut self) {
118        self.history.clear();
119    }
120}
121
122impl Default for LoopDetector {
123    fn default() -> Self {
124        Self::new(3)
125    }
126}