use std::collections::{BTreeMap, HashMap, HashSet};
use sha2::{Digest, Sha256};
pub fn canonical_tool_args(args: &serde_json::Value) -> String {
fn canonicalize(v: &serde_json::Value) -> serde_json::Value {
match v {
serde_json::Value::Object(map) => {
let mut sorted = BTreeMap::new();
for (k, val) in map {
sorted.insert(k.clone(), canonicalize(val));
}
serde_json::Value::Object(sorted.into_iter().collect())
}
serde_json::Value::Array(arr) => {
serde_json::Value::Array(arr.iter().map(canonicalize).collect())
}
_ => v.clone(),
}
}
serde_json::to_string(&canonicalize(args)).unwrap_or_default()
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ToolCallSignature {
pub tool_name: String,
pub args_hash: String,
}
impl ToolCallSignature {
pub fn from_call(tool_name: &str, args: Option<&serde_json::Value>) -> Self {
let default_val = serde_json::Value::Object(serde_json::Map::new());
let val = args.unwrap_or(&default_val);
let canonical = canonical_tool_args(val);
let mut hasher = Sha256::new();
hasher.update(canonical.as_bytes());
let result = hasher.finalize();
let args_hash = format!("{:x}", result);
Self {
tool_name: tool_name.to_string(),
args_hash,
}
}
pub fn to_metadata(&self) -> serde_json::Value {
serde_json::json!({
"tool_name": self.tool_name,
"args_hash": self.args_hash,
})
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ToolGuardrailDecision {
pub action: String, pub code: String,
pub message: String,
pub tool_name: String,
pub count: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub signature: Option<ToolCallSignature>,
}
impl ToolGuardrailDecision {
pub fn allows_execution(&self) -> bool {
self.action == "allow" || self.action == "warn"
}
pub fn should_halt(&self) -> bool {
self.action == "block" || self.action == "halt"
}
pub fn to_metadata(&self) -> serde_json::Value {
let mut map = serde_json::json!({
"action": self.action,
"code": self.code,
"message": self.message,
"tool_name": self.tool_name,
"count": self.count,
});
if let Some(ref sig) = self.signature {
map.as_object_mut().unwrap().insert("signature".to_string(), sig.to_metadata());
}
map
}
}
#[derive(Debug, Clone)]
pub struct ToolCallGuardrailConfig {
pub warnings_enabled: bool,
pub hard_stop_enabled: bool,
pub exact_failure_warn_after: usize,
pub exact_failure_block_after: usize,
pub same_tool_failure_warn_after: usize,
pub same_tool_failure_halt_after: usize,
pub no_progress_warn_after: usize,
pub no_progress_block_after: usize,
pub idempotent_tools: HashSet<String>,
pub mutating_tools: HashSet<String>,
}
impl Default for ToolCallGuardrailConfig {
fn default() -> Self {
let idempotent: HashSet<String> = [
"read_file",
"search_files",
"web_search",
"web_extract",
"session_search",
"browser_snapshot",
"browser_console",
"browser_get_images",
"mcp_filesystem_read_file",
"mcp_filesystem_read_text_file",
"mcp_filesystem_read_multiple_files",
"mcp_filesystem_list_directory",
"mcp_filesystem_list_directory_with_sizes",
"mcp_filesystem_directory_tree",
"mcp_filesystem_get_file_info",
"mcp_filesystem_search_files",
]
.iter()
.map(|s| s.to_string())
.collect();
let mutating: HashSet<String> = [
"terminal",
"execute_code",
"write_file",
"patch",
"todo",
"memory",
"skill_manage",
"browser_click",
"browser_type",
"browser_press",
"browser_scroll",
"browser_navigate",
"send_message",
"cronjob",
"delegate_task",
"process",
]
.iter()
.map(|s| s.to_string())
.collect();
Self {
warnings_enabled: true,
hard_stop_enabled: false,
exact_failure_warn_after: 2,
exact_failure_block_after: 5,
same_tool_failure_warn_after: 3,
same_tool_failure_halt_after: 8,
no_progress_warn_after: 2,
no_progress_block_after: 5,
idempotent_tools: idempotent,
mutating_tools: mutating,
}
}
}
pub fn file_mutation_result_landed(tool_name: &str, result: &str) -> bool {
if tool_name != "write_file" && tool_name != "patch" {
return false;
}
if let Ok(data) = serde_json::from_str::<serde_json::Value>(result.trim()) {
if let Some(obj) = data.as_object() {
if obj.contains_key("error") && !obj["error"].is_null() && obj["error"] != false {
return false;
}
if tool_name == "write_file" {
return obj.contains_key("bytes_written");
}
if tool_name == "patch" {
return obj.get("success").and_then(|v| v.as_bool()) == Some(true);
}
}
}
false
}
pub fn classify_tool_failure(tool_name: &str, result: Option<&str>) -> (bool, String) {
let result_str = match result {
None => return (false, String::new()),
Some(r) => r,
};
if file_mutation_result_landed(tool_name, result_str) {
return (false, String::new());
}
if tool_name == "terminal" {
if let Ok(data) = serde_json::from_str::<serde_json::Value>(result_str.trim()) {
if let Some(obj) = data.as_object() {
if let Some(exit_code) = obj.get("exit_code").and_then(|v| v.as_i64()) {
if exit_code != 0 {
return (true, format!(" [exit {}]", exit_code));
}
}
}
}
return (false, String::new());
}
if tool_name == "memory" {
if let Ok(data) = serde_json::from_str::<serde_json::Value>(result_str.trim()) {
if let Some(obj) = data.as_object() {
if obj.get("success").and_then(|v| v.as_bool()) == Some(false) {
let err_str = obj.get("error").and_then(|v| v.as_str()).unwrap_or("");
if err_str.contains("exceed the limit") {
return (true, " [full]".to_string());
}
}
}
}
}
let limit = 500.min(result_str.len());
let lower = result_str[..limit].to_lowercase();
if lower.contains("\"error\"") || lower.contains("\"failed\"") || result_str.starts_with("Error") {
return (true, " [error]".to_string());
}
(false, String::new())
}
#[derive(Debug, Clone)]
struct ProgressRecord {
result_hash: String,
repeat_count: usize,
}
pub struct ToolCallGuardrailController {
pub config: ToolCallGuardrailConfig,
exact_failure_counts: HashMap<String, usize>,
same_tool_failure_counts: HashMap<String, usize>,
no_progress: HashMap<String, ProgressRecord>,
active_halt_decision: Option<ToolGuardrailDecision>,
}
impl ToolCallGuardrailController {
pub fn new(config: Option<ToolCallGuardrailConfig>) -> Self {
Self {
config: config.unwrap_or_default(),
exact_failure_counts: HashMap::new(),
same_tool_failure_counts: HashMap::new(),
no_progress: HashMap::new(),
active_halt_decision: None,
}
}
pub fn reset_for_turn(&mut self) {
self.exact_failure_counts.clear();
self.same_tool_failure_counts.clear();
self.no_progress.clear();
self.active_halt_decision = None;
}
pub fn halt_decision(&self) -> Option<&ToolGuardrailDecision> {
self.active_halt_decision.as_ref()
}
pub fn before_call(&mut self, tool_name: &str, args: Option<&serde_json::Value>) -> ToolGuardrailDecision {
let sig = ToolCallSignature::from_call(tool_name, args);
if !self.config.hard_stop_enabled {
return ToolGuardrailDecision {
action: "allow".to_string(),
code: "allow".to_string(),
message: String::new(),
tool_name: tool_name.to_string(),
count: 0,
signature: Some(sig),
};
}
let exact_count = *self.exact_failure_counts.get(&sig.args_hash).unwrap_or(&0);
if exact_count >= self.config.exact_failure_block_after {
let dec = ToolGuardrailDecision {
action: "block".to_string(),
code: "repeated_exact_failure_block".to_string(),
message: format!(
"Blocked {}: the same tool call failed {} times with identical arguments. Stop retrying it unchanged; change strategy or explain the blocker.",
tool_name, exact_count
),
tool_name: tool_name.to_string(),
count: exact_count,
signature: Some(sig),
};
self.active_halt_decision = Some(dec.clone());
return dec;
}
if self.is_idempotent(tool_name) {
if let Some(record) = self.no_progress.get(&sig.args_hash) {
let repeat_count = record.repeat_count;
if repeat_count >= self.config.no_progress_block_after {
let dec = ToolGuardrailDecision {
action: "block".to_string(),
code: "idempotent_no_progress_block".to_string(),
message: format!(
"Blocked {}: this read-only call returned the same result {} times. Stop repeating it unchanged; use the result already provided or try a different query.",
tool_name, repeat_count
),
tool_name: tool_name.to_string(),
count: repeat_count,
signature: Some(sig),
};
self.active_halt_decision = Some(dec.clone());
return dec;
}
}
}
ToolGuardrailDecision {
action: "allow".to_string(),
code: "allow".to_string(),
message: String::new(),
tool_name: tool_name.to_string(),
count: 0,
signature: Some(sig),
}
}
pub fn after_call(
&mut self,
tool_name: &str,
args: Option<&serde_json::Value>,
result: Option<&str>,
failed: Option<bool>,
) -> ToolGuardrailDecision {
let sig = ToolCallSignature::from_call(tool_name, args);
let is_failed = failed.unwrap_or_else(|| {
let (f, _) = classify_tool_failure(tool_name, result);
f
});
if is_failed {
let exact_count = self.exact_failure_counts.entry(sig.args_hash.clone()).or_insert(0);
*exact_count += 1;
let exact_val = *exact_count;
self.no_progress.remove(&sig.args_hash);
let same_count = self.same_tool_failure_counts.entry(tool_name.to_string()).or_insert(0);
*same_count += 1;
let same_val = *same_count;
if self.config.hard_stop_enabled && same_val >= self.config.same_tool_failure_halt_after {
let dec = ToolGuardrailDecision {
action: "halt".to_string(),
code: "same_tool_failure_halt".to_string(),
message: format!(
"Stopped {}: it failed {} times this turn. Stop retrying the same failing tool path and choose a different approach.",
tool_name, same_val
),
tool_name: tool_name.to_string(),
count: same_val,
signature: Some(sig),
};
self.active_halt_decision = Some(dec.clone());
return dec;
}
if self.config.warnings_enabled && exact_val >= self.config.exact_failure_warn_after {
return ToolGuardrailDecision {
action: "warn".to_string(),
code: "repeated_exact_failure_warning".to_string(),
message: format!(
"{} has failed {} times with identical arguments. This looks like a loop; inspect the error and change strategy instead of retrying it unchanged.",
tool_name, exact_val
),
tool_name: tool_name.to_string(),
count: exact_val,
signature: Some(sig),
};
}
if self.config.warnings_enabled && same_val >= self.config.same_tool_failure_warn_after {
return ToolGuardrailDecision {
action: "warn".to_string(),
code: "same_tool_failure_warning".to_string(),
message: self.tool_failure_recovery_hint(tool_name, same_val),
tool_name: tool_name.to_string(),
count: same_val,
signature: Some(sig),
};
}
return ToolGuardrailDecision {
action: "allow".to_string(),
code: "allow".to_string(),
message: String::new(),
tool_name: tool_name.to_string(),
count: exact_val,
signature: Some(sig),
};
}
self.exact_failure_counts.remove(&sig.args_hash);
self.same_tool_failure_counts.remove(tool_name);
if !self.is_idempotent(tool_name) {
self.no_progress.remove(&sig.args_hash);
return ToolGuardrailDecision {
action: "allow".to_string(),
code: "allow".to_string(),
message: String::new(),
tool_name: tool_name.to_string(),
count: 0,
signature: Some(sig),
};
}
let res_hash = self.compute_result_hash(result);
let repeat_count = match self.no_progress.get(&sig.args_hash) {
Some(record) if record.result_hash == res_hash => record.repeat_count + 1,
_ => 1,
};
self.no_progress.insert(
sig.args_hash.clone(),
ProgressRecord {
result_hash: res_hash,
repeat_count,
},
);
if self.config.warnings_enabled && repeat_count >= self.config.no_progress_warn_after {
return ToolGuardrailDecision {
action: "warn".to_string(),
code: "idempotent_no_progress_warning".to_string(),
message: format!(
"{} returned the same result {} times. Use the result already provided or change the query instead of repeating it unchanged.",
tool_name, repeat_count
),
tool_name: tool_name.to_string(),
count: repeat_count,
signature: Some(sig),
};
}
ToolGuardrailDecision {
action: "allow".to_string(),
code: "allow".to_string(),
message: String::new(),
tool_name: tool_name.to_string(),
count: repeat_count,
signature: Some(sig),
}
}
pub fn is_idempotent(&self, tool_name: &str) -> bool {
if self.config.mutating_tools.contains(tool_name) {
return false;
}
self.config.idempotent_tools.contains(tool_name)
}
fn compute_result_hash(&self, result: Option<&str>) -> String {
let mut canonical = result.unwrap_or("").to_string();
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(canonical.trim()) {
if parsed.is_object() {
canonical = canonical_tool_args(&parsed);
}
}
let mut hasher = Sha256::new();
hasher.update(canonical.as_bytes());
let res = hasher.finalize();
format!("{:x}", res)
}
pub fn tool_failure_recovery_hint(&self, tool_name: &str, count: usize) -> String {
let common = format!(
"{} has failed {} times this turn. This looks like a loop. Do not switch to text-only replies; keep using tools, but diagnose before retrying. First inspect the latest error/output and verify your assumptions. ",
tool_name, count
);
if tool_name == "terminal" {
format!(
"{}For terminal failures, run a small diagnostic such as `pwd && ls -la` in the same tool, then try an absolute path, a simpler command, a different working directory, or a different tool such as read_file/write_file/patch.",
common
)
} else {
format!(
"{}Try different arguments, a narrower query/path, an absolute path when relevant, or a different tool that can make progress. If the blocker is external, report the blocker after one diagnostic attempt instead of repeating the same failing path.",
common
)
}
}
}
pub fn toolguard_synthetic_result(decision: &ToolGuardrailDecision) -> String {
serde_json::json!({
"error": decision.message,
"guardrail": decision.to_metadata(),
})
.to_string()
}
pub fn append_toolguard_guidance(result: &str, decision: &ToolGuardrailDecision) -> String {
if (decision.action != "warn" && decision.action != "halt") || decision.message.is_empty() {
return result.to_string();
}
let label = if decision.action == "halt" {
"Tool loop hard stop"
} else {
"Tool loop warning"
};
format!(
"{}\n\n[{}: {}; count={}; {}]",
result, label, decision.code, decision.count, decision.message
)
}