Skip to main content

ai_lib_rust/structured/
validator.rs

1//! Output validator for structured responses.
2//!
3//! Validates JSON data against JSON schemas, supporting:
4//! - Basic type validation (string, integer, number, boolean, array, object, null)
5//! - Field constraints (minLength, maxLength, minimum, maximum, pattern, enum)
6//! - Array constraints (minItems, maxItems, items schema)
7//! - Nested validation (recursive object and array validation)
8//! - Additional properties control
9
10use crate::structured::error::{ValidationError, ValidationResult};
11use regex::Regex;
12use serde_json::Value;
13use std::collections::HashSet;
14
15/// Validator for structured output.
16///
17/// Validates JSON data against JSON schemas with full error reporting.
18pub struct OutputValidator {
19    /// The JSON schema to validate against
20    schema: Option<Value>,
21    /// Whether to use strict validation mode
22    strict: bool,
23}
24
25impl OutputValidator {
26    /// Create a new validator with a schema.
27    ///
28    /// # Arguments
29    ///
30    /// * `schema` - JSON schema as a serde_json::Value
31    /// * `strict` - Whether to use strict validation (disallow extra properties by default)
32    pub fn new(schema: Value, strict: bool) -> Self {
33        Self {
34            schema: Some(schema),
35            strict,
36        }
37    }
38
39    /// Create a new validator with a schema (strict mode enabled).
40    pub fn strict(schema: Value) -> Self {
41        Self::new(schema, true)
42    }
43
44    /// Create a new validator with a schema (strict mode disabled).
45    pub fn lenient(schema: Value) -> Self {
46        Self::new(schema, false)
47    }
48
49    /// Create a validator without a schema (permissive mode).
50    pub fn permissive() -> Self {
51        Self {
52            schema: None,
53            strict: false,
54        }
55    }
56
57    /// Validate data against the schema.
58    ///
59    /// # Arguments
60    ///
61    /// * `data` - Data to validate (can be JSON string, JSON value, or arbitrary Rust value)
62    ///
63    /// # Returns
64    ///
65    /// A ValidationResult with validation status and any errors.
66    pub fn validate(&self, data: impl IntoValidatorData) -> ValidationResult {
67        let parsed = data.into_value();
68
69        // If no schema is configured, always succeed
70        let schema = match &self.schema {
71            Some(s) => s.clone(),
72            None => return ValidationResult::success(parsed),
73        };
74
75        self.validate_against_schema(&parsed, &schema, "")
76    }
77
78    /// Validate data and return the validated value or merge errors.
79    ///
80    /// # Arguments
81    ///
82    /// * `data` - Data to validate
83    ///
84    /// # Returns
85    ///
86    /// Ok(validated_value) if validation succeeds, Err(errors) if it fails.
87    pub fn validate_or_fail(
88        &self,
89        data: impl IntoValidatorData,
90    ) -> Result<Value, Vec<ValidationError>> {
91        self.validate(data).into_result()
92    }
93
94    /// Validate data against a schema at a specific path.
95    fn validate_against_schema(
96        &self,
97        data: &Value,
98        schema: &Value,
99        path: &str,
100    ) -> ValidationResult {
101        let mut errors = Vec::new();
102
103        // Type validation
104        let schema_type = schema.get("type").and_then(|t| t.as_str());
105        if let Some(type_name) = schema_type {
106            if let Err(e) = self.validate_type(data, type_name, path) {
107                errors.push(e);
108                return ValidationResult::failure(errors);
109            }
110        }
111
112        // Null handling (nullable)
113        let is_nullable = schema
114            .get("nullable")
115            .and_then(|n| n.as_bool())
116            .unwrap_or(false);
117        if is_nullable && data.is_null() {
118            return ValidationResult::success(data.clone());
119        }
120
121        // String-specific validation
122        if schema_type == Some("string") && data.is_string() {
123            self.validate_string(data, schema, path, &mut errors);
124        }
125
126        // Number-specific validation
127        if matches!(schema_type, Some("integer") | Some("number")) {
128            if let Some(num) = data.as_f64() {
129                self.validate_number(num, schema, path, &mut errors);
130            }
131        }
132
133        // Array validation
134        if schema_type == Some("array") && data.is_array() {
135            self.validate_array(data, schema, path, &mut errors);
136        }
137
138        // Object validation
139        if schema_type == Some("object") && data.is_object() {
140            self.validate_object(data, schema, path, &mut errors);
141        }
142
143        // Enum validation
144        if let Some(enum_values) = schema.get("enum").and_then(|e| e.as_array()) {
145            self.validate_enum(data, enum_values, path, &mut errors);
146        }
147
148        if errors.is_empty() {
149            ValidationResult::success(data.clone())
150        } else {
151            ValidationResult::failure(errors)
152        }
153    }
154
155    /// Validate the type of a value.
156    fn validate_type(
157        &self,
158        data: &Value,
159        expected_type: &str,
160        path: &str,
161    ) -> Result<(), ValidationError> {
162        let is_valid = match expected_type {
163            "string" => data.is_string(),
164            "integer" => data.is_i64(),
165            "number" => data.is_number(),
166            "boolean" => data.is_boolean(),
167            "array" => data.is_array(),
168            "object" => data.is_object(),
169            "null" => data.is_null(),
170            _ => true, // Unknown type, accept anything
171        };
172
173        if !is_valid {
174            let actual_type = match data {
175                Value::String(_) => "string",
176                Value::Number(_) => {
177                    if data.as_i64().is_some() {
178                        "integer"
179                    } else {
180                        "number"
181                    }
182                }
183                Value::Bool(_) => "boolean",
184                Value::Array(_) => "array",
185                Value::Object(_) => "object",
186                Value::Null => "null",
187            };
188            Err(ValidationError::with_path(
189                format!("Expected type '{}', got '{}'", expected_type, actual_type),
190                path.to_string(),
191            ))
192        } else {
193            Ok(())
194        }
195    }
196
197    /// Validate string constraints.
198    fn validate_string(
199        &self,
200        data: &Value,
201        schema: &Value,
202        path: &str,
203        errors: &mut Vec<ValidationError>,
204    ) {
205        let s = match data.as_str() {
206            Some(s) => s,
207            None => return,
208        };
209
210        // minLength
211        if let Some(min_length) = schema.get("minLength").and_then(|m| m.as_u64()) {
212            if s.len() < min_length as usize {
213                errors.push(ValidationError::with_path(
214                    format!("String too short (minimum {} characters)", min_length),
215                    path.to_string(),
216                ));
217            }
218        }
219
220        // maxLength
221        if let Some(max_length) = schema.get("maxLength").and_then(|m| m.as_u64()) {
222            if s.len() > max_length as usize {
223                errors.push(ValidationError::with_path(
224                    format!("String too long (maximum {} characters)", max_length),
225                    path.to_string(),
226                ));
227            }
228        }
229
230        // pattern (regex)
231        if let Some(pattern) = schema.get("pattern").and_then(|p| p.as_str()) {
232            match Regex::new(pattern) {
233                Ok(re) => {
234                    if !re.is_match(s) {
235                        errors.push(ValidationError::with_path(
236                            "String does not match required pattern".to_string(),
237                            path.to_string(),
238                        ));
239                    }
240                }
241                Err(_) => {
242                    // Invalid regex, skip validation
243                }
244            }
245        }
246    }
247
248    /// Validate number constraints.
249    fn validate_number(
250        &self,
251        value: f64,
252        schema: &Value,
253        path: &str,
254        errors: &mut Vec<ValidationError>,
255    ) {
256        // minimum
257        if let Some(minimum) = schema.get("minimum").and_then(|m| m.as_f64()) {
258            if value < minimum {
259                errors.push(ValidationError::with_path(
260                    format!("Value below minimum ({})", minimum),
261                    path.to_string(),
262                ));
263            }
264        }
265
266        // maximum
267        if let Some(maximum) = schema.get("maximum").and_then(|m| m.as_f64()) {
268            if value > maximum {
269                errors.push(ValidationError::with_path(
270                    format!("Value above maximum ({})", maximum),
271                    path.to_string(),
272                ));
273            }
274        }
275    }
276
277    /// Validate array constraints.
278    fn validate_array(
279        &self,
280        data: &Value,
281        schema: &Value,
282        path: &str,
283        errors: &mut Vec<ValidationError>,
284    ) {
285        let arr = match data.as_array() {
286            Some(a) => a,
287            None => return,
288        };
289
290        // minItems
291        if let Some(min_items) = schema.get("minItems").and_then(|m| m.as_u64()) {
292            if arr.len() < min_items as usize {
293                errors.push(ValidationError::with_path(
294                    format!("Array too short (minimum {} items)", min_items),
295                    path.to_string(),
296                ));
297            }
298        }
299
300        // maxItems
301        if let Some(max_items) = schema.get("maxItems").and_then(|m| m.as_u64()) {
302            if arr.len() > max_items as usize {
303                errors.push(ValidationError::with_path(
304                    format!("Array too long (maximum {} items)", max_items),
305                    path.to_string(),
306                ));
307            }
308        }
309
310        // items (validate each element)
311        if let Some(items_schema) = schema.get("items") {
312            for (i, item) in arr.iter().enumerate() {
313                let item_path = format!("{}[{}]", path, i);
314                let result = self.validate_against_schema(item, items_schema, &item_path);
315                if !result.is_valid() {
316                    errors.extend(result.errors);
317                }
318            }
319        }
320    }
321
322    /// Validate object constraints.
323    fn validate_object(
324        &self,
325        data: &Value,
326        schema: &Value,
327        path: &str,
328        errors: &mut Vec<ValidationError>,
329    ) {
330        let obj = match data.as_object() {
331            Some(o) => o,
332            None => return,
333        };
334
335        // required properties
336        let required: Vec<String> = schema
337            .get("required")
338            .and_then(|r| r.as_array())
339            .map(|arr| {
340                arr.iter()
341                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
342                    .collect()
343            })
344            .unwrap_or_default();
345
346        for prop_name in &required {
347            if !obj.contains_key(prop_name) {
348                errors.push(ValidationError::with_path(
349                    format!("Missing required property: {}", prop_name),
350                    format!("{}.{}", path, prop_name),
351                ));
352            }
353        }
354
355        // properties (validate each property)
356        let empty_props: Value = serde_json::json!({});
357        let properties = schema
358            .get("properties")
359            .and_then(|p| p.as_object())
360            .unwrap_or_else(|| empty_props.as_object().unwrap());
361
362        for (prop_name, prop_schema) in properties {
363            if let Some(prop_value) = obj.get(prop_name) {
364                let prop_path = format!("{}.{}", path, prop_name);
365                let result = self.validate_against_schema(prop_value, prop_schema, &prop_path);
366                if !result.is_valid() {
367                    errors.extend(result.errors);
368                }
369            }
370        }
371
372        // additionalProperties
373        let additional_props = schema
374            .get("additionalProperties")
375            .and_then(|a| a.as_bool())
376            .unwrap_or(!self.strict); // Default to opposite of strict mode
377
378        if !additional_props {
379            let allowed_keys: HashSet<&str> = properties.keys().map(|k| k.as_str()).collect();
380            for key in obj.keys() {
381                if !allowed_keys.contains(key.as_str()) {
382                    errors.push(ValidationError::with_path(
383                        format!("Additional property not allowed: {}", key),
384                        format!("{}.{}", path, key),
385                    ));
386                }
387            }
388        }
389
390        // additionalProperties as schema
391        if let Some(additional_schema) =
392            schema.get("additionalProperties").and_then(
393                |a| {
394                    if a.is_boolean() {
395                        None
396                    } else {
397                        Some(a)
398                    }
399                },
400            )
401        {
402            let allowed_keys: HashSet<&str> = properties.keys().map(|k| k.as_str()).collect();
403            for (key, value) in obj {
404                if !allowed_keys.contains(key.as_str()) {
405                    let prop_path = format!("{}.{}", path, key);
406                    let result = self.validate_against_schema(value, additional_schema, &prop_path);
407                    if !result.is_valid() {
408                        errors.extend(result.errors);
409                    }
410                }
411            }
412        }
413    }
414
415    /// Validate enum constraint.
416    fn validate_enum(
417        &self,
418        data: &Value,
419        enum_values: &[Value],
420        path: &str,
421        errors: &mut Vec<ValidationError>,
422    ) {
423        if !enum_values.contains(data) {
424            let allowed: Vec<String> = enum_values
425                .iter()
426                .map(|v| match v {
427                    Value::String(s) => format!("\"{}\"", s),
428                    _ => v.to_string(),
429                })
430                .collect();
431            errors.push(ValidationError::with_path(
432                format!("Value not in allowed enum values: {}", allowed.join(", ")),
433                path.to_string(),
434            ));
435        }
436    }
437}
438
439/// Trait for types that can be converted to validator data.
440pub trait IntoValidatorData {
441    fn into_value(self) -> Value;
442}
443
444impl IntoValidatorData for Value {
445    fn into_value(self) -> Value {
446        self
447    }
448}
449
450impl IntoValidatorData for &Value {
451    fn into_value(self) -> Value {
452        self.clone()
453    }
454}
455
456impl IntoValidatorData for &str {
457    fn into_value(self) -> Value {
458        // Try to parse as JSON, fall back to string
459        serde_json::from_str(self).unwrap_or_else(|_| Value::String(self.to_string()))
460    }
461}
462
463impl IntoValidatorData for String {
464    fn into_value(self) -> Value {
465        // Try to parse as JSON, fall back to string
466        #[allow(clippy::unnecessary_lazy_evaluations)]
467        serde_json::from_str(&self).unwrap_or_else(|_| Value::String(self))
468    }
469}
470
471impl IntoValidatorData for i64 {
472    fn into_value(self) -> Value {
473        Value::Number(self.into())
474    }
475}
476
477impl IntoValidatorData for i32 {
478    fn into_value(self) -> Value {
479        Value::Number(self.into())
480    }
481}
482
483impl IntoValidatorData for u64 {
484    fn into_value(self) -> Value {
485        Value::Number(self.into())
486    }
487}
488
489impl IntoValidatorData for u32 {
490    fn into_value(self) -> Value {
491        Value::Number(self.into())
492    }
493}
494
495impl IntoValidatorData for f64 {
496    fn into_value(self) -> Value {
497        serde_json::Number::from_f64(self)
498            .map(Value::Number)
499            .unwrap_or(Value::Null)
500    }
501}
502
503impl IntoValidatorData for f32 {
504    fn into_value(self) -> Value {
505        serde_json::Number::from_f64(self as f64)
506            .map(Value::Number)
507            .unwrap_or(Value::Null)
508    }
509}
510
511impl IntoValidatorData for bool {
512    fn into_value(self) -> Value {
513        Value::Bool(self)
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    fn make_string_schema() -> Value {
522        serde_json::json!({
523            "type": "string"
524        })
525    }
526
527    fn make_string_schema_with_length(min: Option<u64>, max: Option<u64>) -> Value {
528        let mut schema = serde_json::json!({
529            "type": "string"
530        });
531        if let Some(m) = min {
532            schema["minLength"] = m.into();
533        }
534        if let Some(m) = max {
535            schema["maxLength"] = m.into();
536        }
537        schema
538    }
539
540    fn make_object_schema(required: Vec<String>) -> Value {
541        let mut schema = serde_json::json!({
542            "type": "object",
543            "properties": {
544                "name": {"type": "string"},
545                "age": {"type": "integer"}
546            }
547        });
548        if !required.is_empty() {
549            schema["required"] = serde_json::json!(required);
550        }
551        schema
552    }
553
554    fn make_array_schema() -> Value {
555        serde_json::json!({
556            "type": "array",
557            "items": {"type": "string"}
558        })
559    }
560
561    #[test]
562    fn test_validator_basic_string() {
563        let validator = OutputValidator::lenient(make_string_schema());
564
565        let result = validator.validate("hello");
566        assert!(result.is_valid());
567    }
568
569    #[test]
570    fn test_validator_string_min_length() {
571        let validator = OutputValidator::lenient(make_string_schema_with_length(Some(5), None));
572
573        let result = validator.validate("hi");
574        assert!(!result.is_valid());
575        assert!(result.error_messages()[0].contains("too short"));
576    }
577
578    #[test]
579    fn test_validator_string_max_length() {
580        let validator = OutputValidator::lenient(make_string_schema_with_length(None, Some(3)));
581
582        let result = validator.validate("hello");
583        assert!(!result.is_valid());
584        assert!(result.error_messages()[0].contains("too long"));
585    }
586
587    #[test]
588    fn test_validator_integer_type() {
589        let schema = serde_json::json!({"type": "integer"});
590        let validator = OutputValidator::lenient(schema);
591
592        let result = validator.validate(42_i32);
593        assert!(result.is_valid());
594
595        let result = validator.validate(serde_json::Value::String("42".to_string()));
596        assert!(!result.is_valid());
597    }
598
599    #[test]
600    fn test_validator_object_required() {
601        let schema = make_object_schema(vec!["name".to_string()]);
602        let validator = OutputValidator::lenient(schema);
603
604        let data = serde_json::json!({"age": 30});
605        let result = validator.validate(data);
606        assert!(!result.is_valid());
607        assert!(result.error_messages()[0].contains("Missing required"));
608    }
609
610    #[test]
611    fn test_validator_array_items() {
612        let validator = OutputValidator::lenient(make_array_schema());
613
614        let data = serde_json::json!(["hello", "world"]);
615        let result = validator.validate(data);
616        assert!(result.is_valid());
617
618        let data = serde_json::json!([1, 2, 3]);
619        let result = validator.validate(data);
620        assert!(!result.is_valid());
621    }
622
623    #[test]
624    fn test_validator_enum() {
625        let schema = serde_json::json!({
626            "type": "string",
627            "enum": ["red", "green", "blue"]
628        });
629        let validator = OutputValidator::lenient(schema);
630
631        let result = validator.validate("red");
632        assert!(result.is_valid());
633
634        let result = validator.validate("yellow");
635        assert!(!result.is_valid());
636        assert!(result.error_messages()[0].contains("not in allowed enum"));
637    }
638
639    #[test]
640    fn test_validator_permissive() {
641        let validator = OutputValidator::permissive();
642
643        let result = validator.validate(serde_json::json!({"arbitrary": "data"}));
644        assert!(result.is_valid());
645    }
646
647    #[test]
648    fn test_validator_strict_additional_properties() {
649        let schema = make_object_schema(vec![]);
650        let validator = OutputValidator::strict(schema);
651
652        let data = serde_json::json!({"name": "Alice", "extra": "data"});
653        let result = validator.validate(data);
654        assert!(!result.is_valid());
655        assert!(result.error_messages()[0].contains("Additional property not allowed"));
656    }
657
658    #[test]
659    fn test_validator_nested_object() {
660        let schema = serde_json::json!({
661            "type": "object",
662            "properties": {
663                "user": {
664                    "type": "object",
665                    "properties": {
666                        "name": {"type": "string"}
667                    },
668                    "required": ["name"]
669                }
670            }
671        });
672        let validator = OutputValidator::lenient(schema);
673
674        let data = serde_json::json!({"user": {"age": 30}});
675        let result = validator.validate(data);
676        assert!(!result.is_valid());
677        // Check that error message contains "required" instead of "missing"
678        assert!(result.error_messages()[0]
679            .to_lowercase()
680            .contains("required"));
681    }
682
683    #[test]
684    fn test_validate_or_fail() {
685        let validator = OutputValidator::lenient(make_string_schema());
686
687        let result = validator.validate_or_fail("hello");
688        assert!(result.is_ok());
689
690        let schema = serde_json::json!({"type": "integer"});
691        let validator = OutputValidator::lenient(schema);
692        let result = validator.validate_or_fail("hello");
693        assert!(result.is_err());
694    }
695
696    #[test]
697    fn test_validation_result_merge() {
698        let result1 = ValidationResult::success(serde_json::json!(1));
699        let result2 = ValidationResult::success(serde_json::json!(2));
700        let merged = ValidationResult::merge(vec![result1, result2]);
701        assert!(merged.is_valid());
702
703        let error = ValidationError::without_path("Test error");
704        let result3 = ValidationResult::from_error(error);
705        let merged = ValidationResult::merge(vec![
706            ValidationResult::success(serde_json::json!(1)),
707            result3,
708        ]);
709        assert!(!merged.is_valid());
710    }
711}