Skip to main content

embacle/
structured_output.rs

1// ABOUTME: Standalone function forcing any LlmProvider to return schema-valid JSON
2// ABOUTME: Includes lightweight JSON schema validation, markdown fence extraction, and retry loop
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7//! # Structured Output Enforcement
8//!
9//! Forces any [`LlmProvider`](crate::types::LlmProvider) to return JSON that validates against a provided
10//! JSON Schema. The module injects schema instructions into the system message,
11//! extracts JSON from the response (including markdown fences), validates against
12//! the schema, and retries with validation feedback on failure.
13//!
14//! ## Schema Validation Coverage
15//!
16//! The built-in validator checks `type`, `required`, recursive `properties`,
17//! array `items`, `enum` values, numeric `minimum`/`maximum`, and
18//! `additionalProperties: false`. It does not cover the full JSON Schema
19//! specification (e.g., `oneOf`, `anyOf`, `$ref`, `pattern`).
20
21use std::fmt;
22
23use serde_json::Value;
24use tracing::{info, warn};
25
26use crate::types::{ChatMessage, ChatRequest, LlmProvider, MessageRole, RunnerError};
27
28/// Request configuration for structured JSON output
29#[derive(Debug, Clone)]
30pub struct StructuredOutputRequest {
31    /// The original chat request
32    pub request: ChatRequest,
33    /// JSON Schema the response must conform to
34    pub schema: Value,
35    /// Maximum retry attempts on validation failure
36    pub max_retries: u32,
37}
38
39/// A single schema validation error
40#[derive(Debug, Clone)]
41pub struct SchemaValidationError {
42    /// Human-readable error description
43    pub message: String,
44    /// JSON path where the error occurred (e.g., "$.name")
45    pub path: String,
46}
47
48impl fmt::Display for SchemaValidationError {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        write!(f, "{}: {}", self.path, self.message)
51    }
52}
53
54/// Request structured JSON output from any provider, with schema validation and retry.
55///
56/// # Flow
57///
58/// 1. Append schema instructions to the system message
59/// 2. Call `provider.complete()` and extract JSON from the response
60/// 3. Validate against the schema
61/// 4. On failure, append errors as user feedback and retry up to `max_retries`
62/// 5. After exhaustion, return [`RunnerError::external_service`]
63///
64/// # Errors
65///
66/// Returns [`RunnerError`] if the provider fails or validation is exhausted.
67pub async fn request_structured_output(
68    provider: &dyn LlmProvider,
69    structured_request: &StructuredOutputRequest,
70) -> Result<Value, RunnerError> {
71    let schema_str = serde_json::to_string_pretty(&structured_request.schema)
72        .map_err(|e| RunnerError::internal(format!("failed to serialize schema: {e}")))?;
73
74    let schema_instruction = format!(
75        "\n\nYou MUST respond with ONLY valid JSON that conforms to the following JSON Schema. \
76         Do NOT include any explanatory text, markdown formatting, or anything other than the \
77         JSON object.\n\nSchema:\n```json\n{schema_str}\n```"
78    );
79
80    let mut messages = structured_request.request.messages.clone();
81
82    // Inject schema instruction into the system message
83    inject_schema_instruction(&mut messages, &schema_instruction);
84
85    let total_attempts = structured_request.max_retries + 1;
86    for attempt in 0..total_attempts {
87        let request = ChatRequest {
88            messages: messages.clone(),
89            model: structured_request.request.model.clone(),
90            temperature: structured_request.request.temperature,
91            max_tokens: structured_request.request.max_tokens,
92            stream: false,
93            tools: structured_request.request.tools.clone(),
94            tool_choice: structured_request.request.tool_choice.clone(),
95            top_p: structured_request.request.top_p,
96            stop: structured_request.request.stop.clone(),
97            response_format: structured_request.request.response_format.clone(),
98            turn_id: structured_request.request.turn_id,
99        };
100
101        let response = provider.complete(&request).await?;
102
103        // Try to extract JSON from the response
104        let json_str = extract_json_from_response(&response.content);
105
106        let parsed: Value = match serde_json::from_str(&json_str) {
107            Ok(v) => v,
108            Err(parse_err) => {
109                warn!(
110                    attempt,
111                    error = %parse_err,
112                    "structured output: failed to parse JSON from response"
113                );
114                if attempt < structured_request.max_retries {
115                    messages.push(ChatMessage::assistant(response.content.clone()));
116                    messages.push(ChatMessage::user(format!(
117                        "Your response was not valid JSON: {parse_err}. \
118                         Please respond with ONLY a valid JSON object matching the schema."
119                    )));
120                }
121                continue;
122            }
123        };
124
125        let errors = validate_against_schema(&parsed, &structured_request.schema);
126
127        if errors.is_empty() {
128            info!(attempt, "structured output: validation passed");
129            return Ok(parsed);
130        }
131
132        warn!(
133            attempt,
134            error_count = errors.len(),
135            "structured output: schema validation failed"
136        );
137
138        if attempt < structured_request.max_retries {
139            let error_feedback: Vec<String> = errors.iter().map(ToString::to_string).collect();
140            messages.push(ChatMessage::assistant(response.content.clone()));
141            messages.push(ChatMessage::user(format!(
142                "Your JSON response had validation errors:\n- {}\n\
143                 Please fix these and respond with ONLY a valid JSON object.",
144                error_feedback.join("\n- ")
145            )));
146        }
147    }
148
149    Err(RunnerError::external_service(
150        provider.name(),
151        "structured output validation exhausted after all retries",
152    ))
153}
154
155/// Inject schema instruction into the system message, or create one
156fn inject_schema_instruction(messages: &mut Vec<ChatMessage>, instruction: &str) {
157    if let Some(first) = messages.first_mut() {
158        if first.role == MessageRole::System {
159            let augmented = format!("{}{instruction}", first.content);
160            *first = ChatMessage::system(augmented);
161            return;
162        }
163    }
164    messages.insert(0, ChatMessage::system(instruction.to_owned()));
165}
166
167/// Extract JSON content from a response, handling markdown code fences.
168///
169/// CLI runners often wrap JSON in markdown fences (e.g. `` ```json ... ``` ``).
170/// This function strips those fences and returns the raw JSON object using a
171/// brace-depth counter to find the outermost `{...}` block.
172pub fn extract_json_from_response(content: &str) -> String {
173    let trimmed = content.trim();
174
175    // Fast path: already starts with `{`
176    if trimmed.starts_with('{') {
177        return extract_braced_json(trimmed);
178    }
179
180    // Try to find JSON inside markdown fences
181    if let Some(start) = trimmed.find("```") {
182        let after_fence = &trimmed[start + 3..];
183        // Skip optional language tag (e.g., "json")
184        let content_start = after_fence.find('\n').map_or(0, |pos| pos + 1);
185        let fence_content = &after_fence[content_start..];
186
187        if let Some(end) = fence_content.find("```") {
188            let inside = fence_content[..end].trim();
189            if inside.starts_with('{') {
190                return extract_braced_json(inside);
191            }
192        }
193    }
194
195    // Last resort: find the first `{` and extract from there
196    if let Some(brace_pos) = trimmed.find('{') {
197        return extract_braced_json(&trimmed[brace_pos..]);
198    }
199
200    trimmed.to_owned()
201}
202
203/// Extract a complete JSON object using brace-depth counting
204fn extract_braced_json(text: &str) -> String {
205    let mut depth: i32 = 0;
206    let mut in_string = false;
207    let mut escape_next = false;
208
209    for (i, ch) in text.char_indices() {
210        if escape_next {
211            escape_next = false;
212            continue;
213        }
214
215        match ch {
216            '\\' if in_string => escape_next = true,
217            '"' => in_string = !in_string,
218            '{' if !in_string => depth += 1,
219            '}' if !in_string => {
220                depth -= 1;
221                if depth == 0 {
222                    return text[..=i].to_owned();
223                }
224            }
225            _ => {}
226        }
227    }
228
229    text.to_owned()
230}
231
232/// Validate a JSON value against a schema.
233///
234/// Checks: `type`, `required`, recursive `properties`, array `items`,
235/// `enum` values, numeric `minimum`/`maximum`, and `additionalProperties: false`.
236pub fn validate_against_schema(value: &Value, schema: &Value) -> Vec<SchemaValidationError> {
237    let mut errors = Vec::new();
238    validate_value(value, schema, "$", &mut errors);
239    errors
240}
241
242fn validate_value(
243    value: &Value,
244    schema: &Value,
245    path: &str,
246    errors: &mut Vec<SchemaValidationError>,
247) {
248    // Check type
249    if let Some(expected_type) = schema.get("type").and_then(Value::as_str) {
250        let actual_type = json_type_name(value);
251        if actual_type != expected_type {
252            errors.push(SchemaValidationError {
253                message: format!("expected type \"{expected_type}\", got \"{actual_type}\""),
254                path: path.to_owned(),
255            });
256            return;
257        }
258    }
259
260    // Check enum constraint
261    if let Some(enum_values) = schema.get("enum").and_then(Value::as_array) {
262        if !enum_values.contains(value) {
263            errors.push(SchemaValidationError {
264                message: format!("value not in enum: expected one of {enum_values:?}, got {value}"),
265                path: path.to_owned(),
266            });
267            return;
268        }
269    }
270
271    // Numeric bounds (minimum, maximum)
272    if let Some(num) = value.as_f64() {
273        if let Some(min) = schema.get("minimum").and_then(Value::as_f64) {
274            if num < min {
275                errors.push(SchemaValidationError {
276                    message: format!("value {num} is less than minimum {min}"),
277                    path: path.to_owned(),
278                });
279            }
280        }
281        if let Some(max) = schema.get("maximum").and_then(Value::as_f64) {
282            if num > max {
283                errors.push(SchemaValidationError {
284                    message: format!("value {num} exceeds maximum {max}"),
285                    path: path.to_owned(),
286                });
287            }
288        }
289    }
290
291    // For objects: check required fields, property types (recursive), additional properties
292    if let Some(obj) = value.as_object() {
293        if let Some(required) = schema.get("required").and_then(Value::as_array) {
294            for req in required {
295                if let Some(field_name) = req.as_str() {
296                    if !obj.contains_key(field_name) {
297                        errors.push(SchemaValidationError {
298                            message: format!("missing required field \"{field_name}\""),
299                            path: format!("{path}.{field_name}"),
300                        });
301                    }
302                }
303            }
304        }
305
306        if let Some(properties) = schema.get("properties").and_then(Value::as_object) {
307            for (prop_name, prop_schema) in properties {
308                if let Some(prop_value) = obj.get(prop_name) {
309                    let prop_path = format!("{path}.{prop_name}");
310                    // Recurse into nested properties
311                    validate_value(prop_value, prop_schema, &prop_path, errors);
312                }
313            }
314
315            // Check additionalProperties: false
316            if schema.get("additionalProperties") == Some(&Value::Bool(false)) {
317                for key in obj.keys() {
318                    if !properties.contains_key(key) {
319                        errors.push(SchemaValidationError {
320                            message: format!("unexpected additional property \"{key}\""),
321                            path: format!("{path}.{key}"),
322                        });
323                    }
324                }
325            }
326        }
327    }
328
329    // For arrays: validate items against the items schema
330    if let Some(arr) = value.as_array() {
331        if let Some(items_schema) = schema.get("items") {
332            for (i, item) in arr.iter().enumerate() {
333                let item_path = format!("{path}[{i}]");
334                validate_value(item, items_schema, &item_path, errors);
335            }
336        }
337    }
338}
339
340/// Map a JSON value to its JSON Schema type name
341fn json_type_name(value: &Value) -> &'static str {
342    match value {
343        Value::Null => "null",
344        Value::Bool(_) => "boolean",
345        Value::Number(n) => {
346            if n.is_i64() || n.is_u64() {
347                "integer"
348            } else {
349                "number"
350            }
351        }
352        Value::String(_) => "string",
353        Value::Array(_) => "array",
354        Value::Object(_) => "object",
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use crate::types::{
362        ChatMessage, ChatRequest, ChatResponse, ChatStream, LlmCapabilities, LlmProvider,
363        RunnerError,
364    };
365    use async_trait::async_trait;
366    use serde_json::json;
367    use std::sync::atomic::{AtomicU32, Ordering};
368    use std::sync::Mutex;
369
370    struct TestProvider {
371        responses: Mutex<Vec<Result<ChatResponse, RunnerError>>>,
372        call_count: AtomicU32,
373    }
374
375    impl TestProvider {
376        fn new(responses: Vec<Result<ChatResponse, RunnerError>>) -> Self {
377            Self {
378                responses: Mutex::new(responses),
379                call_count: AtomicU32::new(0),
380            }
381        }
382    }
383
384    #[async_trait]
385    impl LlmProvider for TestProvider {
386        fn name(&self) -> &'static str {
387            "test"
388        }
389        fn display_name(&self) -> &str {
390            "Test Provider"
391        }
392        fn capabilities(&self) -> LlmCapabilities {
393            LlmCapabilities::text_only()
394        }
395        fn default_model(&self) -> &'static str {
396            "test-model"
397        }
398        fn available_models(&self) -> &[String] {
399            &[]
400        }
401        async fn complete(&self, _request: &ChatRequest) -> Result<ChatResponse, RunnerError> {
402            self.call_count.fetch_add(1, Ordering::SeqCst);
403            let mut responses = self.responses.lock().expect("test lock"); // Safe: test assertion
404            if responses.is_empty() {
405                Err(RunnerError::internal("no more test responses"))
406            } else {
407                responses.remove(0)
408            }
409        }
410        async fn complete_stream(&self, _request: &ChatRequest) -> Result<ChatStream, RunnerError> {
411            Err(RunnerError::internal("not supported"))
412        }
413        async fn health_check(&self) -> Result<bool, RunnerError> {
414            Ok(true)
415        }
416    }
417
418    fn make_response(content: &str) -> ChatResponse {
419        ChatResponse {
420            content: content.to_owned(),
421            model: "test-model".to_owned(),
422            usage: None,
423            finish_reason: Some("stop".to_owned()),
424            warnings: None,
425            tool_calls: None,
426        }
427    }
428
429    // --- validate_against_schema tests ---
430
431    #[test]
432    fn validate_valid_object() {
433        let schema = json!({
434            "type": "object",
435            "properties": {
436                "name": {"type": "string"},
437                "age": {"type": "integer"}
438            },
439            "required": ["name", "age"]
440        });
441
442        let value = json!({"name": "Alice", "age": 30});
443        let errors = validate_against_schema(&value, &schema);
444        assert!(errors.is_empty());
445    }
446
447    #[test]
448    fn validate_missing_required_fields() {
449        let schema = json!({
450            "type": "object",
451            "properties": {
452                "name": {"type": "string"},
453                "age": {"type": "integer"}
454            },
455            "required": ["name", "age"]
456        });
457
458        let value = json!({"name": "Alice"});
459        let errors = validate_against_schema(&value, &schema);
460        assert_eq!(errors.len(), 1);
461        assert!(errors[0].message.contains("age"));
462    }
463
464    #[test]
465    fn validate_wrong_types() {
466        let schema = json!({
467            "type": "object",
468            "properties": {
469                "name": {"type": "string"},
470                "age": {"type": "integer"}
471            },
472            "required": ["name"]
473        });
474
475        let value = json!({"name": 42, "age": "not a number"});
476        let errors = validate_against_schema(&value, &schema);
477        assert_eq!(errors.len(), 2);
478    }
479
480    #[test]
481    fn validate_wrong_root_type() {
482        let schema = json!({"type": "object"});
483        let value = json!("just a string");
484        let errors = validate_against_schema(&value, &schema);
485        assert_eq!(errors.len(), 1);
486        assert!(errors[0].message.contains("expected type \"object\""));
487    }
488
489    // --- extract_json_from_response tests ---
490
491    #[test]
492    fn extract_raw_json() {
493        let content = r#"{"name": "Alice", "age": 30}"#;
494        let extracted = extract_json_from_response(content);
495        let parsed: Value = serde_json::from_str(&extracted).expect("valid JSON"); // Safe: test assertion
496        assert_eq!(parsed["name"], "Alice");
497    }
498
499    #[test]
500    fn extract_json_from_markdown_fences() {
501        let content = "Here is the result:\n```json\n{\"name\": \"Bob\", \"age\": 25}\n```\nDone.";
502        let extracted = extract_json_from_response(content);
503        let parsed: Value = serde_json::from_str(&extracted).expect("valid JSON"); // Safe: test assertion
504        assert_eq!(parsed["name"], "Bob");
505    }
506
507    #[test]
508    fn extract_json_with_nested_braces() {
509        let content = r#"{"outer": {"inner": "value"}, "list": [1, 2]}"#;
510        let extracted = extract_json_from_response(content);
511        let parsed: Value = serde_json::from_str(&extracted).expect("valid JSON"); // Safe: test assertion
512        assert_eq!(parsed["outer"]["inner"], "value");
513    }
514
515    // --- full retry loop tests ---
516
517    #[tokio::test]
518    async fn full_retry_loop_eventual_success() {
519        let provider = TestProvider::new(vec![
520            Ok(make_response("not json at all")),
521            Ok(make_response(r#"{"name": "Alice", "age": 30}"#)),
522        ]);
523
524        let schema = json!({
525            "type": "object",
526            "properties": {
527                "name": {"type": "string"},
528                "age": {"type": "integer"}
529            },
530            "required": ["name", "age"]
531        });
532
533        let structured = StructuredOutputRequest {
534            request: ChatRequest::new(vec![ChatMessage::user("give me data")]),
535            schema,
536            max_retries: 2,
537        };
538
539        let result = request_structured_output(&provider, &structured)
540            .await
541            .expect("should succeed on retry"); // Safe: test assertion
542        assert_eq!(result["name"], "Alice");
543        assert_eq!(result["age"], 30);
544    }
545
546    #[tokio::test]
547    async fn exhaustion_returns_error() {
548        let provider = TestProvider::new(vec![
549            Ok(make_response("garbage")),
550            Ok(make_response("still garbage")),
551            Ok(make_response("nope")),
552        ]);
553
554        let schema = json!({
555            "type": "object",
556            "required": ["name"]
557        });
558
559        let structured = StructuredOutputRequest {
560            request: ChatRequest::new(vec![ChatMessage::user("give me data")]),
561            schema,
562            max_retries: 2,
563        };
564
565        let result = request_structured_output(&provider, &structured).await;
566        assert!(result.is_err());
567        let err = result.unwrap_err();
568        assert!(err.message.contains("exhausted"));
569    }
570
571    // --- enhanced validation tests ---
572
573    #[test]
574    fn validate_nested_object() {
575        let schema = json!({
576            "type": "object",
577            "properties": {
578                "address": {
579                    "type": "object",
580                    "properties": {
581                        "city": {"type": "string"},
582                        "zip": {"type": "string"}
583                    },
584                    "required": ["city"]
585                }
586            },
587            "required": ["address"]
588        });
589
590        let valid = json!({"address": {"city": "Paris", "zip": "75001"}});
591        assert!(validate_against_schema(&valid, &schema).is_empty());
592
593        let missing_city = json!({"address": {"zip": "75001"}});
594        let errors = validate_against_schema(&missing_city, &schema);
595        assert_eq!(errors.len(), 1);
596        assert!(errors[0].path.contains("city"));
597
598        let wrong_type = json!({"address": {"city": 42}});
599        let errors = validate_against_schema(&wrong_type, &schema);
600        assert_eq!(errors.len(), 1);
601        assert!(errors[0].message.contains("expected type \"string\""));
602    }
603
604    #[test]
605    fn validate_array_items() {
606        let schema = json!({
607            "type": "array",
608            "items": {"type": "string"}
609        });
610
611        let valid = json!(["a", "b", "c"]);
612        assert!(validate_against_schema(&valid, &schema).is_empty());
613
614        let invalid = json!(["a", 42, "c"]);
615        let errors = validate_against_schema(&invalid, &schema);
616        assert_eq!(errors.len(), 1);
617        assert!(errors[0].path.contains("[1]"));
618    }
619
620    #[test]
621    fn validate_enum_values() {
622        let schema = json!({
623            "type": "string",
624            "enum": ["red", "green", "blue"]
625        });
626
627        let valid = json!("green");
628        assert!(validate_against_schema(&valid, &schema).is_empty());
629
630        let invalid = json!("yellow");
631        let errors = validate_against_schema(&invalid, &schema);
632        assert_eq!(errors.len(), 1);
633        assert!(errors[0].message.contains("not in enum"));
634    }
635
636    #[test]
637    fn validate_numeric_bounds() {
638        let schema = json!({
639            "type": "integer",
640            "minimum": 0,
641            "maximum": 100
642        });
643
644        let valid = json!(50);
645        assert!(validate_against_schema(&valid, &schema).is_empty());
646
647        let too_low = json!(-1);
648        let errors = validate_against_schema(&too_low, &schema);
649        assert_eq!(errors.len(), 1);
650        assert!(errors[0].message.contains("less than minimum"));
651
652        let too_high = json!(101);
653        let errors = validate_against_schema(&too_high, &schema);
654        assert_eq!(errors.len(), 1);
655        assert!(errors[0].message.contains("exceeds maximum"));
656    }
657
658    #[test]
659    fn validate_additional_properties_false() {
660        let schema = json!({
661            "type": "object",
662            "properties": {
663                "name": {"type": "string"}
664            },
665            "additionalProperties": false
666        });
667
668        let valid = json!({"name": "Alice"});
669        assert!(validate_against_schema(&valid, &schema).is_empty());
670
671        let with_extra = json!({"name": "Alice", "age": 30});
672        let errors = validate_against_schema(&with_extra, &schema);
673        assert_eq!(errors.len(), 1);
674        assert!(errors[0].message.contains("unexpected additional property"));
675    }
676}