Skip to main content

agent_runtime/
tool_loop_detection.rs

1use serde_json::Value as JsonValue;
2use std::collections::HashMap;
3
4/// Configuration for detecting and preventing tool call loops
5#[derive(Debug, Clone)]
6pub struct ToolLoopDetectionConfig {
7    /// Whether loop detection is enabled
8    pub enabled: bool,
9
10    /// Custom message to inject when a loop is detected
11    /// If None, uses a default message
12    pub custom_message: Option<String>,
13}
14
15impl Default for ToolLoopDetectionConfig {
16    fn default() -> Self {
17        Self {
18            enabled: true,
19            custom_message: None,
20        }
21    }
22}
23
24impl ToolLoopDetectionConfig {
25    /// Create with loop detection enabled and default message
26    pub fn enabled() -> Self {
27        Self {
28            enabled: true,
29            custom_message: None,
30        }
31    }
32
33    /// Create with loop detection disabled
34    pub fn disabled() -> Self {
35        Self {
36            enabled: false,
37            custom_message: None,
38        }
39    }
40
41    /// Create with a custom message
42    pub fn with_message(message: impl Into<String>) -> Self {
43        Self {
44            enabled: true,
45            custom_message: Some(message.into()),
46        }
47    }
48
49    /// Get the message to use when a loop is detected
50    pub fn get_message(&self, tool_name: &str, previous_result: &JsonValue) -> String {
51        if let Some(custom) = &self.custom_message {
52            // Replace placeholders in custom message
53            custom
54                .replace("{tool_name}", tool_name)
55                .replace("{previous_result}", &previous_result.to_string())
56        } else {
57            // Default message
58            format!(
59                "You already called the tool '{}' with these exact parameters and received a response: {}. \
60                Please use the previous result instead of calling it again. \
61                If you need different information, try calling with different parameters.",
62                tool_name,
63                previous_result
64            )
65        }
66    }
67}
68
69/// Tracks tool calls to detect loops
70#[derive(Debug, Clone, Default)]
71pub struct ToolCallTracker {
72    /// History of (tool_name, args_hash, result) tuples
73    history: Vec<(String, String, JsonValue)>,
74}
75
76impl ToolCallTracker {
77    pub fn new() -> Self {
78        Self {
79            history: Vec::new(),
80        }
81    }
82
83    /// Record a tool call and its result
84    pub fn record_call(
85        &mut self,
86        tool_name: &str,
87        args: &HashMap<String, JsonValue>,
88        result: &JsonValue,
89    ) {
90        let args_hash = Self::hash_args(args);
91        self.history
92            .push((tool_name.to_string(), args_hash, result.clone()));
93    }
94
95    /// Check if this exact tool call (name + args) was made before
96    /// Returns Some(previous_result) if found, None otherwise
97    pub fn check_for_loop(
98        &self,
99        tool_name: &str,
100        args: &HashMap<String, JsonValue>,
101    ) -> Option<JsonValue> {
102        let args_hash = Self::hash_args(args);
103
104        // Look for previous call with same tool + args
105        self.history
106            .iter()
107            .find(|(name, hash, _)| name == tool_name && hash == &args_hash)
108            .map(|(_, _, result)| result.clone())
109    }
110
111    /// Clear the history (e.g., at start of new agent execution)
112    pub fn clear(&mut self) {
113        self.history.clear();
114    }
115
116    /// Create a simple hash of arguments for comparison
117    fn hash_args(args: &HashMap<String, JsonValue>) -> String {
118        // Serialize to JSON for consistent comparison
119        // Sort keys to ensure deterministic ordering
120        let json = serde_json::to_string(args).unwrap_or_default();
121        format!("{:x}", md5::compute(json.as_bytes()))
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use serde_json::json;
129
130    #[test]
131    fn test_loop_detection_config_default_message() {
132        let config = ToolLoopDetectionConfig::default();
133        let result = json!({"data": "test"});
134
135        let message = config.get_message("search", &result);
136        assert!(message.contains("search"));
137        assert!(message.contains("already called"));
138    }
139
140    #[test]
141    fn test_loop_detection_config_custom_message() {
142        let config = ToolLoopDetectionConfig::with_message(
143            "Stop calling {tool_name}! Previous result: {previous_result}",
144        );
145
146        let result = json!({"data": "test"});
147        let message = config.get_message("search", &result);
148
149        assert!(message.contains("Stop calling search!"));
150        assert!(message.contains("test"));
151    }
152
153    #[test]
154    fn test_tracker_detects_loop() {
155        let mut tracker = ToolCallTracker::new();
156
157        let mut args = HashMap::new();
158        args.insert("query".to_string(), json!("test"));
159
160        let result = json!({"found": false});
161
162        // First call - no loop
163        assert!(tracker.check_for_loop("search", &args).is_none());
164
165        // Record it
166        tracker.record_call("search", &args, &result);
167
168        // Second call with same args - loop detected!
169        let previous = tracker.check_for_loop("search", &args);
170        assert!(previous.is_some());
171        assert_eq!(previous.unwrap(), result);
172    }
173
174    #[test]
175    fn test_tracker_different_args_no_loop() {
176        let mut tracker = ToolCallTracker::new();
177
178        let mut args1 = HashMap::new();
179        args1.insert("query".to_string(), json!("test1"));
180
181        let mut args2 = HashMap::new();
182        args2.insert("query".to_string(), json!("test2"));
183
184        tracker.record_call("search", &args1, &json!({}));
185
186        // Different args - no loop
187        assert!(tracker.check_for_loop("search", &args2).is_none());
188    }
189
190    #[test]
191    fn test_tracker_clear() {
192        let mut tracker = ToolCallTracker::new();
193
194        let mut args = HashMap::new();
195        args.insert("query".to_string(), json!("test"));
196
197        tracker.record_call("search", &args, &json!({}));
198        tracker.clear();
199
200        // After clear, no loop detected
201        assert!(tracker.check_for_loop("search", &args).is_none());
202    }
203}