git_iris/agents/
output_validator.rs

1//! Output validation and error recovery for agent responses
2//!
3//! This module provides robust JSON validation and recovery mechanisms
4//! for handling malformed or partially correct LLM responses.
5
6use anyhow::{Context, Result};
7use schemars::JsonSchema;
8use schemars::schema_for;
9use serde::de::DeserializeOwned;
10use serde_json::{Map, Value};
11
12use crate::agents::debug;
13
14/// Validation result with recovery information
15#[derive(Debug)]
16pub struct ValidationResult<T> {
17    /// The parsed value (if successful)
18    pub value: Option<T>,
19    /// Warnings encountered during parsing (non-fatal issues)
20    pub warnings: Vec<String>,
21    /// Whether recovery was needed
22    pub recovered: bool,
23}
24
25impl<T> ValidationResult<T> {
26    fn success(value: T) -> Self {
27        Self {
28            value: Some(value),
29            warnings: vec![],
30            recovered: false,
31        }
32    }
33
34    fn recovered(value: T, warnings: Vec<String>) -> Self {
35        Self {
36            value: Some(value),
37            warnings,
38            recovered: true,
39        }
40    }
41}
42
43/// Validate and parse JSON with schema validation and error recovery
44pub fn validate_and_parse<T>(json_str: &str) -> Result<ValidationResult<T>>
45where
46    T: JsonSchema + DeserializeOwned,
47{
48    let mut warnings = Vec::new();
49
50    // First, try direct parsing
51    match serde_json::from_str::<T>(json_str) {
52        Ok(value) => {
53            debug::debug_json_parse_success(std::any::type_name::<T>());
54            return Ok(ValidationResult::success(value));
55        }
56        Err(e) => {
57            debug::debug_json_parse_error(&format!("Initial parse failed: {}", e));
58            warnings.push(format!("Initial parse failed: {}", e));
59        }
60    }
61
62    // Parse as generic Value for recovery attempts
63    let mut json_value: Value = serde_json::from_str(json_str)
64        .context("Response is not valid JSON - cannot attempt recovery")?;
65
66    // Get the expected schema
67    let schema = schema_for!(T);
68    let schema_value = serde_json::to_value(&schema).unwrap_or(Value::Null);
69
70    // Attempt recovery based on schema
71    if let Some(obj) = json_value.as_object_mut() {
72        recover_missing_fields(obj, &schema_value, &mut warnings);
73        recover_type_mismatches(obj, &schema_value, &mut warnings);
74        recover_null_to_defaults(obj, &schema_value, &mut warnings);
75    }
76
77    // Try parsing again after recovery
78    match serde_json::from_value::<T>(json_value.clone()) {
79        Ok(value) => {
80            debug::debug_context_management(
81                "JSON recovery successful",
82                &format!("{} warnings", warnings.len()),
83            );
84            Ok(ValidationResult::recovered(value, warnings))
85        }
86        Err(e) => {
87            // Final attempt: try to extract just the required fields
88            let final_value = extract_required_fields(&json_value, &schema_value);
89            match serde_json::from_value::<T>(final_value) {
90                Ok(value) => {
91                    warnings.push(format!("Extracted required fields only: {}", e));
92                    Ok(ValidationResult::recovered(value, warnings))
93                }
94                Err(final_e) => Err(anyhow::anyhow!(
95                    "Failed to parse JSON even after recovery attempts: {}",
96                    final_e
97                )),
98            }
99        }
100    }
101}
102
103/// Recover missing required fields by adding defaults
104fn recover_missing_fields(
105    obj: &mut Map<String, Value>,
106    schema: &Value,
107    warnings: &mut Vec<String>,
108) {
109    let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) else {
110        return;
111    };
112
113    let required = schema
114        .get("required")
115        .and_then(|r| r.as_array())
116        .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
117        .unwrap_or_default();
118
119    for field_name in required {
120        if !obj.contains_key(field_name)
121            && let Some(prop_schema) = properties.get(field_name)
122        {
123            let default_value = get_default_for_type(prop_schema);
124            warnings.push(format!(
125                "Added missing required field '{}' with default value",
126                field_name
127            ));
128            obj.insert(field_name.to_string(), default_value);
129        }
130    }
131}
132
133/// Recover type mismatches by attempting conversion
134fn recover_type_mismatches(
135    obj: &mut Map<String, Value>,
136    schema: &Value,
137    warnings: &mut Vec<String>,
138) {
139    let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) else {
140        return;
141    };
142
143    for (field_name, prop_schema) in properties {
144        if let Some(current_value) = obj.get(field_name).cloned() {
145            let expected_type = prop_schema
146                .get("type")
147                .and_then(|t| t.as_str())
148                .unwrap_or("any");
149
150            let converted = match expected_type {
151                "string" => convert_to_string(&current_value),
152                "array" => Some(convert_to_array(&current_value)),
153                "boolean" => convert_to_boolean(&current_value),
154                "integer" | "number" => convert_to_number(&current_value),
155                _ => None,
156            };
157
158            if let Some(new_value) = converted
159                && new_value != current_value
160            {
161                warnings.push(format!(
162                    "Converted field '{}' from {:?} to {}",
163                    field_name,
164                    type_name(&current_value),
165                    expected_type
166                ));
167                obj.insert(field_name.clone(), new_value);
168            }
169        }
170    }
171}
172
173/// Recover null values by replacing with appropriate defaults
174fn recover_null_to_defaults(
175    obj: &mut Map<String, Value>,
176    schema: &Value,
177    warnings: &mut Vec<String>,
178) {
179    let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) else {
180        return;
181    };
182
183    for (field_name, prop_schema) in properties {
184        if let Some(Value::Null) = obj.get(field_name) {
185            // Check if field is nullable (has "anyOf" with null type)
186            let is_nullable = prop_schema
187                .get("anyOf")
188                .and_then(|a| a.as_array())
189                .is_some_and(|arr| {
190                    arr.iter()
191                        .any(|v| v.get("type") == Some(&Value::String("null".to_string())))
192                });
193
194            if !is_nullable {
195                let default_value = get_default_for_type(prop_schema);
196                warnings.push(format!(
197                    "Replaced null value in non-nullable field '{}' with default",
198                    field_name
199                ));
200                obj.insert(field_name.clone(), default_value);
201            }
202        }
203    }
204}
205
206/// Extract only required fields from a JSON value
207fn extract_required_fields(value: &Value, schema: &Value) -> Value {
208    let Some(obj) = value.as_object() else {
209        return value.clone();
210    };
211
212    let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) else {
213        return value.clone();
214    };
215
216    let required: Vec<&str> = schema
217        .get("required")
218        .and_then(|r| r.as_array())
219        .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect())
220        .unwrap_or_default();
221
222    let mut result = Map::new();
223
224    for field_name in required {
225        if let Some(field_value) = obj.get(field_name) {
226            result.insert(field_name.to_string(), field_value.clone());
227        } else if let Some(prop_schema) = properties.get(field_name) {
228            result.insert(field_name.to_string(), get_default_for_type(prop_schema));
229        }
230    }
231
232    // Also include optional fields that are present
233    for (field_name, field_value) in obj {
234        if !result.contains_key(field_name) {
235            result.insert(field_name.clone(), field_value.clone());
236        }
237    }
238
239    Value::Object(result)
240}
241
242/// Get a sensible default value for a JSON schema type
243fn get_default_for_type(schema: &Value) -> Value {
244    // Check for explicit default first
245    if let Some(default) = schema.get("default") {
246        return default.clone();
247    }
248
249    // Check for anyOf (nullable types)
250    if let Some(any_of) = schema.get("anyOf").and_then(|a| a.as_array()) {
251        for variant in any_of {
252            if variant.get("type") == Some(&Value::String("null".to_string())) {
253                return Value::Null;
254            }
255        }
256        // Use first non-null type's default
257        if let Some(first) = any_of.first() {
258            return get_default_for_type(first);
259        }
260    }
261
262    match schema.get("type").and_then(|t| t.as_str()) {
263        Some("string") => Value::String(String::new()),
264        Some("array") => Value::Array(vec![]),
265        Some("object") => Value::Object(Map::new()),
266        Some("boolean") => Value::Bool(false),
267        Some("integer" | "number") => Value::Number(0.into()),
268        _ => Value::Null,
269    }
270}
271
272/// Convert a value to string if possible
273fn convert_to_string(value: &Value) -> Option<Value> {
274    match value {
275        Value::String(_) => Some(value.clone()),
276        Value::Number(n) => Some(Value::String(n.to_string())),
277        Value::Bool(b) => Some(Value::String(b.to_string())),
278        Value::Null => Some(Value::String(String::new())),
279        Value::Array(arr) => {
280            let strings: Vec<String> = arr
281                .iter()
282                .filter_map(|v| v.as_str().map(String::from))
283                .collect();
284            Some(Value::String(strings.join(", ")))
285        }
286        Value::Object(_) => None,
287    }
288}
289
290/// Convert a value to array
291fn convert_to_array(value: &Value) -> Value {
292    match value {
293        Value::Array(_) => value.clone(),
294        Value::Null => Value::Array(vec![]),
295        // Wrap single value in array
296        other => Value::Array(vec![other.clone()]),
297    }
298}
299
300/// Convert a value to boolean if possible
301fn convert_to_boolean(value: &Value) -> Option<Value> {
302    match value {
303        Value::Bool(_) => Some(value.clone()),
304        Value::String(s) => match s.to_lowercase().as_str() {
305            "true" | "yes" | "1" => Some(Value::Bool(true)),
306            "false" | "no" | "0" | "" => Some(Value::Bool(false)),
307            _ => None,
308        },
309        Value::Number(n) => Some(Value::Bool(n.as_f64().unwrap_or(0.0) != 0.0)),
310        Value::Null => Some(Value::Bool(false)),
311        Value::Array(_) | Value::Object(_) => None,
312    }
313}
314
315/// Convert a value to number if possible
316fn convert_to_number(value: &Value) -> Option<Value> {
317    match value {
318        Value::Number(_) => Some(value.clone()),
319        Value::String(s) => {
320            // Try parsing as integer first to preserve integer semantics
321            if let Ok(i) = s.parse::<i64>() {
322                return Some(Value::Number(i.into()));
323            }
324            // Fall back to float
325            s.parse::<f64>()
326                .ok()
327                .and_then(serde_json::Number::from_f64)
328                .map(Value::Number)
329        }
330        Value::Bool(b) => Some(Value::Number(i32::from(*b).into())),
331        Value::Null | Value::Array(_) | Value::Object(_) => None,
332    }
333}
334
335/// Get a human-readable type name for a JSON value
336fn type_name(value: &Value) -> &'static str {
337    match value {
338        Value::Null => "null",
339        Value::Bool(_) => "boolean",
340        Value::Number(_) => "number",
341        Value::String(_) => "string",
342        Value::Array(_) => "array",
343        Value::Object(_) => "object",
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use serde::{Deserialize, Serialize};
351
352    #[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
353    struct TestOutput {
354        title: String,
355        message: String,
356        #[serde(default)]
357        tags: Vec<String>,
358        #[serde(default)]
359        count: i32,
360    }
361
362    #[test]
363    fn test_valid_json_parses_directly() {
364        let json = r#"{"title": "Test", "message": "Hello", "tags": ["a", "b"], "count": 5}"#;
365        let result = validate_and_parse::<TestOutput>(json).expect("should parse");
366        assert!(!result.recovered);
367        assert!(result.warnings.is_empty());
368        assert_eq!(result.value.expect("should have value").title, "Test");
369    }
370
371    #[test]
372    fn test_recovers_missing_optional_fields() {
373        let json = r#"{"title": "Test", "message": "Hello"}"#;
374        let result = validate_and_parse::<TestOutput>(json).expect("should parse");
375        let value = result.value.expect("should have value");
376        assert_eq!(value.title, "Test");
377        assert!(value.tags.is_empty());
378    }
379
380    #[test]
381    fn test_converts_number_to_string() {
382        let json = r#"{"title": 123, "message": "Hello"}"#;
383        let result = validate_and_parse::<TestOutput>(json).expect("should parse");
384        assert!(result.recovered);
385        assert_eq!(result.value.expect("should have value").title, "123");
386    }
387
388    #[test]
389    fn test_converts_single_value_to_array() {
390        let json = r#"{"title": "Test", "message": "Hello", "tags": "single"}"#;
391        let result = validate_and_parse::<TestOutput>(json).expect("should parse");
392        assert!(result.recovered);
393        assert_eq!(
394            result.value.expect("should have value").tags,
395            vec!["single"]
396        );
397    }
398
399    #[test]
400    fn test_converts_string_to_number() {
401        let json = r#"{"title": "Test", "message": "Hello", "count": "42"}"#;
402        let result = validate_and_parse::<TestOutput>(json).expect("should parse");
403        assert!(result.recovered);
404        assert_eq!(result.value.expect("should have value").count, 42);
405    }
406}