Skip to main content

agent_io/llm/
schema.rs

1//! JSON Schema optimizer for tool definitions
2
3use serde_json::{Map, Value, json};
4use std::collections::HashSet;
5
6/// Schema optimizer for creating LLM-compatible JSON schemas
7pub struct SchemaOptimizer;
8
9impl SchemaOptimizer {
10    /// Create an optimized JSON schema from a raw schema
11    ///
12    /// This function:
13    /// - Flattens $ref and $defs
14    /// - Removes additionalProperties
15    /// - Ensures all properties are required
16    pub fn optimize(schema: &Value) -> Value {
17        let mut result = schema.clone();
18        Self::optimize_recursive(&mut result);
19        result
20    }
21
22    fn optimize_recursive(value: &mut Value) {
23        match value {
24            Value::Object(obj) => {
25                // Remove additionalProperties
26                obj.remove("additionalProperties");
27
28                // Handle $ref by inlining the definition
29                if let Some(ref_val) = obj.get("$ref").and_then(|r| r.as_str())
30                    && let Some(defs) = obj.get("$defs")
31                    && let Some(def) = ref_val.strip_prefix("#/$defs/")
32                    && let Some(resolved) = defs.get(def)
33                {
34                    *value = resolved.clone();
35                    Self::optimize_recursive(value);
36                    return;
37                }
38
39                // Remove $defs if present
40                obj.remove("$defs");
41
42                // Make all properties required
43                if let Some(properties) = obj.get("properties").and_then(|p| p.as_object()) {
44                    let all_keys: HashSet<&str> = properties.keys().map(|k| k.as_str()).collect();
45                    obj.insert("required".to_string(), json!(all_keys));
46                }
47
48                // Recursively process nested objects
49                for (_, v) in obj.iter_mut() {
50                    Self::optimize_recursive(v);
51                }
52            }
53            Value::Array(arr) => {
54                for item in arr.iter_mut() {
55                    Self::optimize_recursive(item);
56                }
57            }
58            _ => {}
59        }
60    }
61
62    /// Create a tool definition from a JSON schema
63    pub fn create_tool_definition(
64        name: impl Into<String>,
65        description: impl Into<String>,
66        schema: Value,
67    ) -> super::ToolDefinition {
68        let optimized = Self::optimize(&schema);
69        let parameters = optimized.as_object().cloned().unwrap_or_else(|| {
70            let mut map = Map::new();
71            map.insert("type".to_string(), json!("object"));
72            map.insert("properties".to_string(), json!({}));
73            map
74        });
75
76        super::ToolDefinition {
77            name: name.into(),
78            description: description.into(),
79            parameters,
80            strict: true,
81        }
82    }
83
84    /// Create a minimal schema for a simple string parameter
85    pub fn string_schema() -> Value {
86        json!({
87            "type": "object",
88            "properties": {
89                "value": {
90                    "type": "string",
91                    "description": "The string value"
92                }
93            },
94            "required": ["value"]
95        })
96    }
97
98    /// Create a schema for multiple string parameters
99    pub fn string_params_schema(params: &[(&str, &str)]) -> Value {
100        let properties: Map<String, Value> = params
101            .iter()
102            .map(|(name, desc)| {
103                (
104                    name.to_string(),
105                    json!({
106                        "type": "string",
107                        "description": desc
108                    }),
109                )
110            })
111            .collect();
112
113        let required: Vec<&str> = params.iter().map(|(name, _)| *name).collect();
114
115        json!({
116            "type": "object",
117            "properties": properties,
118            "required": required
119        })
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use serde_json::json;
127
128    #[test]
129    fn test_optimize_schema() {
130        let schema = json!({
131            "$ref": "#/$defs/MyType",
132            "$defs": {
133                "MyType": {
134                    "type": "object",
135                    "properties": {
136                        "name": { "type": "string" }
137                    },
138                    "additionalProperties": false
139                }
140            }
141        });
142
143        let optimized = SchemaOptimizer::optimize(&schema);
144
145        assert!(optimized.get("$ref").is_none());
146        assert!(optimized.get("$defs").is_none());
147        assert!(optimized.get("additionalProperties").is_none());
148        assert!(optimized.get("required").is_some());
149    }
150}