Skip to main content

oxi_agent/
structured_output.rs

1//! Structured output extraction and validation.
2//!
3//! Provides utilities for extracting structured JSON from agent responses
4//! and optionally validating them against a JSON Schema.
5
6use serde_json::Value;
7
8/// Output mode for agent responses.
9#[derive(Debug, Clone, Default)]
10pub enum OutputMode {
11    /// Return the raw text response as-is.
12    #[default]
13    Text,
14    /// Extract JSON from the response.
15    Json,
16    /// Extract JSON and validate against a schema.
17    ValidatedJson {
18        /// JSON Schema to validate against.
19        schema: Value,
20    },
21}
22
23impl OutputMode {
24    /// Check if this mode requires JSON extraction.
25    pub fn requires_json(&self) -> bool {
26        matches!(self, OutputMode::Json | OutputMode::ValidatedJson { .. })
27    }
28}
29
30/// Structured output extractor.
31pub struct StructuredOutput;
32
33impl StructuredOutput {
34    /// Extract structured output from agent response content.
35    ///
36    /// - `OutputMode::Text` → returns the content as a JSON string.
37    /// - `OutputMode::Json` → extracts JSON from the content.
38    /// - `OutputMode::ValidatedJson` → extracts and validates JSON.
39    pub fn extract(content: &str, mode: &OutputMode) -> Result<Value, StructuredOutputError> {
40        match mode {
41            OutputMode::Text => Ok(Value::String(content.to_string())),
42            OutputMode::Json => Self::extract_json(content),
43            OutputMode::ValidatedJson { schema } => {
44                let json = Self::extract_json(content)?;
45                Self::validate(&json, schema)?;
46                Ok(json)
47            }
48        }
49    }
50
51    /// Extract JSON from text content.
52    ///
53    /// Tries, in order:
54    /// 1. Parse the entire content as JSON.
55    /// 2. Extract from a ` ```json ... ``` ` code block.
56    /// 3. Find a matching `{ ... }` or `[ ... ]` bracket pair.
57    pub fn extract_json(content: &str) -> Result<Value, StructuredOutputError> {
58        // 1. Entire content is JSON
59        if let Ok(v) = serde_json::from_str::<Value>(content) {
60            return Ok(v);
61        }
62
63        // 2. ```json ... ``` block
64        if let Some(start) = content.find("```json") {
65            let json_start = start + 7;
66            if let Some(end) = content[json_start..].find("```") {
67                let json_str = content[json_start..json_start + end].trim();
68                return serde_json::from_str(json_str).map_err(|e| {
69                    StructuredOutputError::ParseError(format!(
70                        "JSON parse error in code block: {}",
71                        e
72                    ))
73                });
74            }
75        }
76
77        // 3. Find matching brackets
78        for (open, close) in [('{', '}'), ('[', ']')] {
79            if let Some(start) = content.find(open) {
80                let substr = &content[start..];
81                if let Some(end) = Self::find_matching_bracket(substr, open, close) {
82                    let json_str = &substr[..=end];
83                    if let Ok(v) = serde_json::from_str(json_str) {
84                        return Ok(v);
85                    }
86                }
87            }
88        }
89
90        Err(StructuredOutputError::NotFound(
91            "No JSON found in response".into(),
92        ))
93    }
94
95    /// Validate a JSON value against a JSON Schema.
96    ///
97    /// Uses basic type/required field validation. For full JSON Schema
98    /// validation, enable the `jsonschema` feature (future work).
99    pub fn validate(json: &Value, schema: &Value) -> Result<(), StructuredOutputError> {
100        // Basic validation: check "type" keyword
101        if let Some(expected_type) = schema.get("type").and_then(|t| t.as_str()) {
102            let actual_matches = match expected_type {
103                "object" => json.is_object(),
104                "array" => json.is_array(),
105                "string" => json.is_string(),
106                "number" => json.is_number(),
107                "integer" => json.is_i64() || json.is_u64(),
108                "boolean" => json.is_boolean(),
109                "null" => json.is_null(),
110                _ => true,
111            };
112            if !actual_matches {
113                return Err(StructuredOutputError::ValidationError(format!(
114                    "Expected type '{}', got '{}'",
115                    expected_type,
116                    json_type_name(json)
117                )));
118            }
119        }
120
121        // Check "required" fields
122        if let Some(required) = schema.get("required").and_then(|r| r.as_array()) {
123            if let Some(obj) = json.as_object() {
124                for field in required {
125                    if let Some(name) = field.as_str() {
126                        if !obj.contains_key(name) {
127                            return Err(StructuredOutputError::ValidationError(format!(
128                                "Missing required field: '{}'",
129                                name
130                            )));
131                        }
132                    }
133                }
134            }
135        }
136
137        Ok(())
138    }
139
140    /// Find the index of the closing bracket that matches the first opening bracket.
141    fn find_matching_bracket(s: &str, open: char, close: char) -> Option<usize> {
142        let mut depth = 0;
143        let mut in_string = false;
144        let mut escape_next = false;
145
146        for (i, c) in s.char_indices() {
147            if escape_next {
148                escape_next = false;
149                continue;
150            }
151            match c {
152                '\\' if in_string => escape_next = true,
153                '"' => in_string = !in_string,
154                _ if in_string => {}
155                c if c == open => depth += 1,
156                c if c == close => {
157                    depth -= 1;
158                    if depth == 0 {
159                        return Some(i);
160                    }
161                }
162                _ => {}
163            }
164        }
165        None
166    }
167}
168
169/// Errors during structured output extraction.
170#[derive(Debug, thiserror::Error)]
171pub enum StructuredOutputError {
172    /// JSON could not be found in the response.
173    #[error("JSON not found: {0}")]
174    NotFound(String),
175
176    /// JSON parsing failed.
177    #[error("{0}")]
178    ParseError(String),
179
180    /// Schema validation failed.
181    #[error("Validation error: {0}")]
182    ValidationError(String),
183}
184
185/// Get a human-readable type name for a JSON value.
186fn json_type_name(v: &Value) -> &'static str {
187    match v {
188        Value::Null => "null",
189        Value::Bool(_) => "boolean",
190        Value::Number(_) => "number",
191        Value::String(_) => "string",
192        Value::Array(_) => "array",
193        Value::Object(_) => "object",
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use serde_json::json;
201
202    #[test]
203    fn test_extract_text_mode() {
204        let result = StructuredOutput::extract("hello world", &OutputMode::Text).unwrap();
205        assert_eq!(result, Value::String("hello world".to_string()));
206    }
207
208    #[test]
209    fn test_extract_pure_json() {
210        let json = r#"{"name": "test", "value": 42}"#;
211        let result = StructuredOutput::extract(json, &OutputMode::Json).unwrap();
212        assert_eq!(result["name"], "test");
213        assert_eq!(result["value"], 42);
214    }
215
216    #[test]
217    fn test_extract_json_code_block() {
218        let content = "Here is the result:\n```json\n{\"status\": \"ok\"}\n```\nDone.";
219        let result = StructuredOutput::extract(content, &OutputMode::Json).unwrap();
220        assert_eq!(result["status"], "ok");
221    }
222
223    #[test]
224    fn test_extract_json_embedded_brackets() {
225        let content = "The answer is {\"x\": 1, \"y\": 2} as shown above.";
226        let result = StructuredOutput::extract(content, &OutputMode::Json).unwrap();
227        assert_eq!(result["x"], 1);
228    }
229
230    #[test]
231    fn test_extract_json_array() {
232        let content = "Results: [1, 2, 3]";
233        let result = StructuredOutput::extract(content, &OutputMode::Json).unwrap();
234        assert_eq!(result, json!([1, 2, 3]));
235    }
236
237    #[test]
238    fn test_extract_json_not_found() {
239        let content = "No JSON here, just plain text.";
240        let result = StructuredOutput::extract(content, &OutputMode::Json);
241        assert!(result.is_err());
242    }
243
244    #[test]
245    fn test_validated_json_success() {
246        let schema = json!({
247            "type": "object",
248            "required": ["name"]
249        });
250        let content = r#"{"name": "test", "value": 42}"#;
251        let result =
252            StructuredOutput::extract(content, &OutputMode::ValidatedJson { schema }).unwrap();
253        assert_eq!(result["name"], "test");
254    }
255
256    #[test]
257    fn test_validated_json_wrong_type() {
258        let schema = json!({"type": "array"});
259        let content = r#"{"name": "test"}"#;
260        let result = StructuredOutput::extract(content, &OutputMode::ValidatedJson { schema });
261        assert!(result.is_err());
262    }
263
264    #[test]
265    fn test_validated_json_missing_required() {
266        let schema = json!({
267            "type": "object",
268            "required": ["name", "age"]
269        });
270        let content = r#"{"name": "test"}"#;
271        let result = StructuredOutput::extract(content, &OutputMode::ValidatedJson { schema });
272        assert!(result.is_err());
273    }
274
275    #[test]
276    fn test_nested_brackets() {
277        let content = r#"Result: {"a": {"b": [1, 2]}, "c": 3}"#;
278        let result = StructuredOutput::extract_json(content).unwrap();
279        assert_eq!(result["a"]["b"], json!([1, 2]));
280        assert_eq!(result["c"], 3);
281    }
282
283    #[test]
284    fn test_json_with_string_containing_brackets() {
285        let content = r#"{"text": "hello {world}"}"#;
286        let result = StructuredOutput::extract_json(content).unwrap();
287        assert_eq!(result["text"], "hello {world}");
288    }
289
290    #[test]
291    fn test_output_mode_requires_json() {
292        assert!(!OutputMode::Text.requires_json());
293        assert!(OutputMode::Json.requires_json());
294        assert!(OutputMode::ValidatedJson { schema: json!({}) }.requires_json());
295    }
296}