Skip to main content

llm_core/
schema.rs

1use crate::LlmError;
2
3/// Parse a schema DSL string into a JSON Schema object.
4///
5/// Format: comma-separated (or newline-separated) fields.
6/// Each field: `name [type][:description]`
7/// Types: str (default), int, float, bool
8pub fn parse_schema_dsl(input: &str) -> Result<serde_json::Value, LlmError> {
9    let input = input.trim();
10    if input.is_empty() {
11        return Err(LlmError::Config("empty schema DSL input".into()));
12    }
13
14    let fields: Vec<&str> = if input.contains('\n') {
15        input.split('\n').collect()
16    } else {
17        input.split(',').collect()
18    };
19
20    let mut properties = serde_json::Map::new();
21    let mut required = Vec::new();
22
23    for field in fields {
24        let field = field.trim();
25        if field.is_empty() {
26            continue;
27        }
28
29        let mut parts = field.splitn(2, char::is_whitespace);
30        let name = parts
31            .next()
32            .ok_or_else(|| LlmError::Config(format!("invalid field: {field}")))?
33            .trim();
34
35        let mut json_type = "string".to_string();
36        let mut description: Option<String> = None;
37
38        if let Some(rest) = parts.next() {
39            let rest = rest.trim();
40            if !rest.is_empty() {
41                // Check for type:description or just type or just :description
42                if let Some((type_part, desc_part)) = rest.split_once(':') {
43                    let type_part = type_part.trim();
44                    if !type_part.is_empty() {
45                        json_type = map_type(type_part)?;
46                    }
47                    let desc_part = desc_part.trim();
48                    if !desc_part.is_empty() {
49                        description = Some(desc_part.to_string());
50                    }
51                } else {
52                    json_type = map_type(rest)?;
53                }
54            }
55        }
56
57        let mut prop = serde_json::Map::new();
58        prop.insert("type".into(), serde_json::Value::String(json_type));
59        if let Some(desc) = description {
60            prop.insert("description".into(), serde_json::Value::String(desc));
61        }
62        properties.insert(name.to_string(), serde_json::Value::Object(prop));
63        required.push(serde_json::Value::String(name.to_string()));
64    }
65
66    if properties.is_empty() {
67        return Err(LlmError::Config("no valid fields in schema DSL".into()));
68    }
69
70    Ok(serde_json::json!({
71        "type": "object",
72        "properties": properties,
73        "required": required,
74    }))
75}
76
77/// Wrap a schema in an array structure for --schema-multi.
78pub fn multi_schema(schema: serde_json::Value) -> serde_json::Value {
79    serde_json::json!({
80        "type": "object",
81        "properties": {
82            "items": {
83                "type": "array",
84                "items": schema,
85            }
86        },
87        "required": ["items"],
88    })
89}
90
91fn map_type(t: &str) -> Result<String, LlmError> {
92    match t {
93        "str" => Ok("string".into()),
94        "int" => Ok("integer".into()),
95        "float" => Ok("number".into()),
96        "bool" => Ok("boolean".into()),
97        other => Err(LlmError::Config(format!("unknown type: {other}"))),
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn parse_single_string_field() {
107        let result = parse_schema_dsl("name str").unwrap();
108        assert_eq!(result["type"], "object");
109        assert_eq!(result["properties"]["name"]["type"], "string");
110        assert_eq!(result["required"][0], "name");
111    }
112
113    #[test]
114    fn parse_each_type() {
115        let result = parse_schema_dsl("a int, b float, c bool, d str").unwrap();
116        assert_eq!(result["properties"]["a"]["type"], "integer");
117        assert_eq!(result["properties"]["b"]["type"], "number");
118        assert_eq!(result["properties"]["c"]["type"], "boolean");
119        assert_eq!(result["properties"]["d"]["type"], "string");
120    }
121
122    #[test]
123    fn parse_multiple_fields() {
124        let result = parse_schema_dsl("name str, age int, active bool").unwrap();
125        let props = result["properties"].as_object().unwrap();
126        assert_eq!(props.len(), 3);
127        let required = result["required"].as_array().unwrap();
128        assert_eq!(required.len(), 3);
129    }
130
131    #[test]
132    fn parse_field_with_description() {
133        let result = parse_schema_dsl("age int:The person's age").unwrap();
134        assert_eq!(result["properties"]["age"]["type"], "integer");
135        assert_eq!(
136            result["properties"]["age"]["description"],
137            "The person's age"
138        );
139    }
140
141    #[test]
142    fn parse_mixed_descriptions() {
143        let result = parse_schema_dsl("name str, age int:The age, active bool").unwrap();
144        assert!(result["properties"]["name"]["description"].is_null());
145        assert_eq!(result["properties"]["age"]["description"], "The age");
146        assert!(result["properties"]["active"]["description"].is_null());
147    }
148
149    #[test]
150    fn parse_default_type_is_string() {
151        let result = parse_schema_dsl("name").unwrap();
152        assert_eq!(result["properties"]["name"]["type"], "string");
153    }
154
155    #[test]
156    fn parse_whitespace_tolerance() {
157        let result = parse_schema_dsl("  name   str  ,  age   int  ").unwrap();
158        assert_eq!(result["properties"]["name"]["type"], "string");
159        assert_eq!(result["properties"]["age"]["type"], "integer");
160    }
161
162    #[test]
163    fn parse_newline_separated() {
164        let result = parse_schema_dsl("name str\nage int\nactive bool").unwrap();
165        let props = result["properties"].as_object().unwrap();
166        assert_eq!(props.len(), 3);
167    }
168
169    #[test]
170    fn parse_empty_string_error() {
171        let result = parse_schema_dsl("");
172        assert!(result.is_err());
173    }
174
175    #[test]
176    fn parse_invalid_type_error() {
177        let result = parse_schema_dsl("name xyz");
178        assert!(result.is_err());
179        let err = result.unwrap_err().to_string();
180        assert!(err.contains("unknown type"));
181    }
182
183    #[test]
184    fn multi_schema_wraps_in_array() {
185        let schema = parse_schema_dsl("name str, age int").unwrap();
186        let multi = multi_schema(schema.clone());
187        assert_eq!(multi["type"], "object");
188        assert_eq!(multi["properties"]["items"]["type"], "array");
189        assert_eq!(multi["properties"]["items"]["items"], schema);
190        assert_eq!(multi["required"][0], "items");
191    }
192}