agent_runtime/
tool_loop_detection.rs1use serde_json::Value as JsonValue;
2use std::collections::HashMap;
3
4#[derive(Debug, Clone)]
6pub struct ToolLoopDetectionConfig {
7 pub enabled: bool,
9
10 pub custom_message: Option<String>,
13}
14
15impl Default for ToolLoopDetectionConfig {
16 fn default() -> Self {
17 Self {
18 enabled: true,
19 custom_message: None,
20 }
21 }
22}
23
24impl ToolLoopDetectionConfig {
25 pub fn enabled() -> Self {
27 Self {
28 enabled: true,
29 custom_message: None,
30 }
31 }
32
33 pub fn disabled() -> Self {
35 Self {
36 enabled: false,
37 custom_message: None,
38 }
39 }
40
41 pub fn with_message(message: impl Into<String>) -> Self {
43 Self {
44 enabled: true,
45 custom_message: Some(message.into()),
46 }
47 }
48
49 pub fn get_message(&self, tool_name: &str, previous_result: &JsonValue) -> String {
51 if let Some(custom) = &self.custom_message {
52 custom
54 .replace("{tool_name}", tool_name)
55 .replace("{previous_result}", &previous_result.to_string())
56 } else {
57 format!(
59 "You already called the tool '{}' with these exact parameters and received a response: {}. \
60 Please use the previous result instead of calling it again. \
61 If you need different information, try calling with different parameters.",
62 tool_name,
63 previous_result
64 )
65 }
66 }
67}
68
69#[derive(Debug, Clone, Default)]
71pub struct ToolCallTracker {
72 history: Vec<(String, String, JsonValue)>,
74}
75
76impl ToolCallTracker {
77 pub fn new() -> Self {
78 Self {
79 history: Vec::new(),
80 }
81 }
82
83 pub fn record_call(
85 &mut self,
86 tool_name: &str,
87 args: &HashMap<String, JsonValue>,
88 result: &JsonValue,
89 ) {
90 let args_hash = Self::hash_args(args);
91 self.history
92 .push((tool_name.to_string(), args_hash, result.clone()));
93 }
94
95 pub fn check_for_loop(
98 &self,
99 tool_name: &str,
100 args: &HashMap<String, JsonValue>,
101 ) -> Option<JsonValue> {
102 let args_hash = Self::hash_args(args);
103
104 self.history
106 .iter()
107 .find(|(name, hash, _)| name == tool_name && hash == &args_hash)
108 .map(|(_, _, result)| result.clone())
109 }
110
111 pub fn clear(&mut self) {
113 self.history.clear();
114 }
115
116 fn hash_args(args: &HashMap<String, JsonValue>) -> String {
118 let json = serde_json::to_string(args).unwrap_or_default();
121 format!("{:x}", md5::compute(json.as_bytes()))
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use serde_json::json;
129
130 #[test]
131 fn test_loop_detection_config_default_message() {
132 let config = ToolLoopDetectionConfig::default();
133 let result = json!({"data": "test"});
134
135 let message = config.get_message("search", &result);
136 assert!(message.contains("search"));
137 assert!(message.contains("already called"));
138 }
139
140 #[test]
141 fn test_loop_detection_config_custom_message() {
142 let config = ToolLoopDetectionConfig::with_message(
143 "Stop calling {tool_name}! Previous result: {previous_result}",
144 );
145
146 let result = json!({"data": "test"});
147 let message = config.get_message("search", &result);
148
149 assert!(message.contains("Stop calling search!"));
150 assert!(message.contains("test"));
151 }
152
153 #[test]
154 fn test_tracker_detects_loop() {
155 let mut tracker = ToolCallTracker::new();
156
157 let mut args = HashMap::new();
158 args.insert("query".to_string(), json!("test"));
159
160 let result = json!({"found": false});
161
162 assert!(tracker.check_for_loop("search", &args).is_none());
164
165 tracker.record_call("search", &args, &result);
167
168 let previous = tracker.check_for_loop("search", &args);
170 assert!(previous.is_some());
171 assert_eq!(previous.unwrap(), result);
172 }
173
174 #[test]
175 fn test_tracker_different_args_no_loop() {
176 let mut tracker = ToolCallTracker::new();
177
178 let mut args1 = HashMap::new();
179 args1.insert("query".to_string(), json!("test1"));
180
181 let mut args2 = HashMap::new();
182 args2.insert("query".to_string(), json!("test2"));
183
184 tracker.record_call("search", &args1, &json!({}));
185
186 assert!(tracker.check_for_loop("search", &args2).is_none());
188 }
189
190 #[test]
191 fn test_tracker_clear() {
192 let mut tracker = ToolCallTracker::new();
193
194 let mut args = HashMap::new();
195 args.insert("query".to_string(), json!("test"));
196
197 tracker.record_call("search", &args, &json!({}));
198 tracker.clear();
199
200 assert!(tracker.check_for_loop("search", &args).is_none());
202 }
203}