oxi_agent/
structured_output.rs1use serde_json::Value;
7
8#[derive(Debug, Clone, Default)]
10pub enum OutputMode {
11 #[default]
13 Text,
14 Json,
16 ValidatedJson {
18 schema: Value,
20 },
21}
22
23impl OutputMode {
24 pub fn requires_json(&self) -> bool {
26 matches!(self, OutputMode::Json | OutputMode::ValidatedJson { .. })
27 }
28}
29
30pub struct StructuredOutput;
32
33impl StructuredOutput {
34 pub fn extract(content: &str, mode: &OutputMode) -> Result<Value, StructuredOutputError> {
40 match mode {
41 OutputMode::Text => Ok(Value::String(content.to_string())),
42 OutputMode::Json => Self::extract_json(content),
43 OutputMode::ValidatedJson { schema } => {
44 let json = Self::extract_json(content)?;
45 Self::validate(&json, schema)?;
46 Ok(json)
47 }
48 }
49 }
50
51 pub fn extract_json(content: &str) -> Result<Value, StructuredOutputError> {
58 if let Ok(v) = serde_json::from_str::<Value>(content) {
60 return Ok(v);
61 }
62
63 if let Some(start) = content.find("```json") {
65 let json_start = start + 7;
66 if let Some(end) = content[json_start..].find("```") {
67 let json_str = content[json_start..json_start + end].trim();
68 return serde_json::from_str(json_str).map_err(|e| {
69 StructuredOutputError::ParseError(format!(
70 "JSON parse error in code block: {}",
71 e
72 ))
73 });
74 }
75 }
76
77 for (open, close) in [('{', '}'), ('[', ']')] {
79 if let Some(start) = content.find(open) {
80 let substr = &content[start..];
81 if let Some(end) = Self::find_matching_bracket(substr, open, close) {
82 let json_str = &substr[..=end];
83 if let Ok(v) = serde_json::from_str(json_str) {
84 return Ok(v);
85 }
86 }
87 }
88 }
89
90 Err(StructuredOutputError::NotFound(
91 "No JSON found in response".into(),
92 ))
93 }
94
95 pub fn validate(json: &Value, schema: &Value) -> Result<(), StructuredOutputError> {
100 if let Some(expected_type) = schema.get("type").and_then(|t| t.as_str()) {
102 let actual_matches = match expected_type {
103 "object" => json.is_object(),
104 "array" => json.is_array(),
105 "string" => json.is_string(),
106 "number" => json.is_number(),
107 "integer" => json.is_i64() || json.is_u64(),
108 "boolean" => json.is_boolean(),
109 "null" => json.is_null(),
110 _ => true,
111 };
112 if !actual_matches {
113 return Err(StructuredOutputError::ValidationError(format!(
114 "Expected type '{}', got '{}'",
115 expected_type,
116 json_type_name(json)
117 )));
118 }
119 }
120
121 if let Some(required) = schema.get("required").and_then(|r| r.as_array()) {
123 if let Some(obj) = json.as_object() {
124 for field in required {
125 if let Some(name) = field.as_str() {
126 if !obj.contains_key(name) {
127 return Err(StructuredOutputError::ValidationError(format!(
128 "Missing required field: '{}'",
129 name
130 )));
131 }
132 }
133 }
134 }
135 }
136
137 Ok(())
138 }
139
140 fn find_matching_bracket(s: &str, open: char, close: char) -> Option<usize> {
142 let mut depth = 0;
143 let mut in_string = false;
144 let mut escape_next = false;
145
146 for (i, c) in s.char_indices() {
147 if escape_next {
148 escape_next = false;
149 continue;
150 }
151 match c {
152 '\\' if in_string => escape_next = true,
153 '"' => in_string = !in_string,
154 _ if in_string => {}
155 c if c == open => depth += 1,
156 c if c == close => {
157 depth -= 1;
158 if depth == 0 {
159 return Some(i);
160 }
161 }
162 _ => {}
163 }
164 }
165 None
166 }
167}
168
169#[derive(Debug, thiserror::Error)]
171pub enum StructuredOutputError {
172 #[error("JSON not found: {0}")]
174 NotFound(String),
175
176 #[error("{0}")]
178 ParseError(String),
179
180 #[error("Validation error: {0}")]
182 ValidationError(String),
183}
184
185fn json_type_name(v: &Value) -> &'static str {
187 match v {
188 Value::Null => "null",
189 Value::Bool(_) => "boolean",
190 Value::Number(_) => "number",
191 Value::String(_) => "string",
192 Value::Array(_) => "array",
193 Value::Object(_) => "object",
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use serde_json::json;
201
202 #[test]
203 fn test_extract_text_mode() {
204 let result = StructuredOutput::extract("hello world", &OutputMode::Text).unwrap();
205 assert_eq!(result, Value::String("hello world".to_string()));
206 }
207
208 #[test]
209 fn test_extract_pure_json() {
210 let json = r#"{"name": "test", "value": 42}"#;
211 let result = StructuredOutput::extract(json, &OutputMode::Json).unwrap();
212 assert_eq!(result["name"], "test");
213 assert_eq!(result["value"], 42);
214 }
215
216 #[test]
217 fn test_extract_json_code_block() {
218 let content = "Here is the result:\n```json\n{\"status\": \"ok\"}\n```\nDone.";
219 let result = StructuredOutput::extract(content, &OutputMode::Json).unwrap();
220 assert_eq!(result["status"], "ok");
221 }
222
223 #[test]
224 fn test_extract_json_embedded_brackets() {
225 let content = "The answer is {\"x\": 1, \"y\": 2} as shown above.";
226 let result = StructuredOutput::extract(content, &OutputMode::Json).unwrap();
227 assert_eq!(result["x"], 1);
228 }
229
230 #[test]
231 fn test_extract_json_array() {
232 let content = "Results: [1, 2, 3]";
233 let result = StructuredOutput::extract(content, &OutputMode::Json).unwrap();
234 assert_eq!(result, json!([1, 2, 3]));
235 }
236
237 #[test]
238 fn test_extract_json_not_found() {
239 let content = "No JSON here, just plain text.";
240 let result = StructuredOutput::extract(content, &OutputMode::Json);
241 assert!(result.is_err());
242 }
243
244 #[test]
245 fn test_validated_json_success() {
246 let schema = json!({
247 "type": "object",
248 "required": ["name"]
249 });
250 let content = r#"{"name": "test", "value": 42}"#;
251 let result =
252 StructuredOutput::extract(content, &OutputMode::ValidatedJson { schema }).unwrap();
253 assert_eq!(result["name"], "test");
254 }
255
256 #[test]
257 fn test_validated_json_wrong_type() {
258 let schema = json!({"type": "array"});
259 let content = r#"{"name": "test"}"#;
260 let result = StructuredOutput::extract(content, &OutputMode::ValidatedJson { schema });
261 assert!(result.is_err());
262 }
263
264 #[test]
265 fn test_validated_json_missing_required() {
266 let schema = json!({
267 "type": "object",
268 "required": ["name", "age"]
269 });
270 let content = r#"{"name": "test"}"#;
271 let result = StructuredOutput::extract(content, &OutputMode::ValidatedJson { schema });
272 assert!(result.is_err());
273 }
274
275 #[test]
276 fn test_nested_brackets() {
277 let content = r#"Result: {"a": {"b": [1, 2]}, "c": 3}"#;
278 let result = StructuredOutput::extract_json(content).unwrap();
279 assert_eq!(result["a"]["b"], json!([1, 2]));
280 assert_eq!(result["c"], 3);
281 }
282
283 #[test]
284 fn test_json_with_string_containing_brackets() {
285 let content = r#"{"text": "hello {world}"}"#;
286 let result = StructuredOutput::extract_json(content).unwrap();
287 assert_eq!(result["text"], "hello {world}");
288 }
289
290 #[test]
291 fn test_output_mode_requires_json() {
292 assert!(!OutputMode::Text.requires_json());
293 assert!(OutputMode::Json.requires_json());
294 assert!(OutputMode::ValidatedJson { schema: json!({}) }.requires_json());
295 }
296}