use crate::clients::base::ToolCall;
use crate::core::tool_spec::{ParamModel, ToolSpec};
use crate::guardrails::nudge::Nudge;
use indexmap::{IndexMap, IndexSet};
use serde_json::Value;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ArgValidationKind {
MissingRequired,
WrongType {
expected: String,
actual: String,
},
ExtraArgument,
EnumMismatch {
allowed: Vec<String>,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ArgValidationError {
pub tool: String,
pub path: String,
pub kind: ArgValidationKind,
}
impl ArgValidationError {
pub fn new(tool: impl Into<String>, path: impl Into<String>, kind: ArgValidationKind) -> Self {
Self {
tool: tool.into(),
path: path.into(),
kind,
}
}
pub fn message(&self) -> String {
match &self.kind {
ArgValidationKind::MissingRequired => {
format!("{} is required", self.path)
}
ArgValidationKind::WrongType { expected, actual } => {
format!("{} must be {}, got {}", self.path, expected, actual)
}
ArgValidationKind::ExtraArgument => {
format!("{} is not allowed", self.path)
}
ArgValidationKind::EnumMismatch { allowed } => {
format!("{} must be one of: {}", self.path, allowed.join(", "))
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GuardrailViolation {
NoToolCall,
UnknownTool {
called: String,
available: Vec<String>,
},
PrematureTerminal {
terminal: String,
pending: Vec<String>,
},
MissingPrerequisite {
tool: String,
missing: Vec<String>,
},
InvalidArguments {
tool: String,
errors: Vec<ArgValidationError>,
},
UnsafeBatch {
reason: String,
},
RepeatedFailure {
kind: String,
count: usize,
},
WrongToolLikely {
called: String,
suggested: Vec<String>,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GuardrailState {
pub completed_steps: Vec<String>,
pub pending_steps: Vec<String>,
pub allowed_next_tools: Vec<String>,
pub blocked_tools: Vec<String>,
pub terminal_tools: Vec<String>,
}
impl GuardrailState {
pub fn from_parts(
completed_steps: Vec<String>,
pending_steps: Vec<String>,
tool_names: &[String],
terminal_tools: &IndexSet<String>,
) -> Self {
let terminal_list: Vec<String> = terminal_tools.iter().cloned().collect();
let blocked_tools = if pending_steps.is_empty() {
Vec::new()
} else {
terminal_list.clone()
};
let allowed_next_tools = if pending_steps.is_empty() {
tool_names.to_vec()
} else {
pending_steps.clone()
};
Self {
completed_steps,
pending_steps,
allowed_next_tools,
blocked_tools,
terminal_tools: terminal_list,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum GuardrailDecision {
Allow(Vec<ToolCall>),
Nudge {
violation: GuardrailViolation,
nudge: Nudge,
},
Reject {
violation: GuardrailViolation,
message: String,
},
}
pub fn validate_tool_arguments(call: &ToolCall, spec: &ToolSpec) -> Vec<ArgValidationError> {
let mut errors = Vec::new();
let root_schema = spec.json_schema.as_ref();
if let ParamModel::Object { properties, .. } = &spec.parameters {
validate_object(
&call.tool,
"",
&call.args,
properties,
root_schema,
root_schema,
&mut errors,
);
}
errors
}
pub fn validate_tool_call_batch(
calls: &[ToolCall],
specs: &IndexMap<String, ToolSpec>,
) -> Vec<ArgValidationError> {
let mut errors = Vec::new();
for call in calls {
if let Some(spec) = specs.get(&call.tool) {
errors.extend(validate_tool_arguments(call, spec));
}
}
errors
}
fn validate_object(
tool: &str,
base_path: &str,
args: &IndexMap<String, Value>,
properties: &IndexMap<String, ParamModel>,
schema: Option<&Value>,
root_schema: Option<&Value>,
errors: &mut Vec<ArgValidationError>,
) {
let schema = resolve_schema(schema, root_schema);
for (name, model) in properties {
let path = join_path(base_path, name);
match args.get(name) {
Some(value) => validate_value(
tool,
&path,
value,
model,
property_schema(schema, name, root_schema),
root_schema,
errors,
),
None if model.is_required() => {
errors.push(ArgValidationError::new(
tool,
path,
ArgValidationKind::MissingRequired,
));
}
None => {}
}
}
if additional_properties_false(schema) {
for key in args.keys() {
if !properties.contains_key(key) {
errors.push(ArgValidationError::new(
tool,
join_path(base_path, key),
ArgValidationKind::ExtraArgument,
));
}
}
}
}
fn validate_value(
tool: &str,
path: &str,
value: &Value,
model: &ParamModel,
schema: Option<&Value>,
root_schema: Option<&Value>,
errors: &mut Vec<ArgValidationError>,
) {
match model {
ParamModel::String { enum_values, .. } => {
if let Some(actual) = wrong_type(value, "string") {
errors.push(wrong_type_error(tool, path, "string", actual));
return;
}
if let (Some(allowed), Some(actual)) = (enum_values, value.as_str()) {
if !allowed.iter().any(|item| item == actual) {
errors.push(ArgValidationError::new(
tool,
path,
ArgValidationKind::EnumMismatch {
allowed: allowed.clone(),
},
));
}
}
}
ParamModel::Number { .. } => {
if let Some(actual) = wrong_type(value, "number") {
errors.push(wrong_type_error(tool, path, "number", actual));
}
}
ParamModel::Boolean { .. } => {
if let Some(actual) = wrong_type(value, "boolean") {
errors.push(wrong_type_error(tool, path, "boolean", actual));
}
}
ParamModel::Integer { .. } => {
if !(value.as_i64().is_some() || value.as_u64().is_some()) {
errors.push(wrong_type_error(tool, path, "integer", actual_type(value)));
}
}
ParamModel::Object { properties, .. } => {
let Some(obj) = value.as_object() else {
errors.push(wrong_type_error(tool, path, "object", actual_type(value)));
return;
};
let nested_args: IndexMap<String, Value> = obj
.iter()
.map(|(key, value)| (key.clone(), value.clone()))
.collect();
validate_object(
tool,
path,
&nested_args,
properties,
schema,
root_schema,
errors,
);
}
ParamModel::Array { items, .. } => {
let Some(values) = value.as_array() else {
errors.push(wrong_type_error(tool, path, "array", actual_type(value)));
return;
};
let item_schema = item_schema(schema, root_schema);
for (idx, item) in values.iter().enumerate() {
validate_value(
tool,
&join_index(path, idx),
item,
items,
item_schema,
root_schema,
errors,
);
}
}
ParamModel::Unsupported { .. } => {}
}
}
fn wrong_type(value: &Value, expected: &str) -> Option<&'static str> {
match expected {
"string" if value.is_string() => None,
"number" if value.is_number() => None,
"boolean" if value.is_boolean() => None,
_ => Some(actual_type(value)),
}
}
fn wrong_type_error(
tool: &str,
path: &str,
expected: impl Into<String>,
actual: impl Into<String>,
) -> ArgValidationError {
ArgValidationError::new(
tool,
path,
ArgValidationKind::WrongType {
expected: expected.into(),
actual: actual.into(),
},
)
}
fn actual_type(value: &Value) -> &'static str {
match value {
Value::Null => "null",
Value::Bool(_) => "boolean",
Value::Number(_) => "number",
Value::String(_) => "string",
Value::Array(_) => "array",
Value::Object(_) => "object",
}
}
fn additional_properties_false(schema: Option<&Value>) -> bool {
schema
.and_then(|schema| schema.get("additionalProperties"))
.and_then(Value::as_bool)
== Some(false)
}
fn property_schema<'a>(
schema: Option<&'a Value>,
name: &str,
root_schema: Option<&'a Value>,
) -> Option<&'a Value> {
let schema = resolve_schema(schema, root_schema)?;
schema
.get("properties")
.and_then(Value::as_object)
.and_then(|properties| properties.get(name))
}
fn item_schema<'a>(schema: Option<&'a Value>, root_schema: Option<&'a Value>) -> Option<&'a Value> {
let schema = resolve_schema(schema, root_schema)?;
schema.get("items")
}
fn resolve_schema<'a>(
schema: Option<&'a Value>,
root_schema: Option<&'a Value>,
) -> Option<&'a Value> {
let schema = schema?;
if let Some(reference) = schema.get("$ref").and_then(Value::as_str) {
return resolve_ref(reference, root_schema);
}
if let Some(any_of) = schema.get("anyOf").and_then(Value::as_array) {
return any_of
.iter()
.find(|candidate| candidate.get("type").and_then(Value::as_str) != Some("null"))
.and_then(|candidate| resolve_schema(Some(candidate), root_schema));
}
Some(schema)
}
fn resolve_ref<'a>(reference: &str, root_schema: Option<&'a Value>) -> Option<&'a Value> {
let name = reference.strip_prefix("#/$defs/")?;
root_schema?
.get("$defs")
.and_then(Value::as_object)?
.get(name)
}
fn join_path(base: &str, field: &str) -> String {
if base.is_empty() {
field.to_string()
} else {
format!("{base}.{field}")
}
}
fn join_index(base: &str, index: usize) -> String {
format!("{base}[{index}]")
}