runbeam_sdk/validation/
validator.rs

1//! Core validation logic for TOML configurations against schema definitions.
2//!
3//! This module implements the actual validation that checks TOML content
4//! against parsed schema definitions.
5
6use crate::validation::error::ValidationError;
7use crate::validation::schema::{FieldDefinition, Schema, TableDefinition};
8
9/// Validate TOML content against a schema
10pub fn validate(content_toml: &str, schema_toml: &str) -> Result<(), ValidationError> {
11    // Parse schema
12    let schema = Schema::from_str(schema_toml)?;
13
14    // Parse content
15    let content: toml::Value =
16        toml::from_str(content_toml).map_err(|e| ValidationError::TomlParseError(e.to_string()))?;
17
18    // Validate
19    let mut errors = Vec::new();
20    validate_document(&content, &schema, &mut errors);
21
22    if errors.is_empty() {
23        Ok(())
24    } else if errors.len() == 1 {
25        Err(errors.into_iter().next().unwrap())
26    } else {
27        Err(ValidationError::Multiple(errors))
28    }
29}
30
31/// Recursively validate tables, building full paths for nested tables
32fn validate_tables_recursive(
33    table: &toml::map::Map<String, toml::Value>,
34    parent_path: &str,
35    schema: &Schema,
36    errors: &mut Vec<ValidationError>,
37) {
38    for (key, value) in table.iter() {
39        let table_path = if parent_path.is_empty() {
40            key.to_string()
41        } else {
42            format!("{}.{}", parent_path, key)
43        };
44
45        // Check if this table path matches a schema definition
46        if let Some(table_def) = schema.find_table(&table_path) {
47            validate_table(value, table_def, &table_path, schema, errors);
48        } else if value.is_table() {
49            // No direct match, check if we need to recurse deeper for dotted table names
50            // like [network.default] which becomes {network: {default: {...}}}
51            if let Some(nested_table) = value.as_table() {
52                // Try to find a pattern that might match this nested structure
53                let has_pattern_match = schema
54                    .tables
55                    .values()
56                    .any(|t| t.is_pattern && schema.matches_pattern(&table_path, &t.name));
57
58                if has_pattern_match {
59                    // This table itself doesn't match, but its children might match a pattern
60                    validate_tables_recursive(nested_table, &table_path, schema, errors);
61                } else {
62                    // Check if any nested tables might match
63                    validate_tables_recursive(nested_table, &table_path, schema, errors);
64                }
65            } else {
66                // Not a table, and doesn't match - this is unexpected
67                errors.push(ValidationError::UnexpectedTable {
68                    table_path: table_path.clone(),
69                });
70            }
71        }
72    }
73}
74
75/// Validate the entire TOML document
76fn validate_document(content: &toml::Value, schema: &Schema, errors: &mut Vec<ValidationError>) {
77    let Some(root_table) = content.as_table() else {
78        errors.push(ValidationError::TomlParseError(
79            "Root of TOML must be a table".to_string(),
80        ));
81        return;
82    };
83
84    // Validate each table in the content
85    // Handle both direct tables and nested/dotted tables
86    validate_tables_recursive(root_table, "", schema, errors);
87
88    // Check for missing required tables
89    for table_def in schema.get_concrete_tables() {
90        if table_def.required && !root_table.contains_key(&table_def.name) {
91            errors.push(ValidationError::MissingRequiredField {
92                field_path: table_def.name.clone(),
93            });
94        }
95    }
96}
97
98/// Validate a table against its schema definition
99fn validate_table(
100    table_value: &toml::Value,
101    table_def: &TableDefinition,
102    table_path: &str,
103    schema: &Schema,
104    errors: &mut Vec<ValidationError>,
105) {
106    let Some(table) = table_value.as_table() else {
107        errors.push(ValidationError::InvalidType {
108            field_path: table_path.to_string(),
109            expected: "table".to_string(),
110            found: get_type_name(table_value),
111        });
112        return;
113    };
114
115    // Validate pattern constraint if this is a pattern table
116    if table_def.is_pattern {
117        if let Some(pattern_constraint) = &table_def.pattern_constraint {
118            // Extract the dynamic part of the table name
119            // e.g., for "network.default" with pattern "network.*", extract "default"
120            if let Some(dynamic_part) = extract_dynamic_part(table_path, &table_def.name) {
121                if !pattern_constraint.is_match(&dynamic_part) {
122                    errors.push(ValidationError::PatternMismatch {
123                        field_path: table_path.to_string(),
124                        pattern: pattern_constraint.as_str().to_string(),
125                    });
126                }
127            }
128        }
129    }
130
131    // Validate each field in the schema
132    for field_def in table_def.get_fields() {
133        let field_path = format!("{}.{}", table_path, field_def.name);
134
135        // Handle nested field paths (e.g., "tcp_config.bind_address")
136        let field_value = get_nested_field(table, &field_def.name);
137
138        // Check required fields
139        // If required_if is specified, it takes precedence over the base required flag
140        let is_required = if field_def.required_if.is_some() {
141            field_def.is_conditionally_required(&toml::Value::Table(table.clone()))
142        } else {
143            field_def.required
144        };
145
146        if is_required && field_value.is_none() {
147            if field_def.required_if.is_some() {
148                errors.push(ValidationError::ConditionalRequirementFailed {
149                    field_path: field_path.clone(),
150                    condition: field_def.required_if.as_ref().unwrap().clone(),
151                });
152            } else {
153                errors.push(ValidationError::MissingRequiredField {
154                    field_path: field_path.clone(),
155                });
156            }
157            continue;
158        }
159
160        // Validate field value if present
161        if let Some(value) = field_value {
162            validate_field(value, field_def, &field_path, schema, errors);
163        }
164    }
165
166    // Check for unexpected fields (optional strict mode - currently lenient)
167    // This could be enabled with a flag in the future
168}
169
170/// Validate a field value against its definition
171fn validate_field(
172    value: &toml::Value,
173    field_def: &FieldDefinition,
174    field_path: &str,
175    schema: &Schema,
176    errors: &mut Vec<ValidationError>,
177) {
178    // Validate type
179    if !validate_type(value, &field_def.field_type) {
180        errors.push(ValidationError::InvalidType {
181            field_path: field_path.to_string(),
182            expected: field_def.field_type.clone(),
183            found: get_type_name(value),
184        });
185        return;
186    }
187
188    // Validate enum values
189    if let Some(enum_values) = &field_def.enum_values {
190        if let Some(str_value) = value.as_str() {
191            if !enum_values.contains(&str_value.to_string()) {
192                errors.push(ValidationError::InvalidEnumValue {
193                    field_path: field_path.to_string(),
194                    value: str_value.to_string(),
195                    allowed: enum_values.clone(),
196                });
197            }
198        }
199    }
200
201    // Validate numeric ranges
202    if let Some(int_value) = value.as_integer() {
203        if let Some(min) = field_def.min {
204            if int_value < min {
205                errors.push(ValidationError::OutOfRange {
206                    field_path: field_path.to_string(),
207                    value: int_value.to_string(),
208                    min: Some(min.to_string()),
209                    max: field_def.max.map(|m| m.to_string()),
210                });
211            }
212        }
213        if let Some(max) = field_def.max {
214            if int_value > max {
215                errors.push(ValidationError::OutOfRange {
216                    field_path: field_path.to_string(),
217                    value: int_value.to_string(),
218                    min: field_def.min.map(|m| m.to_string()),
219                    max: Some(max.to_string()),
220                });
221            }
222        }
223    }
224
225    // Validate pattern for string values
226    if let Some(pattern) = &field_def.pattern {
227        if let Some(str_value) = value.as_str() {
228            if !pattern.is_match(str_value) {
229                errors.push(ValidationError::PatternMismatch {
230                    field_path: field_path.to_string(),
231                    pattern: pattern.as_str().to_string(),
232                });
233            }
234        }
235    }
236
237    // Validate arrays
238    if let Some(array) = value.as_array() {
239        // Validate array length
240        if let Some(min_items) = field_def.min_items {
241            if array.len() < min_items {
242                errors.push(ValidationError::InvalidArrayLength {
243                    field_path: field_path.to_string(),
244                    length: array.len(),
245                    min: Some(min_items),
246                    max: field_def.max_items,
247                });
248            }
249        }
250        if let Some(max_items) = field_def.max_items {
251            if array.len() > max_items {
252                errors.push(ValidationError::InvalidArrayLength {
253                    field_path: field_path.to_string(),
254                    length: array.len(),
255                    min: field_def.min_items,
256                    max: Some(max_items),
257                });
258            }
259        }
260
261        // Validate array item types
262        if let Some(expected_item_type) = &field_def.array_item_type {
263            for (i, item) in array.iter().enumerate() {
264                if !validate_type(item, expected_item_type) {
265                    errors.push(ValidationError::InvalidType {
266                        field_path: format!("{}[{}]", field_path, i),
267                        expected: expected_item_type.clone(),
268                        found: get_type_name(item),
269                    });
270                }
271            }
272        }
273    }
274
275    // Validate nested tables recursively
276    if field_def.field_type == "table" {
277        if let Some(table_value) = value.as_table() {
278            // For nested tables, we need to find the corresponding table definition
279            // This is simplified - in a full implementation, we'd need to handle
280            // nested table schemas more comprehensively
281            for (nested_key, nested_value) in table_value.iter() {
282                let nested_path = format!("{}.{}", field_path, nested_key);
283                if let Some(nested_table_def) = schema.find_table(&nested_path) {
284                    validate_table(nested_value, nested_table_def, &nested_path, schema, errors);
285                }
286            }
287        }
288    }
289}
290
291/// Check if a value matches the expected type
292fn validate_type(value: &toml::Value, expected_type: &str) -> bool {
293    match expected_type {
294        "string" => value.is_str(),
295        "integer" => value.is_integer(),
296        "boolean" => value.is_bool(),
297        "float" => value.is_float() || value.is_integer(), // Allow integers as floats
298        "array" => value.is_array(),
299        "table" => value.is_table(),
300        _ => false,
301    }
302}
303
304/// Get a human-readable type name for a TOML value
305fn get_type_name(value: &toml::Value) -> String {
306    match value {
307        toml::Value::String(_) => "string".to_string(),
308        toml::Value::Integer(_) => "integer".to_string(),
309        toml::Value::Float(_) => "float".to_string(),
310        toml::Value::Boolean(_) => "boolean".to_string(),
311        toml::Value::Array(_) => "array".to_string(),
312        toml::Value::Table(_) => "table".to_string(),
313        toml::Value::Datetime(_) => "datetime".to_string(),
314    }
315}
316
317/// Extract nested field value from a table using dot notation
318///
319/// Supports paths like "tcp_config.bind_address"
320fn get_nested_field<'a>(
321    table: &'a toml::map::Map<String, toml::Value>,
322    path: &str,
323) -> Option<&'a toml::Value> {
324    let parts: Vec<&str> = path.split('.').collect();
325
326    if parts.len() == 1 {
327        return table.get(path);
328    }
329
330    let mut current = table.get(parts[0])?;
331
332    for part in &parts[1..] {
333        current = current.as_table()?.get(*part)?;
334    }
335
336    Some(current)
337}
338
339/// Extract the dynamic part from a table path given a pattern
340///
341/// For example: extract_dynamic_part("network.default", "network.*") returns Some("default")
342fn extract_dynamic_part(table_path: &str, pattern: &str) -> Option<String> {
343    if !pattern.contains('*') {
344        return None;
345    }
346
347    let pattern_parts: Vec<&str> = pattern.split('.').collect();
348    let path_parts: Vec<&str> = table_path.split('.').collect();
349
350    if pattern_parts.len() != path_parts.len() {
351        return None;
352    }
353
354    for (i, pattern_part) in pattern_parts.iter().enumerate() {
355        if *pattern_part == "*" {
356            return Some(path_parts[i].to_string());
357        }
358    }
359
360    None
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    fn simple_schema() -> &'static str {
368        r#"
369[schema]
370version = "1.0"
371description = "Simple test schema"
372
373[[table]]
374name = "config"
375required = true
376
377[[table.field]]
378name = "name"
379type = "string"
380required = true
381
382[[table.field]]
383name = "port"
384type = "integer"
385required = false
386min = 1
387max = 65535
388
389[[table.field]]
390name = "enabled"
391type = "boolean"
392required = false
393
394[[table.field]]
395name = "log_level"
396type = "string"
397required = false
398enum = ["debug", "info", "warn", "error"]
399
400[[table.field]]
401name = "tags"
402type = "array"
403array_item_type = "string"
404min_items = 1
405"#
406    }
407
408    #[test]
409    fn test_valid_toml() {
410        let schema = simple_schema();
411        let content = r#"
412[config]
413name = "test"
414port = 8080
415enabled = true
416log_level = "info"
417tags = ["api", "production"]
418"#;
419
420        assert!(validate(content, schema).is_ok());
421    }
422
423    #[test]
424    fn test_missing_required_field() {
425        let schema = simple_schema();
426        let content = r#"
427[config]
428port = 8080
429"#;
430
431        let result = validate(content, schema);
432        assert!(result.is_err());
433        let error = result.unwrap_err();
434        assert!(matches!(
435            error,
436            ValidationError::MissingRequiredField { .. }
437        ));
438    }
439
440    #[test]
441    fn test_invalid_type() {
442        let schema = simple_schema();
443        let content = r#"
444[config]
445name = "test"
446port = "not a number"
447"#;
448
449        let result = validate(content, schema);
450        assert!(result.is_err());
451        let error = result.unwrap_err();
452        assert!(matches!(error, ValidationError::InvalidType { .. }));
453    }
454
455    #[test]
456    fn test_invalid_enum_value() {
457        let schema = simple_schema();
458        let content = r#"
459[config]
460name = "test"
461log_level = "invalid"
462"#;
463
464        let result = validate(content, schema);
465        assert!(result.is_err());
466        let error = result.unwrap_err();
467        assert!(matches!(error, ValidationError::InvalidEnumValue { .. }));
468    }
469
470    #[test]
471    fn test_out_of_range() {
472        let schema = simple_schema();
473        let content = r#"
474[config]
475name = "test"
476port = 999999
477"#;
478
479        let result = validate(content, schema);
480        assert!(result.is_err());
481        let error = result.unwrap_err();
482        assert!(matches!(error, ValidationError::OutOfRange { .. }));
483    }
484
485    #[test]
486    fn test_invalid_array_length() {
487        let schema = simple_schema();
488        let content = r#"
489[config]
490name = "test"
491tags = []
492"#;
493
494        let result = validate(content, schema);
495        assert!(result.is_err());
496        let error = result.unwrap_err();
497        assert!(matches!(error, ValidationError::InvalidArrayLength { .. }));
498    }
499
500    #[test]
501    fn test_invalid_array_item_type() {
502        let schema = simple_schema();
503        let content = r#"
504[config]
505name = "test"
506tags = [1, 2, 3]
507"#;
508
509        let result = validate(content, schema);
510        assert!(result.is_err());
511        let error = result.unwrap_err();
512        // Should be Multiple errors or InvalidType
513        match error {
514            ValidationError::Multiple(errors) => {
515                assert!(errors
516                    .iter()
517                    .any(|e| matches!(e, ValidationError::InvalidType { .. })));
518            }
519            ValidationError::InvalidType { .. } => {}
520            _ => panic!("Expected InvalidType or Multiple errors"),
521        }
522    }
523
524    #[test]
525    fn test_pattern_tables() {
526        let schema = r#"
527[schema]
528version = "1.0"
529description = "Pattern table test"
530
531[[table]]
532name = "network.*"
533pattern = true
534pattern_constraint = "^[a-z0-9_-]+$"
535
536[[table.field]]
537name = "bind_address"
538type = "string"
539required = true
540"#;
541
542        let content = r#"
543[network.default]
544bind_address = "0.0.0.0"
545
546[network.management]
547bind_address = "127.0.0.1"
548"#;
549
550        let result = validate(content, schema);
551        if let Err(e) = &result {
552            eprintln!("Validation error: {}", e);
553        }
554        assert!(result.is_ok());
555    }
556
557    #[test]
558    fn test_pattern_constraint_violation() {
559        let schema = r#"
560[schema]
561version = "1.0"
562description = "Pattern constraint test"
563
564[[table]]
565name = "network.*"
566pattern = true
567pattern_constraint = "^[a-z0-9_-]+$"
568
569[[table.field]]
570name = "bind_address"
571type = "string"
572required = true
573"#;
574
575        let content = r#"
576[network.INVALID_NAME]
577bind_address = "0.0.0.0"
578"#;
579
580        let result = validate(content, schema);
581        assert!(result.is_err());
582        let error = result.unwrap_err();
583        assert!(matches!(error, ValidationError::PatternMismatch { .. }));
584    }
585
586    #[test]
587    fn test_nested_fields() {
588        let schema = r#"
589[schema]
590version = "1.0"
591description = "Nested fields test"
592
593[[table]]
594name = "network"
595required = true
596
597[[table.field]]
598name = "tcp_config.bind_address"
599type = "string"
600required = true
601
602[[table.field]]
603name = "tcp_config.port"
604type = "integer"
605required = true
606"#;
607
608        let content = r#"
609[network.tcp_config]
610bind_address = "0.0.0.0"
611port = 8080
612"#;
613
614        assert!(validate(content, schema).is_ok());
615    }
616
617    #[test]
618    fn test_get_nested_field() {
619        let mut table = toml::map::Map::new();
620        let mut tcp_config = toml::map::Map::new();
621        tcp_config.insert(
622            "bind_address".to_string(),
623            toml::Value::String("0.0.0.0".to_string()),
624        );
625        table.insert("tcp_config".to_string(), toml::Value::Table(tcp_config));
626
627        let value = get_nested_field(&table, "tcp_config.bind_address");
628        assert!(value.is_some());
629        assert_eq!(value.unwrap().as_str(), Some("0.0.0.0"));
630    }
631
632    #[test]
633    fn test_extract_dynamic_part() {
634        assert_eq!(
635            extract_dynamic_part("network.default", "network.*"),
636            Some("default".to_string())
637        );
638        assert_eq!(
639            extract_dynamic_part("network.management", "network.*"),
640            Some("management".to_string())
641        );
642        assert_eq!(extract_dynamic_part("network", "network.*"), None);
643    }
644}