Skip to main content

loop_guardrail/
lib.rs

1use std::collections::{BTreeMap, HashMap, HashSet};
2use sha2::{Digest, Sha256};
3
4pub fn canonical_tool_args(args: &serde_json::Value) -> String {
5    fn canonicalize(v: &serde_json::Value) -> serde_json::Value {
6        match v {
7            serde_json::Value::Object(map) => {
8                let mut sorted = BTreeMap::new();
9                for (k, val) in map {
10                    sorted.insert(k.clone(), canonicalize(val));
11                }
12                serde_json::Value::Object(sorted.into_iter().collect())
13            }
14            serde_json::Value::Array(arr) => {
15                serde_json::Value::Array(arr.iter().map(canonicalize).collect())
16            }
17            _ => v.clone(),
18        }
19    }
20    serde_json::to_string(&canonicalize(args)).unwrap_or_default()
21}
22
23#[derive(Debug, Clone, serde::Serialize)]
24pub struct ToolCallSignature {
25    pub tool_name: String,
26    pub args_hash: String,
27}
28
29impl ToolCallSignature {
30    pub fn from_call(tool_name: &str, args: Option<&serde_json::Value>) -> Self {
31        let default_val = serde_json::Value::Object(serde_json::Map::new());
32        let val = args.unwrap_or(&default_val);
33        let canonical = canonical_tool_args(val);
34        let mut hasher = Sha256::new();
35        hasher.update(canonical.as_bytes());
36        let result = hasher.finalize();
37        let args_hash = format!("{:x}", result);
38        Self {
39            tool_name: tool_name.to_string(),
40            args_hash,
41        }
42    }
43
44    pub fn to_metadata(&self) -> serde_json::Value {
45        serde_json::json!({
46            "tool_name": self.tool_name,
47            "args_hash": self.args_hash,
48        })
49    }
50}
51
52#[derive(Debug, Clone, serde::Serialize)]
53pub struct ToolGuardrailDecision {
54    pub action: String, // "allow" | "warn" | "block" | "halt"
55    pub code: String,
56    pub message: String,
57    pub tool_name: String,
58    pub count: usize,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub signature: Option<ToolCallSignature>,
61}
62
63impl ToolGuardrailDecision {
64    pub fn allows_execution(&self) -> bool {
65        self.action == "allow" || self.action == "warn"
66    }
67
68    pub fn should_halt(&self) -> bool {
69        self.action == "block" || self.action == "halt"
70    }
71
72    pub fn to_metadata(&self) -> serde_json::Value {
73        let mut map = serde_json::json!({
74            "action": self.action,
75            "code": self.code,
76            "message": self.message,
77            "tool_name": self.tool_name,
78            "count": self.count,
79        });
80        if let Some(ref sig) = self.signature {
81            map.as_object_mut().unwrap().insert("signature".to_string(), sig.to_metadata());
82        }
83        map
84    }
85}
86
87#[derive(Debug, Clone)]
88pub struct ToolCallGuardrailConfig {
89    pub warnings_enabled: bool,
90    pub hard_stop_enabled: bool,
91    pub exact_failure_warn_after: usize,
92    pub exact_failure_block_after: usize,
93    pub same_tool_failure_warn_after: usize,
94    pub same_tool_failure_halt_after: usize,
95    pub no_progress_warn_after: usize,
96    pub no_progress_block_after: usize,
97    pub idempotent_tools: HashSet<String>,
98    pub mutating_tools: HashSet<String>,
99}
100
101impl Default for ToolCallGuardrailConfig {
102    fn default() -> Self {
103        let idempotent: HashSet<String> = [
104            "read_file",
105            "search_files",
106            "web_search",
107            "web_extract",
108            "session_search",
109            "browser_snapshot",
110            "browser_console",
111            "browser_get_images",
112            "mcp_filesystem_read_file",
113            "mcp_filesystem_read_text_file",
114            "mcp_filesystem_read_multiple_files",
115            "mcp_filesystem_list_directory",
116            "mcp_filesystem_list_directory_with_sizes",
117            "mcp_filesystem_directory_tree",
118            "mcp_filesystem_get_file_info",
119            "mcp_filesystem_search_files",
120        ]
121        .iter()
122        .map(|s| s.to_string())
123        .collect();
124
125        let mutating: HashSet<String> = [
126            "terminal",
127            "execute_code",
128            "write_file",
129            "patch",
130            "todo",
131            "memory",
132            "skill_manage",
133            "browser_click",
134            "browser_type",
135            "browser_press",
136            "browser_scroll",
137            "browser_navigate",
138            "send_message",
139            "cronjob",
140            "delegate_task",
141            "process",
142        ]
143        .iter()
144        .map(|s| s.to_string())
145        .collect();
146
147        Self {
148            warnings_enabled: true,
149            hard_stop_enabled: false,
150            exact_failure_warn_after: 2,
151            exact_failure_block_after: 5,
152            same_tool_failure_warn_after: 3,
153            same_tool_failure_halt_after: 8,
154            no_progress_warn_after: 2,
155            no_progress_block_after: 5,
156            idempotent_tools: idempotent,
157            mutating_tools: mutating,
158        }
159    }
160}
161
162pub fn file_mutation_result_landed(tool_name: &str, result: &str) -> bool {
163    if tool_name != "write_file" && tool_name != "patch" {
164        return false;
165    }
166    if let Ok(data) = serde_json::from_str::<serde_json::Value>(result.trim()) {
167        if let Some(obj) = data.as_object() {
168            if obj.contains_key("error") && !obj["error"].is_null() && obj["error"] != false {
169                return false;
170            }
171            if tool_name == "write_file" {
172                return obj.contains_key("bytes_written");
173            }
174            if tool_name == "patch" {
175                return obj.get("success").and_then(|v| v.as_bool()) == Some(true);
176            }
177        }
178    }
179    false
180}
181
182pub fn classify_tool_failure(tool_name: &str, result: Option<&str>) -> (bool, String) {
183    let result_str = match result {
184        None => return (false, String::new()),
185        Some(r) => r,
186    };
187    if file_mutation_result_landed(tool_name, result_str) {
188        return (false, String::new());
189    }
190
191    if tool_name == "terminal" {
192        if let Ok(data) = serde_json::from_str::<serde_json::Value>(result_str.trim()) {
193            if let Some(obj) = data.as_object() {
194                if let Some(exit_code) = obj.get("exit_code").and_then(|v| v.as_i64()) {
195                    if exit_code != 0 {
196                        return (true, format!(" [exit {}]", exit_code));
197                    }
198                }
199            }
200        }
201        return (false, String::new());
202    }
203
204    if tool_name == "memory" {
205        if let Ok(data) = serde_json::from_str::<serde_json::Value>(result_str.trim()) {
206            if let Some(obj) = data.as_object() {
207                if obj.get("success").and_then(|v| v.as_bool()) == Some(false) {
208                    let err_str = obj.get("error").and_then(|v| v.as_str()).unwrap_or("");
209                    if err_str.contains("exceed the limit") {
210                        return (true, " [full]".to_string());
211                    }
212                }
213            }
214        }
215    }
216
217    let limit = 500.min(result_str.len());
218    let lower = result_str[..limit].to_lowercase();
219    if lower.contains("\"error\"") || lower.contains("\"failed\"") || result_str.starts_with("Error") {
220        return (true, " [error]".to_string());
221    }
222
223    (false, String::new())
224}
225
226#[derive(Debug, Clone)]
227struct ProgressRecord {
228    result_hash: String,
229    repeat_count: usize,
230}
231
232pub struct ToolCallGuardrailController {
233    pub config: ToolCallGuardrailConfig,
234    exact_failure_counts: HashMap<String, usize>,
235    same_tool_failure_counts: HashMap<String, usize>,
236    no_progress: HashMap<String, ProgressRecord>,
237    active_halt_decision: Option<ToolGuardrailDecision>,
238}
239
240impl ToolCallGuardrailController {
241    pub fn new(config: Option<ToolCallGuardrailConfig>) -> Self {
242        Self {
243            config: config.unwrap_or_default(),
244            exact_failure_counts: HashMap::new(),
245            same_tool_failure_counts: HashMap::new(),
246            no_progress: HashMap::new(),
247            active_halt_decision: None,
248        }
249    }
250
251    pub fn reset_for_turn(&mut self) {
252        self.exact_failure_counts.clear();
253        self.same_tool_failure_counts.clear();
254        self.no_progress.clear();
255        self.active_halt_decision = None;
256    }
257
258    pub fn halt_decision(&self) -> Option<&ToolGuardrailDecision> {
259        self.active_halt_decision.as_ref()
260    }
261
262    pub fn before_call(&mut self, tool_name: &str, args: Option<&serde_json::Value>) -> ToolGuardrailDecision {
263        let sig = ToolCallSignature::from_call(tool_name, args);
264        if !self.config.hard_stop_enabled {
265            return ToolGuardrailDecision {
266                action: "allow".to_string(),
267                code: "allow".to_string(),
268                message: String::new(),
269                tool_name: tool_name.to_string(),
270                count: 0,
271                signature: Some(sig),
272            };
273        }
274
275        let exact_count = *self.exact_failure_counts.get(&sig.args_hash).unwrap_or(&0);
276        if exact_count >= self.config.exact_failure_block_after {
277            let dec = ToolGuardrailDecision {
278                action: "block".to_string(),
279                code: "repeated_exact_failure_block".to_string(),
280                message: format!(
281                    "Blocked {}: the same tool call failed {} times with identical arguments. Stop retrying it unchanged; change strategy or explain the blocker.",
282                    tool_name, exact_count
283                ),
284                tool_name: tool_name.to_string(),
285                count: exact_count,
286                signature: Some(sig),
287            };
288            self.active_halt_decision = Some(dec.clone());
289            return dec;
290        }
291
292        if self.is_idempotent(tool_name) {
293            if let Some(record) = self.no_progress.get(&sig.args_hash) {
294                let repeat_count = record.repeat_count;
295                if repeat_count >= self.config.no_progress_block_after {
296                    let dec = ToolGuardrailDecision {
297                        action: "block".to_string(),
298                        code: "idempotent_no_progress_block".to_string(),
299                        message: format!(
300                            "Blocked {}: this read-only call returned the same result {} times. Stop repeating it unchanged; use the result already provided or try a different query.",
301                            tool_name, repeat_count
302                        ),
303                        tool_name: tool_name.to_string(),
304                        count: repeat_count,
305                        signature: Some(sig),
306                    };
307                    self.active_halt_decision = Some(dec.clone());
308                    return dec;
309                }
310            }
311        }
312
313        ToolGuardrailDecision {
314            action: "allow".to_string(),
315            code: "allow".to_string(),
316            message: String::new(),
317            tool_name: tool_name.to_string(),
318            count: 0,
319            signature: Some(sig),
320        }
321    }
322
323    pub fn after_call(
324        &mut self,
325        tool_name: &str,
326        args: Option<&serde_json::Value>,
327        result: Option<&str>,
328        failed: Option<bool>,
329    ) -> ToolGuardrailDecision {
330        let sig = ToolCallSignature::from_call(tool_name, args);
331        let is_failed = failed.unwrap_or_else(|| {
332            let (f, _) = classify_tool_failure(tool_name, result);
333            f
334        });
335
336        if is_failed {
337            let exact_count = self.exact_failure_counts.entry(sig.args_hash.clone()).or_insert(0);
338            *exact_count += 1;
339            let exact_val = *exact_count;
340            self.no_progress.remove(&sig.args_hash);
341
342            let same_count = self.same_tool_failure_counts.entry(tool_name.to_string()).or_insert(0);
343            *same_count += 1;
344            let same_val = *same_count;
345
346            if self.config.hard_stop_enabled && same_val >= self.config.same_tool_failure_halt_after {
347                let dec = ToolGuardrailDecision {
348                    action: "halt".to_string(),
349                    code: "same_tool_failure_halt".to_string(),
350                    message: format!(
351                        "Stopped {}: it failed {} times this turn. Stop retrying the same failing tool path and choose a different approach.",
352                        tool_name, same_val
353                    ),
354                    tool_name: tool_name.to_string(),
355                    count: same_val,
356                    signature: Some(sig),
357                };
358                self.active_halt_decision = Some(dec.clone());
359                return dec;
360            }
361
362            if self.config.warnings_enabled && exact_val >= self.config.exact_failure_warn_after {
363                return ToolGuardrailDecision {
364                    action: "warn".to_string(),
365                    code: "repeated_exact_failure_warning".to_string(),
366                    message: format!(
367                        "{} has failed {} times with identical arguments. This looks like a loop; inspect the error and change strategy instead of retrying it unchanged.",
368                        tool_name, exact_val
369                    ),
370                    tool_name: tool_name.to_string(),
371                    count: exact_val,
372                    signature: Some(sig),
373                };
374            }
375
376            if self.config.warnings_enabled && same_val >= self.config.same_tool_failure_warn_after {
377                return ToolGuardrailDecision {
378                    action: "warn".to_string(),
379                    code: "same_tool_failure_warning".to_string(),
380                    message: self.tool_failure_recovery_hint(tool_name, same_val),
381                    tool_name: tool_name.to_string(),
382                    count: same_val,
383                    signature: Some(sig),
384                };
385            }
386
387            return ToolGuardrailDecision {
388                action: "allow".to_string(),
389                code: "allow".to_string(),
390                message: String::new(),
391                tool_name: tool_name.to_string(),
392                count: exact_val,
393                signature: Some(sig),
394            };
395        }
396
397        self.exact_failure_counts.remove(&sig.args_hash);
398        self.same_tool_failure_counts.remove(tool_name);
399
400        if !self.is_idempotent(tool_name) {
401            self.no_progress.remove(&sig.args_hash);
402            return ToolGuardrailDecision {
403                action: "allow".to_string(),
404                code: "allow".to_string(),
405                message: String::new(),
406                tool_name: tool_name.to_string(),
407                count: 0,
408                signature: Some(sig),
409            };
410        }
411
412        let res_hash = self.compute_result_hash(result);
413        let repeat_count = match self.no_progress.get(&sig.args_hash) {
414            Some(record) if record.result_hash == res_hash => record.repeat_count + 1,
415            _ => 1,
416        };
417
418        self.no_progress.insert(
419            sig.args_hash.clone(),
420            ProgressRecord {
421                result_hash: res_hash,
422                repeat_count,
423            },
424        );
425
426        if self.config.warnings_enabled && repeat_count >= self.config.no_progress_warn_after {
427            return ToolGuardrailDecision {
428                action: "warn".to_string(),
429                code: "idempotent_no_progress_warning".to_string(),
430                message: format!(
431                    "{} returned the same result {} times. Use the result already provided or change the query instead of repeating it unchanged.",
432                    tool_name, repeat_count
433                ),
434                tool_name: tool_name.to_string(),
435                count: repeat_count,
436                signature: Some(sig),
437            };
438        }
439
440        ToolGuardrailDecision {
441            action: "allow".to_string(),
442            code: "allow".to_string(),
443            message: String::new(),
444            tool_name: tool_name.to_string(),
445            count: repeat_count,
446            signature: Some(sig),
447        }
448    }
449
450    pub fn is_idempotent(&self, tool_name: &str) -> bool {
451        if self.config.mutating_tools.contains(tool_name) {
452            return false;
453        }
454        self.config.idempotent_tools.contains(tool_name)
455    }
456
457    fn compute_result_hash(&self, result: Option<&str>) -> String {
458        let mut canonical = result.unwrap_or("").to_string();
459        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(canonical.trim()) {
460            if parsed.is_object() {
461                canonical = canonical_tool_args(&parsed);
462            }
463        }
464        let mut hasher = Sha256::new();
465        hasher.update(canonical.as_bytes());
466        let res = hasher.finalize();
467        format!("{:x}", res)
468    }
469
470    pub fn tool_failure_recovery_hint(&self, tool_name: &str, count: usize) -> String {
471        let common = format!(
472            "{} has failed {} times this turn. This looks like a loop. Do not switch to text-only replies; keep using tools, but diagnose before retrying. First inspect the latest error/output and verify your assumptions. ",
473            tool_name, count
474        );
475        if tool_name == "terminal" {
476            format!(
477                "{}For terminal failures, run a small diagnostic such as `pwd && ls -la` in the same tool, then try an absolute path, a simpler command, a different working directory, or a different tool such as read_file/write_file/patch.",
478                common
479            )
480        } else {
481            format!(
482                "{}Try different arguments, a narrower query/path, an absolute path when relevant, or a different tool that can make progress. If the blocker is external, report the blocker after one diagnostic attempt instead of repeating the same failing path.",
483                common
484            )
485        }
486    }
487}
488
489pub fn toolguard_synthetic_result(decision: &ToolGuardrailDecision) -> String {
490    serde_json::json!({
491        "error": decision.message,
492        "guardrail": decision.to_metadata(),
493    })
494    .to_string()
495}
496
497pub fn append_toolguard_guidance(result: &str, decision: &ToolGuardrailDecision) -> String {
498    if (decision.action != "warn" && decision.action != "halt") || decision.message.is_empty() {
499        return result.to_string();
500    }
501    let label = if decision.action == "halt" {
502        "Tool loop hard stop"
503    } else {
504        "Tool loop warning"
505    };
506    format!(
507        "{}\n\n[{}: {}; count={}; {}]",
508        result, label, decision.code, decision.count, decision.message
509    )
510}