Skip to main content

git_iris/agents/
output_validator.rs

1//! Output validation and error recovery for agent responses
2//!
3//! This module provides robust JSON validation and recovery mechanisms
4//! for handling malformed or partially correct LLM responses.
5
6use anyhow::{Context, Result};
7use schemars::JsonSchema;
8use schemars::schema_for;
9use serde::de::DeserializeOwned;
10use serde_json::{Map, Value};
11
12use crate::agents::debug;
13
14/// Validation result with recovery information
15#[derive(Debug)]
16pub struct ValidationResult<T> {
17    /// The parsed value (if successful)
18    pub value: Option<T>,
19    /// Warnings encountered during parsing (non-fatal issues)
20    pub warnings: Vec<String>,
21    /// Whether recovery was needed
22    pub recovered: bool,
23}
24
25impl<T> ValidationResult<T> {
26    fn success(value: T) -> Self {
27        Self {
28            value: Some(value),
29            warnings: vec![],
30            recovered: false,
31        }
32    }
33
34    fn recovered(value: T, warnings: Vec<String>) -> Self {
35        Self {
36            value: Some(value),
37            warnings,
38            recovered: true,
39        }
40    }
41}
42
43/// Validate and parse JSON with schema validation and error recovery
44///
45/// # Errors
46///
47/// Returns an error when the JSON cannot be parsed even after recovery attempts.
48pub fn validate_and_parse<T>(json_str: &str) -> Result<ValidationResult<T>>
49where
50    T: JsonSchema + DeserializeOwned,
51{
52    let mut warnings = Vec::new();
53
54    // First, try direct parsing
55    match serde_json::from_str::<T>(json_str) {
56        Ok(value) => {
57            debug::debug_json_parse_success(std::any::type_name::<T>());
58            return Ok(ValidationResult::success(value));
59        }
60        Err(e) => {
61            debug::debug_json_parse_error(&format!("Initial parse failed: {}", e));
62            warnings.push(format!("Initial parse failed: {}", e));
63        }
64    }
65
66    // Parse as generic Value for recovery attempts
67    let mut json_value: Value = serde_json::from_str(json_str)
68        .context("Response is not valid JSON - cannot attempt recovery")?;
69
70    // Get the expected schema
71    let schema = schema_for!(T);
72    let schema_value = serde_json::to_value(&schema).unwrap_or(Value::Null);
73
74    // Attempt recovery based on schema
75    if let Some(obj) = json_value.as_object_mut() {
76        recover_missing_fields(obj, &schema_value, &mut warnings);
77        recover_type_mismatches(obj, &schema_value, &mut warnings);
78        recover_null_to_defaults(obj, &schema_value, &mut warnings);
79    }
80
81    // Try parsing again after recovery
82    match serde_json::from_value::<T>(json_value.clone()) {
83        Ok(value) => {
84            debug::debug_context_management(
85                "JSON recovery successful",
86                &format!("{} warnings", warnings.len()),
87            );
88            Ok(ValidationResult::recovered(value, warnings))
89        }
90        Err(e) => {
91            // Final attempt: try to extract just the required fields
92            let final_value = extract_required_fields(&json_value, &schema_value);
93            match serde_json::from_value::<T>(final_value) {
94                Ok(value) => {
95                    warnings.push(format!("Extracted required fields only: {}", e));
96                    Ok(ValidationResult::recovered(value, warnings))
97                }
98                Err(final_e) => Err(anyhow::anyhow!(
99                    "Failed to parse JSON even after recovery attempts: {}",
100                    final_e
101                )),
102            }
103        }
104    }
105}
106
107/// Recover missing required fields by adding defaults
108fn recover_missing_fields(
109    obj: &mut Map<String, Value>,
110    schema: &Value,
111    warnings: &mut Vec<String>,
112) {
113    let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) else {
114        return;
115    };
116
117    let required = schema
118        .get("required")
119        .and_then(|r| r.as_array())
120        .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
121        .unwrap_or_default();
122
123    for field_name in required {
124        if !obj.contains_key(field_name)
125            && let Some(prop_schema) = properties.get(field_name)
126        {
127            let default_value = get_default_for_type(prop_schema);
128            warnings.push(format!(
129                "Added missing required field '{}' with default value",
130                field_name
131            ));
132            obj.insert(field_name.to_string(), default_value);
133        }
134    }
135}
136
137/// Recover type mismatches by attempting conversion
138fn recover_type_mismatches(
139    obj: &mut Map<String, Value>,
140    schema: &Value,
141    warnings: &mut Vec<String>,
142) {
143    let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) else {
144        return;
145    };
146
147    for (field_name, prop_schema) in properties {
148        if let Some(current_value) = obj.get(field_name).cloned() {
149            let expected_type = prop_schema
150                .get("type")
151                .and_then(|t| t.as_str())
152                .unwrap_or("any");
153
154            let converted = match expected_type {
155                "string" => convert_to_string(&current_value),
156                "array" => Some(convert_to_array(&current_value)),
157                "boolean" => convert_to_boolean(&current_value),
158                "integer" | "number" => convert_to_number(&current_value),
159                _ => None,
160            };
161
162            if let Some(new_value) = converted
163                && new_value != current_value
164            {
165                warnings.push(format!(
166                    "Converted field '{}' from {:?} to {}",
167                    field_name,
168                    type_name(&current_value),
169                    expected_type
170                ));
171                obj.insert(field_name.clone(), new_value);
172            }
173        }
174    }
175}
176
177/// Recover null values by replacing with appropriate defaults
178fn recover_null_to_defaults(
179    obj: &mut Map<String, Value>,
180    schema: &Value,
181    warnings: &mut Vec<String>,
182) {
183    let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) else {
184        return;
185    };
186
187    for (field_name, prop_schema) in properties {
188        if let Some(Value::Null) = obj.get(field_name) {
189            // Check if field is nullable (has "anyOf" with null type)
190            let is_nullable = prop_schema
191                .get("anyOf")
192                .and_then(|a| a.as_array())
193                .is_some_and(|arr| {
194                    arr.iter()
195                        .any(|v| v.get("type") == Some(&Value::String("null".to_string())))
196                });
197
198            if !is_nullable {
199                let default_value = get_default_for_type(prop_schema);
200                warnings.push(format!(
201                    "Replaced null value in non-nullable field '{}' with default",
202                    field_name
203                ));
204                obj.insert(field_name.clone(), default_value);
205            }
206        }
207    }
208}
209
210/// Extract only required fields from a JSON value
211fn extract_required_fields(value: &Value, schema: &Value) -> Value {
212    let Some(obj) = value.as_object() else {
213        return value.clone();
214    };
215
216    let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) else {
217        return value.clone();
218    };
219
220    let required: Vec<&str> = schema
221        .get("required")
222        .and_then(|r| r.as_array())
223        .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect())
224        .unwrap_or_default();
225
226    let mut result = Map::new();
227
228    for field_name in required {
229        if let Some(field_value) = obj.get(field_name) {
230            result.insert(field_name.to_string(), field_value.clone());
231        } else if let Some(prop_schema) = properties.get(field_name) {
232            result.insert(field_name.to_string(), get_default_for_type(prop_schema));
233        }
234    }
235
236    // Also include optional fields that are present
237    for (field_name, field_value) in obj {
238        if !result.contains_key(field_name) {
239            result.insert(field_name.clone(), field_value.clone());
240        }
241    }
242
243    Value::Object(result)
244}
245
246/// Get a sensible default value for a JSON schema type
247fn get_default_for_type(schema: &Value) -> Value {
248    // Check for explicit default first
249    if let Some(default) = schema.get("default") {
250        return default.clone();
251    }
252
253    // Check for anyOf (nullable types)
254    if let Some(any_of) = schema.get("anyOf").and_then(|a| a.as_array()) {
255        for variant in any_of {
256            if variant.get("type") == Some(&Value::String("null".to_string())) {
257                return Value::Null;
258            }
259        }
260        // Use first non-null type's default
261        if let Some(first) = any_of.first() {
262            return get_default_for_type(first);
263        }
264    }
265
266    match schema.get("type").and_then(|t| t.as_str()) {
267        Some("string") => Value::String(String::new()),
268        Some("array") => Value::Array(vec![]),
269        Some("object") => Value::Object(Map::new()),
270        Some("boolean") => Value::Bool(false),
271        Some("integer" | "number") => Value::Number(0.into()),
272        _ => Value::Null,
273    }
274}
275
276/// Convert a value to string if possible
277fn convert_to_string(value: &Value) -> Option<Value> {
278    match value {
279        Value::String(_) => Some(value.clone()),
280        Value::Number(n) => Some(Value::String(n.to_string())),
281        Value::Bool(b) => Some(Value::String(b.to_string())),
282        Value::Null => Some(Value::String(String::new())),
283        Value::Array(arr) => {
284            let strings: Vec<String> = arr
285                .iter()
286                .filter_map(|v| v.as_str().map(String::from))
287                .collect();
288            Some(Value::String(strings.join(", ")))
289        }
290        Value::Object(_) => None,
291    }
292}
293
294/// Convert a value to array
295fn convert_to_array(value: &Value) -> Value {
296    match value {
297        Value::Array(_) => value.clone(),
298        Value::Null => Value::Array(vec![]),
299        // Wrap single value in array
300        other => Value::Array(vec![other.clone()]),
301    }
302}
303
304/// Convert a value to boolean if possible
305fn convert_to_boolean(value: &Value) -> Option<Value> {
306    match value {
307        Value::Bool(_) => Some(value.clone()),
308        Value::String(s) => match s.to_lowercase().as_str() {
309            "true" | "yes" | "1" => Some(Value::Bool(true)),
310            "false" | "no" | "0" | "" => Some(Value::Bool(false)),
311            _ => None,
312        },
313        Value::Number(n) => Some(Value::Bool(n.as_f64().unwrap_or(0.0) != 0.0)),
314        Value::Null => Some(Value::Bool(false)),
315        Value::Array(_) | Value::Object(_) => None,
316    }
317}
318
319/// Convert a value to number if possible
320fn convert_to_number(value: &Value) -> Option<Value> {
321    match value {
322        Value::Number(_) => Some(value.clone()),
323        Value::String(s) => {
324            // Try parsing as integer first to preserve integer semantics
325            if let Ok(i) = s.parse::<i64>() {
326                return Some(Value::Number(i.into()));
327            }
328            // Fall back to float
329            s.parse::<f64>()
330                .ok()
331                .and_then(serde_json::Number::from_f64)
332                .map(Value::Number)
333        }
334        Value::Bool(b) => Some(Value::Number(i32::from(*b).into())),
335        Value::Null | Value::Array(_) | Value::Object(_) => None,
336    }
337}
338
339/// Get a human-readable type name for a JSON value
340fn type_name(value: &Value) -> &'static str {
341    match value {
342        Value::Null => "null",
343        Value::Bool(_) => "boolean",
344        Value::Number(_) => "number",
345        Value::String(_) => "string",
346        Value::Array(_) => "array",
347        Value::Object(_) => "object",
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use serde::{Deserialize, Serialize};
355
356    #[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
357    struct TestOutput {
358        title: String,
359        message: String,
360        #[serde(default)]
361        tags: Vec<String>,
362        #[serde(default)]
363        count: i32,
364    }
365
366    #[test]
367    fn test_valid_json_parses_directly() {
368        let json = r#"{"title": "Test", "message": "Hello", "tags": ["a", "b"], "count": 5}"#;
369        let result = validate_and_parse::<TestOutput>(json).expect("should parse");
370        assert!(!result.recovered);
371        assert!(result.warnings.is_empty());
372        assert_eq!(result.value.expect("should have value").title, "Test");
373    }
374
375    #[test]
376    fn test_recovers_missing_optional_fields() {
377        let json = r#"{"title": "Test", "message": "Hello"}"#;
378        let result = validate_and_parse::<TestOutput>(json).expect("should parse");
379        let value = result.value.expect("should have value");
380        assert_eq!(value.title, "Test");
381        assert!(value.tags.is_empty());
382    }
383
384    #[test]
385    fn test_converts_number_to_string() {
386        let json = r#"{"title": 123, "message": "Hello"}"#;
387        let result = validate_and_parse::<TestOutput>(json).expect("should parse");
388        assert!(result.recovered);
389        assert_eq!(result.value.expect("should have value").title, "123");
390    }
391
392    #[test]
393    fn test_converts_single_value_to_array() {
394        let json = r#"{"title": "Test", "message": "Hello", "tags": "single"}"#;
395        let result = validate_and_parse::<TestOutput>(json).expect("should parse");
396        assert!(result.recovered);
397        assert_eq!(
398            result.value.expect("should have value").tags,
399            vec!["single"]
400        );
401    }
402
403    #[test]
404    fn test_converts_string_to_number() {
405        let json = r#"{"title": "Test", "message": "Hello", "count": "42"}"#;
406        let result = validate_and_parse::<TestOutput>(json).expect("should parse");
407        assert!(result.recovered);
408        assert_eq!(result.value.expect("should have value").count, 42);
409    }
410}