Skip to main content

lash_sansio/tool_contract/
schema_validation.rs

1use serde_json::Value;
2
3use crate::tool_contract::ToolContract;
4
5#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
6pub struct LashSchema {
7    pub schema: Value,
8}
9
10impl LashSchema {
11    pub fn new(schema: Value) -> Self {
12        Self { schema }
13    }
14
15    pub fn any() -> Self {
16        Self::new(serde_json::json!({}))
17    }
18
19    pub fn object(properties: serde_json::Map<String, Value>, required: Vec<String>) -> Self {
20        let mut schema = serde_json::Map::new();
21        schema.insert("type".to_string(), Value::String("object".to_string()));
22        schema.insert("properties".to_string(), Value::Object(properties));
23        if !required.is_empty() {
24            schema.insert(
25                "required".to_string(),
26                Value::Array(required.into_iter().map(Value::String).collect()),
27            );
28        }
29        schema.insert("additionalProperties".to_string(), Value::Bool(true));
30        Self::new(Value::Object(schema))
31    }
32
33    pub fn validate(&self, value: &Value) -> Result<(), String> {
34        validate_schema("", &self.schema, value)
35    }
36}
37
38pub fn validate_tool_input(contract: &ToolContract, args: &Value) -> Result<(), String> {
39    LashSchema::new(contract.input_schema.clone()).validate(args)
40}
41
42fn validate_schema(path: &str, schema: &Value, value: &Value) -> Result<(), String> {
43    if let Some(any_of) = schema
44        .get("anyOf")
45        .or_else(|| schema.get("oneOf"))
46        .and_then(Value::as_array)
47    {
48        if any_of
49            .iter()
50            .any(|subschema| validate_schema(path, subschema, value).is_ok())
51        {
52            return Ok(());
53        }
54        return Err(format!(
55            "{}: expected {}, got {}",
56            display_path(path),
57            expected_description(schema),
58            display_value(value)
59        ));
60    }
61
62    if let Some(enum_values) = schema.get("enum").and_then(Value::as_array)
63        && !enum_values.iter().any(|candidate| candidate == value)
64    {
65        return Err(format!(
66            "{}: expected one of {}, got {}",
67            display_path(path),
68            enum_values
69                .iter()
70                .map(display_value)
71                .collect::<Vec<_>>()
72                .join(", "),
73            display_value(value)
74        ));
75    }
76
77    if let Some(type_value) = schema.get("type")
78        && !matches_type(type_value, value)
79    {
80        return Err(format!(
81            "{}: expected {}, got {}",
82            display_path(path),
83            expected_description(schema),
84            display_value(value)
85        ));
86    }
87
88    if let Some(maximum) = schema.get("maximum").and_then(Value::as_f64)
89        && value.as_f64().is_some_and(|actual| actual > maximum)
90    {
91        return Err(format!(
92            "{}: expected {}, got {}",
93            display_path(path),
94            expected_description(schema),
95            display_value(value)
96        ));
97    }
98    if let Some(minimum) = schema.get("minimum").and_then(Value::as_f64)
99        && value.as_f64().is_some_and(|actual| actual < minimum)
100    {
101        return Err(format!(
102            "{}: expected {}, got {}",
103            display_path(path),
104            expected_description(schema),
105            display_value(value)
106        ));
107    }
108
109    if let Some(max_length) = schema.get("maxLength").and_then(Value::as_u64)
110        && value
111            .as_str()
112            .is_some_and(|actual| actual.chars().count() as u64 > max_length)
113    {
114        return Err(format!(
115            "{}: expected {}, got {}",
116            display_path(path),
117            expected_description(schema),
118            display_value(value)
119        ));
120    }
121    if let Some(min_length) = schema.get("minLength").and_then(Value::as_u64)
122        && value
123            .as_str()
124            .is_some_and(|actual| (actual.chars().count() as u64) < min_length)
125    {
126        return Err(format!(
127            "{}: expected {}, got {}",
128            display_path(path),
129            expected_description(schema),
130            display_value(value)
131        ));
132    }
133
134    if let Some(items) = value.as_array() {
135        if let Some(max_items) = schema.get("maxItems").and_then(Value::as_u64)
136            && items.len() as u64 > max_items
137        {
138            return Err(format!(
139                "{}: expected {}, got {} items",
140                display_path(path),
141                expected_description(schema),
142                items.len()
143            ));
144        }
145        if let Some(min_items) = schema.get("minItems").and_then(Value::as_u64)
146            && (items.len() as u64) < min_items
147        {
148            return Err(format!(
149                "{}: expected {}, got {} items",
150                display_path(path),
151                expected_description(schema),
152                items.len()
153            ));
154        }
155        if let Some(item_schema) = schema.get("items") {
156            for (idx, item) in items.iter().enumerate() {
157                validate_schema(&format!("{path}[{idx}]"), item_schema, item)?;
158            }
159        }
160    }
161
162    if let Some(object) = value.as_object() {
163        let required = schema
164            .get("required")
165            .and_then(Value::as_array)
166            .into_iter()
167            .flatten()
168            .filter_map(Value::as_str);
169        for property in required {
170            if !object.contains_key(property) {
171                return Err(format!(
172                    "{}: required property missing",
173                    join_path(path, property)
174                ));
175            }
176        }
177
178        if let Some(properties) = schema.get("properties").and_then(Value::as_object) {
179            for (name, property_schema) in properties {
180                if let Some(property_value) = object.get(name) {
181                    validate_schema(&join_path(path, name), property_schema, property_value)?;
182                }
183            }
184            match schema.get("additionalProperties") {
185                Some(Value::Bool(true)) => {}
186                Some(Value::Object(additional_schema)) => {
187                    let additional_schema = Value::Object(additional_schema.clone());
188                    for (name, property_value) in object {
189                        if is_internal_argument(name) || properties.contains_key(name) {
190                            continue;
191                        }
192                        validate_schema(
193                            &join_path(path, name),
194                            &additional_schema,
195                            property_value,
196                        )?;
197                    }
198                }
199                _ => {
200                    for name in object.keys() {
201                        if is_internal_argument(name) {
202                            continue;
203                        }
204                        if !properties.contains_key(name) {
205                            return Err(format!("{}: unexpected property", join_path(path, name)));
206                        }
207                    }
208                }
209            }
210        }
211    }
212
213    Ok(())
214}
215
216fn matches_type(type_value: &Value, value: &Value) -> bool {
217    match type_value {
218        Value::String(ty) => matches_single_type(ty, value),
219        Value::Array(types) => types
220            .iter()
221            .filter_map(Value::as_str)
222            .any(|ty| matches_single_type(ty, value)),
223        _ => true,
224    }
225}
226
227fn matches_single_type(ty: &str, value: &Value) -> bool {
228    match ty {
229        "null" => value.is_null(),
230        "boolean" => value.is_boolean(),
231        "string" => value.is_string(),
232        "integer" => value.as_i64().is_some() || value.as_u64().is_some(),
233        "number" => value.is_number(),
234        "array" => value.is_array(),
235        "object" => value.is_object(),
236        _ => true,
237    }
238}
239
240fn expected_description(schema: &Value) -> String {
241    let mut parts = Vec::new();
242    if let Some(type_value) = schema.get("type") {
243        parts.push(type_description(type_value, schema));
244    } else if schema.get("anyOf").is_some() || schema.get("oneOf").is_some() {
245        parts.push("matching schema".to_string());
246    }
247    if let Some(minimum) = schema.get("minimum") {
248        parts.push(format!(">= {}", display_value(minimum)));
249    }
250    if let Some(maximum) = schema.get("maximum") {
251        parts.push(format!("<= {}", display_value(maximum)));
252    }
253    if let Some(min_length) = schema.get("minLength") {
254        parts.push(format!("length >= {}", display_value(min_length)));
255    }
256    if let Some(max_length) = schema.get("maxLength") {
257        parts.push(format!("length <= {}", display_value(max_length)));
258    }
259    if let Some(min_items) = schema.get("minItems") {
260        parts.push(format!("items >= {}", display_value(min_items)));
261    }
262    if let Some(max_items) = schema.get("maxItems") {
263        parts.push(format!("items <= {}", display_value(max_items)));
264    }
265    if parts.is_empty() {
266        "valid value".to_string()
267    } else {
268        parts.join(" ")
269    }
270}
271
272fn type_description(type_value: &Value, schema: &Value) -> String {
273    match type_value {
274        Value::String(ty) => type_name(ty, schema).to_string(),
275        Value::Array(types) => types
276            .iter()
277            .filter_map(Value::as_str)
278            .map(|ty| type_name(ty, schema).to_string())
279            .collect::<Vec<_>>()
280            .join(" or "),
281        _ => "valid value".to_string(),
282    }
283}
284
285fn type_name<'a>(ty: &'a str, schema: &'a Value) -> &'a str {
286    if ty == "array" && schema.get("items").is_some() {
287        "array"
288    } else {
289        ty
290    }
291}
292
293fn display_path(path: &str) -> String {
294    if path.is_empty() {
295        "arguments".to_string()
296    } else {
297        path.to_string()
298    }
299}
300
301fn join_path(base: &str, key: &str) -> String {
302    if base.is_empty() {
303        key.to_string()
304    } else {
305        format!("{base}.{key}")
306    }
307}
308
309fn is_internal_argument(name: &str) -> bool {
310    name == "__session_id__"
311}
312
313fn display_value(value: &Value) -> String {
314    match value {
315        Value::String(text) => format!("{text:?}"),
316        _ => value.to_string(),
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::ToolDefinition;
324
325    #[test]
326    fn validation_reports_missing_required_property_by_path() {
327        let tool = ToolDefinition::raw(
328            "tool:spotify",
329            "spotify",
330            "",
331            serde_json::json!({
332                "type": "object",
333                "properties": {
334                    "access_token": { "type": "string" }
335                },
336                "required": ["access_token"],
337                "additionalProperties": false
338            }),
339            serde_json::json!({}),
340        );
341
342        let error = validate_tool_input(&tool.contract(), &serde_json::json!({})).unwrap_err();
343        assert_eq!(error, "access_token: required property missing");
344    }
345
346    #[test]
347    fn validation_reports_numeric_limits_by_path() {
348        let tool = ToolDefinition::raw(
349            "tool:spotify",
350            "spotify",
351            "",
352            serde_json::json!({
353                "type": "object",
354                "properties": {
355                    "page_limit": { "type": "integer", "maximum": 20 }
356                },
357                "required": ["page_limit"],
358                "additionalProperties": false
359            }),
360            serde_json::json!({}),
361        );
362
363        let error =
364            validate_tool_input(&tool.contract(), &serde_json::json!({ "page_limit": 100 }))
365                .unwrap_err();
366        assert_eq!(error, "page_limit: expected integer <= 20, got 100");
367    }
368
369    #[test]
370    fn validation_rejects_unknown_property_when_additional_properties_is_omitted() {
371        let tool = ToolDefinition::raw(
372            "tool:mcp__appworld__venmo_show_transactions",
373            "mcp__appworld__venmo_show_transactions",
374            "",
375            serde_json::json!({
376                "type": "object",
377                "properties": {
378                    "min_created_at": { "type": "string" },
379                    "max_created_at": { "type": "string" },
380                    "limit": { "type": "integer", "maximum": 100 }
381                },
382                "required": ["limit"]
383            }),
384            serde_json::json!({}),
385        );
386
387        let error = validate_tool_input(
388            &tool.contract(),
389            &serde_json::json!({
390                "min_datetime": "2024-01-01T00:00:00Z",
391                "limit": 20
392            }),
393        )
394        .unwrap_err();
395        assert_eq!(error, "min_datetime: unexpected property");
396    }
397
398    #[test]
399    fn validation_allows_unknown_property_when_additional_properties_is_true() {
400        let tool = ToolDefinition::raw(
401            "tool:open",
402            "open",
403            "",
404            serde_json::json!({
405                "type": "object",
406                "properties": {
407                    "path": { "type": "string" }
408                },
409                "additionalProperties": true
410            }),
411            serde_json::json!({}),
412        );
413
414        validate_tool_input(
415            &tool.contract(),
416            &serde_json::json!({
417                "path": "README.md",
418                "unknown": "preserved"
419            }),
420        )
421        .unwrap();
422    }
423
424    #[test]
425    fn validation_preserves_internal_session_id_argument() {
426        let tool = ToolDefinition::raw(
427            "tool:lashlang_tool",
428            "lashlang_tool",
429            "",
430            serde_json::json!({
431                "type": "object",
432                "properties": {
433                    "query": { "type": "string" }
434                }
435            }),
436            serde_json::json!({}),
437        );
438
439        validate_tool_input(
440            &tool.contract(),
441            &serde_json::json!({
442                "query": "hello",
443                "__session_id__": "session"
444            }),
445        )
446        .unwrap();
447    }
448}