Skip to main content

ai_agent/
tool_validation.rs

1// Source: /data/home/swei/claudecode/openclaudecode/src/services/tools/toolExecution.ts
2//! Input validation for tool calls.
3//!
4//! Translated from TypeScript checkPermissionsAndCallTool validation step.
5
6use crate::types::{ToolDefinition, ToolInputSchema};
7
8/// Validate tool input against the tool's JSON Schema.
9/// Returns Ok(()) if valid, or Err with a human-readable error message.
10///
11/// Matches TypeScript's Zod schema validation with structured error messages.
12pub fn validate_tool_input(
13    name: &str,
14    input: &serde_json::Value,
15    tools: &[ToolDefinition],
16) -> Result<(), String> {
17    let tool = tools
18        .iter()
19        .find(|t| t.name == name)
20        .ok_or(format!("Tool '{}' not found", name))?;
21    validate_against_schema(name, input, &tool.input_schema)
22}
23
24/// Validate input against a specific tool's schema.
25fn validate_against_schema(
26    tool_name: &str,
27    input: &serde_json::Value,
28    schema: &ToolInputSchema,
29) -> Result<(), String> {
30    let properties = schema
31        .properties
32        .as_object()
33        .ok_or_else(|| format!("Invalid schema for tool '{}'", tool_name))?;
34    let required = schema.required.as_ref();
35
36    let mut errors: Vec<String> = Vec::new();
37
38    // Check required fields
39    if let Some(req) = required {
40        for field in req {
41            if !input.get(field).is_some() {
42                errors.push(format!("The required parameter `{}` is missing", field));
43            }
44        }
45    }
46
47    // Check each provided field against schema
48    if let Some(obj) = input.as_object() {
49        for (key, value) in obj {
50            if let Some(prop_schema) = properties.get(key.as_str()) {
51                if let Some(prop_type) = prop_schema.get("type") {
52                    if let Some(prop_type_str) = prop_type.as_str() {
53                        if !check_type(value, prop_type_str) {
54                            let received = json_type_name(value);
55                            errors.push(format!(
56                                "The parameter `{}` type is expected as `{}` but provided as `{}`",
57                                key, prop_type_str, received
58                            ));
59                        }
60                    }
61                }
62            }
63        }
64    }
65
66    // Check for unexpected parameters (properties not in schema)
67    if let Some(obj) = input.as_object() {
68        for key in obj.keys() {
69            if !properties.contains_key(key.as_str()) {
70                errors.push(format!("An unexpected parameter `{}` was provided", key));
71            }
72        }
73    }
74
75    if errors.is_empty() {
76        Ok(())
77    } else {
78        let issue_word = if errors.len() > 1 { "issues" } else { "issue" };
79        Err(format!(
80            "{} failed due to the following {}:\n{}",
81            tool_name,
82            issue_word,
83            errors.join("\n")
84        ))
85    }
86}
87
88/// Check if a JSON value matches the expected schema type.
89fn check_type(value: &serde_json::Value, expected_type: &str) -> bool {
90    match expected_type {
91        "string" => value.is_string(),
92        "number" => value.is_number(),
93        "integer" => value.is_number() && value.as_i64().is_some(),
94        "boolean" => value.is_boolean(),
95        "array" => value.is_array(),
96        "object" => value.is_object(),
97        "null" => value.is_null(),
98        _ => true, // Unknown types are permissive
99    }
100}
101
102/// Get the JSON type name for error messages.
103fn json_type_name(value: &serde_json::Value) -> &'static str {
104    match value {
105        serde_json::Value::Null => "null",
106        serde_json::Value::Bool(_) => "boolean",
107        serde_json::Value::Number(_) => "number",
108        serde_json::Value::String(_) => "string",
109        serde_json::Value::Array(_) => "array",
110        serde_json::Value::Object(_) => "object",
111    }
112}
113
114/// Tool definition lookup by name.
115/// Matches TypeScript's findToolByName which also checks aliases.
116pub fn find_tool_by_name<'a>(
117    tools: &'a [ToolDefinition],
118    name: &str,
119) -> Option<&'a ToolDefinition> {
120    tools.iter().find(|t| t.name == name).or_else(|| {
121        // Fallback: check if it's a deprecated alias
122        match name {
123            "Edit" => tools.iter().find(|t| t.name == "FileEdit"),
124            _ => None,
125        }
126    })
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use crate::types::ToolInputSchema;
133
134    fn make_tool(
135        name: &str,
136        properties: serde_json::Value,
137        required: Option<Vec<String>>,
138    ) -> ToolDefinition {
139        ToolDefinition {
140            name: name.to_string(),
141            description: format!("Test tool {}", name),
142            input_schema: ToolInputSchema {
143                schema_type: "object".to_string(),
144                properties,
145                required,
146            },
147            annotations: None,
148            should_defer: None,
149            always_load: None,
150            is_mcp: None,
151            search_hint: None,
152            aliases: None,
153            user_facing_name: None,
154            interrupt_behavior: None,
155        }
156    }
157
158    #[test]
159    fn test_valid_input() {
160        let tool = make_tool(
161            "Bash",
162            serde_json::json!({
163                "command": { "type": "string" }
164            }),
165            Some(vec!["command".to_string()]),
166        );
167        let input = serde_json::json!({ "command": "ls -la" });
168        assert!(validate_tool_input("Bash", &input, &[tool]).is_ok());
169    }
170
171    #[test]
172    fn test_missing_required_field() {
173        let tool = make_tool(
174            "Bash",
175            serde_json::json!({
176                "command": { "type": "string" }
177            }),
178            Some(vec!["command".to_string()]),
179        );
180        let input = serde_json::json!({});
181        let err = validate_tool_input("Bash", &input, &[tool]).unwrap_err();
182        assert!(err.contains("The required parameter `command` is missing"));
183    }
184
185    #[test]
186    fn test_type_mismatch() {
187        let tool = make_tool(
188            "Bash",
189            serde_json::json!({
190                "command": { "type": "string" }
191            }),
192            Some(vec!["command".to_string()]),
193        );
194        let input = serde_json::json!({ "command": 123 });
195        let err = validate_tool_input("Bash", &input, &[tool]).unwrap_err();
196        assert!(err.contains("type is expected as `string` but provided as `number`"));
197    }
198
199    #[test]
200    fn test_unexpected_parameter() {
201        let tool = make_tool(
202            "Bash",
203            serde_json::json!({
204                "command": { "type": "string" }
205            }),
206            Some(vec!["command".to_string()]),
207        );
208        let input = serde_json::json!({ "command": "ls", "unknown_field": "val" });
209        let err = validate_tool_input("Bash", &input, &[tool]).unwrap_err();
210        assert!(err.contains("An unexpected parameter `unknown_field` was provided"));
211    }
212
213    #[test]
214    fn test_alias_resolution() {
215        let tool = make_tool("Read", serde_json::json!({}), None);
216        let tools = vec![tool];
217        assert!(find_tool_by_name(&tools, "Read").is_some());
218        assert!(find_tool_by_name(&tools, "Read").is_some());
219        assert!(find_tool_by_name(&tools, "NonExistent").is_none());
220    }
221}