use loop_guardrail::{
canonical_tool_args, ToolCallGuardrailConfig, ToolCallGuardrailController,
ToolCallSignature, append_toolguard_guidance, toolguard_synthetic_result,
};
use serde_json::json;
#[test]
fn test_canonical_args_sorting() {
let val1 = json!({"b": 2, "a": 1});
let val2 = json!({"a": 1, "b": 2});
assert_eq!(canonical_tool_args(&val1), canonical_tool_args(&val2));
assert_eq!(canonical_tool_args(&val1), "{\"a\":1,\"b\":2}");
let deep_val1 = json!({"x": {"y": 2, "z": 1}, "a": 5});
let deep_val2 = json!({"a": 5, "x": {"z": 1, "y": 2}});
assert_eq!(canonical_tool_args(&deep_val1), canonical_tool_args(&deep_val2));
}
#[test]
fn test_signature_hashing() {
let sig1 = ToolCallSignature::from_call("test", Some(&json!({"x": 1})));
let sig2 = ToolCallSignature::from_call("test", Some(&json!({"x": 1})));
let sig3 = ToolCallSignature::from_call("test", Some(&json!({"x": 2})));
assert_eq!(sig1.args_hash, sig2.args_hash);
assert_ne!(sig1.args_hash, sig3.args_hash);
}
#[test]
fn test_failure_counting_and_warnings() {
let config = ToolCallGuardrailConfig {
hard_stop_enabled: true,
warnings_enabled: true,
exact_failure_warn_after: 2,
exact_failure_block_after: 3,
..Default::default()
};
let mut ctrl = ToolCallGuardrailController::new(Some(config));
let dec1 = ctrl.before_call("terminal", Some(&json!({"cmd": "ls"})));
assert!(dec1.allows_execution());
let dec_after1 = ctrl.after_call("terminal", Some(&json!({"cmd": "ls"})), Some("Error: cmd not found"), Some(true));
assert_eq!(dec_after1.action, "allow");
let dec2 = ctrl.before_call("terminal", Some(&json!({"cmd": "ls"})));
assert!(dec2.allows_execution());
let dec_after2 = ctrl.after_call("terminal", Some(&json!({"cmd": "ls"})), Some("Error: cmd not found"), Some(true));
assert_eq!(dec_after2.action, "warn");
let dec3 = ctrl.before_call("terminal", Some(&json!({"cmd": "ls"})));
assert!(dec3.allows_execution());
let _dec_after3 = ctrl.after_call("terminal", Some(&json!({"cmd": "ls"})), Some("Error: cmd not found"), Some(true));
let dec4 = ctrl.before_call("terminal", Some(&json!({"cmd": "ls"})));
assert_eq!(dec4.action, "block");
assert!(dec4.should_halt());
}
#[test]
fn test_no_progress_tracking() {
let config = ToolCallGuardrailConfig {
hard_stop_enabled: true,
warnings_enabled: true,
no_progress_warn_after: 2,
no_progress_block_after: 3,
..Default::default()
};
let mut ctrl = ToolCallGuardrailController::new(Some(config));
let args = json!({"path": "file.txt"});
let dec1 = ctrl.before_call("read_file", Some(&args));
assert!(dec1.allows_execution());
let dec_after1 = ctrl.after_call("read_file", Some(&args), Some("hello"), Some(false));
assert_eq!(dec_after1.action, "allow");
let dec_after2 = ctrl.after_call("read_file", Some(&args), Some("hello"), Some(false));
assert_eq!(dec_after2.action, "warn");
let _dec_after3 = ctrl.after_call("read_file", Some(&args), Some("hello"), Some(false));
let dec4 = ctrl.before_call("read_file", Some(&args));
assert_eq!(dec4.action, "block");
}
#[test]
fn test_guidance_formatting() {
let mut ctrl = ToolCallGuardrailController::new(None);
let args = json!({"cmd": "ls"});
let _dec = ctrl.after_call("terminal", Some(&args), Some("Error"), Some(true)); let dec2 = ctrl.after_call("terminal", Some(&args), Some("Error"), Some(true));
assert_eq!(dec2.action, "warn");
let result_guided = append_toolguard_guidance("Error output", &dec2);
assert!(result_guided.contains("[Tool loop warning:"));
assert!(result_guided.contains("count=2"));
let synth = toolguard_synthetic_result(&dec2);
assert!(synth.contains("\"error\":"));
}
#[test]
fn test_scale_guardrail_performance() {
let mut ctrl = ToolCallGuardrailController::new(None);
let start = std::time::Instant::now();
for i in 0..10000 {
let tool_name = if i % 2 == 0 { "web_search" } else { "read_file" };
let args = json!({"query": format!("term_{}", i)});
let _ = ctrl.before_call(tool_name, Some(&args));
let failed = i % 2 != 0;
let _ = ctrl.after_call(
tool_name,
Some(&args),
Some(if failed { "Error" } else { "result" }),
Some(failed),
);
}
let duration = start.elapsed();
assert!(duration < std::time::Duration::from_secs(1));
}