Skip to main content

roder_api/
tool_schema.rs

1use serde::{Deserialize, Serialize};
2use serde_json::{Map, Value};
3
4#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
5#[serde(rename_all = "snake_case")]
6pub enum ToolSchemaMode {
7    Strict,
8    Warning,
9}
10
11#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
12#[serde(rename_all = "camelCase")]
13pub struct ToolSchemaPolicy {
14    pub mode: ToolSchemaMode,
15}
16
17impl ToolSchemaPolicy {
18    pub fn strict() -> Self {
19        Self {
20            mode: ToolSchemaMode::Strict,
21        }
22    }
23
24    pub fn warning() -> Self {
25        Self {
26            mode: ToolSchemaMode::Warning,
27        }
28    }
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
32#[serde(rename_all = "snake_case")]
33pub enum ToolSchemaLintKind {
34    NestedRequiredArray,
35    MissingAdditionalProperties,
36    AmbiguousFieldName,
37    MismatchedDefault,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
41#[serde(rename_all = "camelCase")]
42pub struct ToolSchemaLint {
43    pub tool_name: String,
44    pub pointer: String,
45    pub kind: ToolSchemaLintKind,
46    pub message: String,
47    pub severity: ToolSchemaMode,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
51#[serde(rename_all = "camelCase")]
52pub struct ToolSchemaReport {
53    pub tool_name: String,
54    pub schema: Value,
55    #[serde(default)]
56    pub lints: Vec<ToolSchemaLint>,
57}
58
59pub fn normalize_tool_schema(
60    tool_name: &str,
61    schema: &Value,
62    policy: ToolSchemaPolicy,
63) -> ToolSchemaReport {
64    let schema = normalize_value(schema);
65    let mut lints = Vec::new();
66    lint_value(tool_name, "", &schema, policy, &mut lints);
67    ToolSchemaReport {
68        tool_name: tool_name.to_string(),
69        schema,
70        lints,
71    }
72}
73
74fn normalize_value(value: &Value) -> Value {
75    match value {
76        Value::Object(object) => {
77            let mut normalized = Map::new();
78            push_key("type", object, &mut normalized);
79            push_key("required", object, &mut normalized);
80            if let Some(properties) = object.get("properties") {
81                normalized.insert("properties".to_string(), normalize_properties(properties));
82            }
83            push_key("additionalProperties", object, &mut normalized);
84            let mut rest = object
85                .iter()
86                .filter(|(key, _)| {
87                    !matches!(
88                        key.as_str(),
89                        "type" | "required" | "properties" | "additionalProperties"
90                    )
91                })
92                .collect::<Vec<_>>();
93            rest.sort_by_key(|(key, _)| *key);
94            for (key, value) in rest {
95                normalized.insert(key.clone(), normalize_value(value));
96            }
97            Value::Object(normalized)
98        }
99        Value::Array(items) => Value::Array(items.iter().map(normalize_value).collect()),
100        _ => value.clone(),
101    }
102}
103
104fn normalize_properties(value: &Value) -> Value {
105    let Value::Object(properties) = value else {
106        return normalize_value(value);
107    };
108    let mut normalized = Map::new();
109    let mut entries = properties.iter().collect::<Vec<_>>();
110    entries.sort_by_key(|(key, _)| *key);
111    for (key, value) in entries {
112        normalized.insert(key.clone(), normalize_value(value));
113    }
114    Value::Object(normalized)
115}
116
117fn push_key(key: &str, source: &Map<String, Value>, target: &mut Map<String, Value>) {
118    if let Some(value) = source.get(key) {
119        target.insert(key.to_string(), normalize_value(value));
120    }
121}
122
123fn lint_value(
124    tool_name: &str,
125    pointer: &str,
126    value: &Value,
127    policy: ToolSchemaPolicy,
128    lints: &mut Vec<ToolSchemaLint>,
129) {
130    let Value::Object(object) = value else {
131        return;
132    };
133    if object.get("type").and_then(Value::as_str) == Some("object")
134        && object.get("additionalProperties").and_then(Value::as_bool) != Some(false)
135    {
136        push_lint(
137            lints,
138            tool_name,
139            pointer,
140            ToolSchemaLintKind::MissingAdditionalProperties,
141            "object schema should set additionalProperties: false",
142            policy,
143        );
144    }
145    if !pointer.is_empty() && object.get("required").is_some_and(Value::is_array) {
146        push_lint(
147            lints,
148            tool_name,
149            pointer,
150            ToolSchemaLintKind::NestedRequiredArray,
151            "nested object schemas with required arrays are brittle for model tool calls",
152            policy,
153        );
154    }
155    if let Some(properties) = object.get("properties").and_then(Value::as_object) {
156        for (name, property) in properties {
157            if matches!(name.as_str(), "file" | "text" | "input" | "value") {
158                push_lint(
159                    lints,
160                    tool_name,
161                    &format!("{pointer}/properties/{name}"),
162                    ToolSchemaLintKind::AmbiguousFieldName,
163                    "prefer specific coding-agent argument names such as path, content, query, or command",
164                    policy,
165                );
166            }
167            if let Some(default) = property.get("default")
168                && default.is_null()
169            {
170                push_lint(
171                    lints,
172                    tool_name,
173                    &format!("{pointer}/properties/{name}/default"),
174                    ToolSchemaLintKind::MismatchedDefault,
175                    "null defaults are ambiguous unless the runtime applies the same default",
176                    policy,
177                );
178            }
179            lint_value(
180                tool_name,
181                &format!("{pointer}/properties/{name}"),
182                property,
183                policy,
184                lints,
185            );
186        }
187    }
188}
189
190fn push_lint(
191    lints: &mut Vec<ToolSchemaLint>,
192    tool_name: &str,
193    pointer: &str,
194    kind: ToolSchemaLintKind,
195    message: &str,
196    policy: ToolSchemaPolicy,
197) {
198    lints.push(ToolSchemaLint {
199        tool_name: tool_name.to_string(),
200        pointer: if pointer.is_empty() {
201            "/".to_string()
202        } else {
203            pointer.to_string()
204        },
205        kind,
206        message: message.to_string(),
207        severity: policy.mode,
208    });
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn tool_schema_normalizes_required_before_properties_at_every_object_layer() {
217        let schema = serde_json::json!({
218            "additionalProperties": false,
219            "properties": {
220                "edits": {
221                    "properties": {
222                        "new_string": { "type": "string" },
223                        "old_string": { "type": "string" }
224                    },
225                    "additionalProperties": false,
226                    "required": ["old_string", "new_string"],
227                    "type": "object"
228                },
229                "path": { "type": "string" }
230            },
231            "required": ["path", "edits"],
232            "type": "object"
233        });
234
235        let report = normalize_tool_schema("multi_edit", &schema, ToolSchemaPolicy::strict());
236        let json = serde_json::to_string(&report.schema).unwrap();
237        let root_required = json.find(r#""required""#).unwrap();
238        let root_properties = json.find(r#""properties""#).unwrap();
239        let nested = json.find(r#""edits":{"#).unwrap();
240        let nested_required = json[nested..].find(r#""required""#).unwrap() + nested;
241        let nested_properties = json[nested..].find(r#""properties""#).unwrap() + nested;
242
243        assert!(root_required < root_properties, "{json}");
244        assert!(nested_required < nested_properties, "{json}");
245    }
246
247    #[test]
248    fn tool_schema_lints_include_tool_name_and_json_pointer() {
249        let schema = serde_json::json!({
250            "type": "object",
251            "properties": {
252                "input": {
253                    "type": "object",
254                    "required": ["value"],
255                    "properties": {
256                        "value": { "type": "string", "default": null }
257                    }
258                }
259            }
260        });
261
262        let report = normalize_tool_schema("bad_tool", &schema, ToolSchemaPolicy::strict());
263
264        assert!(report.lints.iter().any(|lint| lint.tool_name == "bad_tool"
265            && lint.pointer == "/"
266            && lint.kind == ToolSchemaLintKind::MissingAdditionalProperties));
267        assert!(
268            report
269                .lints
270                .iter()
271                .any(|lint| lint.pointer == "/properties/input"
272                    && lint.kind == ToolSchemaLintKind::NestedRequiredArray)
273        );
274    }
275}