1use std::collections::{BTreeMap, HashMap, HashSet};
2use sha2::{Digest, Sha256};
3
4pub fn canonical_tool_args(args: &serde_json::Value) -> String {
5 fn canonicalize(v: &serde_json::Value) -> serde_json::Value {
6 match v {
7 serde_json::Value::Object(map) => {
8 let mut sorted = BTreeMap::new();
9 for (k, val) in map {
10 sorted.insert(k.clone(), canonicalize(val));
11 }
12 serde_json::Value::Object(sorted.into_iter().collect())
13 }
14 serde_json::Value::Array(arr) => {
15 serde_json::Value::Array(arr.iter().map(canonicalize).collect())
16 }
17 _ => v.clone(),
18 }
19 }
20 serde_json::to_string(&canonicalize(args)).unwrap_or_default()
21}
22
23#[derive(Debug, Clone, serde::Serialize)]
24pub struct ToolCallSignature {
25 pub tool_name: String,
26 pub args_hash: String,
27}
28
29impl ToolCallSignature {
30 pub fn from_call(tool_name: &str, args: Option<&serde_json::Value>) -> Self {
31 let default_val = serde_json::Value::Object(serde_json::Map::new());
32 let val = args.unwrap_or(&default_val);
33 let canonical = canonical_tool_args(val);
34 let mut hasher = Sha256::new();
35 hasher.update(canonical.as_bytes());
36 let result = hasher.finalize();
37 let args_hash = format!("{:x}", result);
38 Self {
39 tool_name: tool_name.to_string(),
40 args_hash,
41 }
42 }
43
44 pub fn to_metadata(&self) -> serde_json::Value {
45 serde_json::json!({
46 "tool_name": self.tool_name,
47 "args_hash": self.args_hash,
48 })
49 }
50}
51
52#[derive(Debug, Clone, serde::Serialize)]
53pub struct ToolGuardrailDecision {
54 pub action: String, pub code: String,
56 pub message: String,
57 pub tool_name: String,
58 pub count: usize,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub signature: Option<ToolCallSignature>,
61}
62
63impl ToolGuardrailDecision {
64 pub fn allows_execution(&self) -> bool {
65 self.action == "allow" || self.action == "warn"
66 }
67
68 pub fn should_halt(&self) -> bool {
69 self.action == "block" || self.action == "halt"
70 }
71
72 pub fn to_metadata(&self) -> serde_json::Value {
73 let mut map = serde_json::json!({
74 "action": self.action,
75 "code": self.code,
76 "message": self.message,
77 "tool_name": self.tool_name,
78 "count": self.count,
79 });
80 if let Some(ref sig) = self.signature {
81 map.as_object_mut().unwrap().insert("signature".to_string(), sig.to_metadata());
82 }
83 map
84 }
85}
86
87#[derive(Debug, Clone)]
88pub struct ToolCallGuardrailConfig {
89 pub warnings_enabled: bool,
90 pub hard_stop_enabled: bool,
91 pub exact_failure_warn_after: usize,
92 pub exact_failure_block_after: usize,
93 pub same_tool_failure_warn_after: usize,
94 pub same_tool_failure_halt_after: usize,
95 pub no_progress_warn_after: usize,
96 pub no_progress_block_after: usize,
97 pub idempotent_tools: HashSet<String>,
98 pub mutating_tools: HashSet<String>,
99}
100
101impl Default for ToolCallGuardrailConfig {
102 fn default() -> Self {
103 let idempotent: HashSet<String> = [
104 "read_file",
105 "search_files",
106 "web_search",
107 "web_extract",
108 "session_search",
109 "browser_snapshot",
110 "browser_console",
111 "browser_get_images",
112 "mcp_filesystem_read_file",
113 "mcp_filesystem_read_text_file",
114 "mcp_filesystem_read_multiple_files",
115 "mcp_filesystem_list_directory",
116 "mcp_filesystem_list_directory_with_sizes",
117 "mcp_filesystem_directory_tree",
118 "mcp_filesystem_get_file_info",
119 "mcp_filesystem_search_files",
120 ]
121 .iter()
122 .map(|s| s.to_string())
123 .collect();
124
125 let mutating: HashSet<String> = [
126 "terminal",
127 "execute_code",
128 "write_file",
129 "patch",
130 "todo",
131 "memory",
132 "skill_manage",
133 "browser_click",
134 "browser_type",
135 "browser_press",
136 "browser_scroll",
137 "browser_navigate",
138 "send_message",
139 "cronjob",
140 "delegate_task",
141 "process",
142 ]
143 .iter()
144 .map(|s| s.to_string())
145 .collect();
146
147 Self {
148 warnings_enabled: true,
149 hard_stop_enabled: false,
150 exact_failure_warn_after: 2,
151 exact_failure_block_after: 5,
152 same_tool_failure_warn_after: 3,
153 same_tool_failure_halt_after: 8,
154 no_progress_warn_after: 2,
155 no_progress_block_after: 5,
156 idempotent_tools: idempotent,
157 mutating_tools: mutating,
158 }
159 }
160}
161
162pub fn file_mutation_result_landed(tool_name: &str, result: &str) -> bool {
163 if tool_name != "write_file" && tool_name != "patch" {
164 return false;
165 }
166 if let Ok(data) = serde_json::from_str::<serde_json::Value>(result.trim()) {
167 if let Some(obj) = data.as_object() {
168 if obj.contains_key("error") && !obj["error"].is_null() && obj["error"] != false {
169 return false;
170 }
171 if tool_name == "write_file" {
172 return obj.contains_key("bytes_written");
173 }
174 if tool_name == "patch" {
175 return obj.get("success").and_then(|v| v.as_bool()) == Some(true);
176 }
177 }
178 }
179 false
180}
181
182pub fn classify_tool_failure(tool_name: &str, result: Option<&str>) -> (bool, String) {
183 let result_str = match result {
184 None => return (false, String::new()),
185 Some(r) => r,
186 };
187 if file_mutation_result_landed(tool_name, result_str) {
188 return (false, String::new());
189 }
190
191 if tool_name == "terminal" {
192 if let Ok(data) = serde_json::from_str::<serde_json::Value>(result_str.trim()) {
193 if let Some(obj) = data.as_object() {
194 if let Some(exit_code) = obj.get("exit_code").and_then(|v| v.as_i64()) {
195 if exit_code != 0 {
196 return (true, format!(" [exit {}]", exit_code));
197 }
198 }
199 }
200 }
201 return (false, String::new());
202 }
203
204 if tool_name == "memory" {
205 if let Ok(data) = serde_json::from_str::<serde_json::Value>(result_str.trim()) {
206 if let Some(obj) = data.as_object() {
207 if obj.get("success").and_then(|v| v.as_bool()) == Some(false) {
208 let err_str = obj.get("error").and_then(|v| v.as_str()).unwrap_or("");
209 if err_str.contains("exceed the limit") {
210 return (true, " [full]".to_string());
211 }
212 }
213 }
214 }
215 }
216
217 let limit = 500.min(result_str.len());
218 let lower = result_str[..limit].to_lowercase();
219 if lower.contains("\"error\"") || lower.contains("\"failed\"") || result_str.starts_with("Error") {
220 return (true, " [error]".to_string());
221 }
222
223 (false, String::new())
224}
225
226#[derive(Debug, Clone)]
227struct ProgressRecord {
228 result_hash: String,
229 repeat_count: usize,
230}
231
232pub struct ToolCallGuardrailController {
233 pub config: ToolCallGuardrailConfig,
234 exact_failure_counts: HashMap<String, usize>,
235 same_tool_failure_counts: HashMap<String, usize>,
236 no_progress: HashMap<String, ProgressRecord>,
237 active_halt_decision: Option<ToolGuardrailDecision>,
238}
239
240impl ToolCallGuardrailController {
241 pub fn new(config: Option<ToolCallGuardrailConfig>) -> Self {
242 Self {
243 config: config.unwrap_or_default(),
244 exact_failure_counts: HashMap::new(),
245 same_tool_failure_counts: HashMap::new(),
246 no_progress: HashMap::new(),
247 active_halt_decision: None,
248 }
249 }
250
251 pub fn reset_for_turn(&mut self) {
252 self.exact_failure_counts.clear();
253 self.same_tool_failure_counts.clear();
254 self.no_progress.clear();
255 self.active_halt_decision = None;
256 }
257
258 pub fn halt_decision(&self) -> Option<&ToolGuardrailDecision> {
259 self.active_halt_decision.as_ref()
260 }
261
262 pub fn before_call(&mut self, tool_name: &str, args: Option<&serde_json::Value>) -> ToolGuardrailDecision {
263 let sig = ToolCallSignature::from_call(tool_name, args);
264 if !self.config.hard_stop_enabled {
265 return ToolGuardrailDecision {
266 action: "allow".to_string(),
267 code: "allow".to_string(),
268 message: String::new(),
269 tool_name: tool_name.to_string(),
270 count: 0,
271 signature: Some(sig),
272 };
273 }
274
275 let exact_count = *self.exact_failure_counts.get(&sig.args_hash).unwrap_or(&0);
276 if exact_count >= self.config.exact_failure_block_after {
277 let dec = ToolGuardrailDecision {
278 action: "block".to_string(),
279 code: "repeated_exact_failure_block".to_string(),
280 message: format!(
281 "Blocked {}: the same tool call failed {} times with identical arguments. Stop retrying it unchanged; change strategy or explain the blocker.",
282 tool_name, exact_count
283 ),
284 tool_name: tool_name.to_string(),
285 count: exact_count,
286 signature: Some(sig),
287 };
288 self.active_halt_decision = Some(dec.clone());
289 return dec;
290 }
291
292 if self.is_idempotent(tool_name) {
293 if let Some(record) = self.no_progress.get(&sig.args_hash) {
294 let repeat_count = record.repeat_count;
295 if repeat_count >= self.config.no_progress_block_after {
296 let dec = ToolGuardrailDecision {
297 action: "block".to_string(),
298 code: "idempotent_no_progress_block".to_string(),
299 message: format!(
300 "Blocked {}: this read-only call returned the same result {} times. Stop repeating it unchanged; use the result already provided or try a different query.",
301 tool_name, repeat_count
302 ),
303 tool_name: tool_name.to_string(),
304 count: repeat_count,
305 signature: Some(sig),
306 };
307 self.active_halt_decision = Some(dec.clone());
308 return dec;
309 }
310 }
311 }
312
313 ToolGuardrailDecision {
314 action: "allow".to_string(),
315 code: "allow".to_string(),
316 message: String::new(),
317 tool_name: tool_name.to_string(),
318 count: 0,
319 signature: Some(sig),
320 }
321 }
322
323 pub fn after_call(
324 &mut self,
325 tool_name: &str,
326 args: Option<&serde_json::Value>,
327 result: Option<&str>,
328 failed: Option<bool>,
329 ) -> ToolGuardrailDecision {
330 let sig = ToolCallSignature::from_call(tool_name, args);
331 let is_failed = failed.unwrap_or_else(|| {
332 let (f, _) = classify_tool_failure(tool_name, result);
333 f
334 });
335
336 if is_failed {
337 let exact_count = self.exact_failure_counts.entry(sig.args_hash.clone()).or_insert(0);
338 *exact_count += 1;
339 let exact_val = *exact_count;
340 self.no_progress.remove(&sig.args_hash);
341
342 let same_count = self.same_tool_failure_counts.entry(tool_name.to_string()).or_insert(0);
343 *same_count += 1;
344 let same_val = *same_count;
345
346 if self.config.hard_stop_enabled && same_val >= self.config.same_tool_failure_halt_after {
347 let dec = ToolGuardrailDecision {
348 action: "halt".to_string(),
349 code: "same_tool_failure_halt".to_string(),
350 message: format!(
351 "Stopped {}: it failed {} times this turn. Stop retrying the same failing tool path and choose a different approach.",
352 tool_name, same_val
353 ),
354 tool_name: tool_name.to_string(),
355 count: same_val,
356 signature: Some(sig),
357 };
358 self.active_halt_decision = Some(dec.clone());
359 return dec;
360 }
361
362 if self.config.warnings_enabled && exact_val >= self.config.exact_failure_warn_after {
363 return ToolGuardrailDecision {
364 action: "warn".to_string(),
365 code: "repeated_exact_failure_warning".to_string(),
366 message: format!(
367 "{} has failed {} times with identical arguments. This looks like a loop; inspect the error and change strategy instead of retrying it unchanged.",
368 tool_name, exact_val
369 ),
370 tool_name: tool_name.to_string(),
371 count: exact_val,
372 signature: Some(sig),
373 };
374 }
375
376 if self.config.warnings_enabled && same_val >= self.config.same_tool_failure_warn_after {
377 return ToolGuardrailDecision {
378 action: "warn".to_string(),
379 code: "same_tool_failure_warning".to_string(),
380 message: self.tool_failure_recovery_hint(tool_name, same_val),
381 tool_name: tool_name.to_string(),
382 count: same_val,
383 signature: Some(sig),
384 };
385 }
386
387 return ToolGuardrailDecision {
388 action: "allow".to_string(),
389 code: "allow".to_string(),
390 message: String::new(),
391 tool_name: tool_name.to_string(),
392 count: exact_val,
393 signature: Some(sig),
394 };
395 }
396
397 self.exact_failure_counts.remove(&sig.args_hash);
398 self.same_tool_failure_counts.remove(tool_name);
399
400 if !self.is_idempotent(tool_name) {
401 self.no_progress.remove(&sig.args_hash);
402 return ToolGuardrailDecision {
403 action: "allow".to_string(),
404 code: "allow".to_string(),
405 message: String::new(),
406 tool_name: tool_name.to_string(),
407 count: 0,
408 signature: Some(sig),
409 };
410 }
411
412 let res_hash = self.compute_result_hash(result);
413 let repeat_count = match self.no_progress.get(&sig.args_hash) {
414 Some(record) if record.result_hash == res_hash => record.repeat_count + 1,
415 _ => 1,
416 };
417
418 self.no_progress.insert(
419 sig.args_hash.clone(),
420 ProgressRecord {
421 result_hash: res_hash,
422 repeat_count,
423 },
424 );
425
426 if self.config.warnings_enabled && repeat_count >= self.config.no_progress_warn_after {
427 return ToolGuardrailDecision {
428 action: "warn".to_string(),
429 code: "idempotent_no_progress_warning".to_string(),
430 message: format!(
431 "{} returned the same result {} times. Use the result already provided or change the query instead of repeating it unchanged.",
432 tool_name, repeat_count
433 ),
434 tool_name: tool_name.to_string(),
435 count: repeat_count,
436 signature: Some(sig),
437 };
438 }
439
440 ToolGuardrailDecision {
441 action: "allow".to_string(),
442 code: "allow".to_string(),
443 message: String::new(),
444 tool_name: tool_name.to_string(),
445 count: repeat_count,
446 signature: Some(sig),
447 }
448 }
449
450 pub fn is_idempotent(&self, tool_name: &str) -> bool {
451 if self.config.mutating_tools.contains(tool_name) {
452 return false;
453 }
454 self.config.idempotent_tools.contains(tool_name)
455 }
456
457 fn compute_result_hash(&self, result: Option<&str>) -> String {
458 let mut canonical = result.unwrap_or("").to_string();
459 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(canonical.trim()) {
460 if parsed.is_object() {
461 canonical = canonical_tool_args(&parsed);
462 }
463 }
464 let mut hasher = Sha256::new();
465 hasher.update(canonical.as_bytes());
466 let res = hasher.finalize();
467 format!("{:x}", res)
468 }
469
470 pub fn tool_failure_recovery_hint(&self, tool_name: &str, count: usize) -> String {
471 let common = format!(
472 "{} has failed {} times this turn. This looks like a loop. Do not switch to text-only replies; keep using tools, but diagnose before retrying. First inspect the latest error/output and verify your assumptions. ",
473 tool_name, count
474 );
475 if tool_name == "terminal" {
476 format!(
477 "{}For terminal failures, run a small diagnostic such as `pwd && ls -la` in the same tool, then try an absolute path, a simpler command, a different working directory, or a different tool such as read_file/write_file/patch.",
478 common
479 )
480 } else {
481 format!(
482 "{}Try different arguments, a narrower query/path, an absolute path when relevant, or a different tool that can make progress. If the blocker is external, report the blocker after one diagnostic attempt instead of repeating the same failing path.",
483 common
484 )
485 }
486 }
487}
488
489pub fn toolguard_synthetic_result(decision: &ToolGuardrailDecision) -> String {
490 serde_json::json!({
491 "error": decision.message,
492 "guardrail": decision.to_metadata(),
493 })
494 .to_string()
495}
496
497pub fn append_toolguard_guidance(result: &str, decision: &ToolGuardrailDecision) -> String {
498 if (decision.action != "warn" && decision.action != "halt") || decision.message.is_empty() {
499 return result.to_string();
500 }
501 let label = if decision.action == "halt" {
502 "Tool loop hard stop"
503 } else {
504 "Tool loop warning"
505 };
506 format!(
507 "{}\n\n[{}: {}; count={}; {}]",
508 result, label, decision.code, decision.count, decision.message
509 )
510}