1use crate::validation::error::ValidationError;
7use crate::validation::schema::{FieldDefinition, Schema, TableDefinition};
8
9pub fn validate(content_toml: &str, schema_toml: &str) -> Result<(), ValidationError> {
11 let schema = Schema::from_str(schema_toml)?;
13
14 let content: toml::Value =
16 toml::from_str(content_toml).map_err(|e| ValidationError::TomlParseError(e.to_string()))?;
17
18 let mut errors = Vec::new();
20 validate_document(&content, &schema, &mut errors);
21
22 if errors.is_empty() {
23 Ok(())
24 } else if errors.len() == 1 {
25 Err(errors.into_iter().next().unwrap())
26 } else {
27 Err(ValidationError::Multiple(errors))
28 }
29}
30
31fn validate_tables_recursive(
33 table: &toml::map::Map<String, toml::Value>,
34 parent_path: &str,
35 schema: &Schema,
36 errors: &mut Vec<ValidationError>,
37) {
38 for (key, value) in table.iter() {
39 let table_path = if parent_path.is_empty() {
40 key.to_string()
41 } else {
42 format!("{}.{}", parent_path, key)
43 };
44
45 if let Some(table_def) = schema.find_table(&table_path) {
47 validate_table(value, table_def, &table_path, schema, errors);
48 } else if value.is_table() {
49 if let Some(nested_table) = value.as_table() {
52 let has_pattern_match = schema
54 .tables
55 .values()
56 .any(|t| t.is_pattern && schema.matches_pattern(&table_path, &t.name));
57
58 if has_pattern_match {
59 validate_tables_recursive(nested_table, &table_path, schema, errors);
61 } else {
62 validate_tables_recursive(nested_table, &table_path, schema, errors);
64 }
65 } else {
66 errors.push(ValidationError::UnexpectedTable {
68 table_path: table_path.clone(),
69 });
70 }
71 }
72 }
73}
74
75fn validate_document(content: &toml::Value, schema: &Schema, errors: &mut Vec<ValidationError>) {
77 let Some(root_table) = content.as_table() else {
78 errors.push(ValidationError::TomlParseError(
79 "Root of TOML must be a table".to_string(),
80 ));
81 return;
82 };
83
84 validate_tables_recursive(root_table, "", schema, errors);
87
88 for table_def in schema.get_concrete_tables() {
90 if table_def.required && !root_table.contains_key(&table_def.name) {
91 errors.push(ValidationError::MissingRequiredField {
92 field_path: table_def.name.clone(),
93 });
94 }
95 }
96}
97
98fn validate_table(
100 table_value: &toml::Value,
101 table_def: &TableDefinition,
102 table_path: &str,
103 schema: &Schema,
104 errors: &mut Vec<ValidationError>,
105) {
106 let Some(table) = table_value.as_table() else {
107 errors.push(ValidationError::InvalidType {
108 field_path: table_path.to_string(),
109 expected: "table".to_string(),
110 found: get_type_name(table_value),
111 });
112 return;
113 };
114
115 if table_def.is_pattern {
117 if let Some(pattern_constraint) = &table_def.pattern_constraint {
118 if let Some(dynamic_part) = extract_dynamic_part(table_path, &table_def.name) {
121 if !pattern_constraint.is_match(&dynamic_part) {
122 errors.push(ValidationError::PatternMismatch {
123 field_path: table_path.to_string(),
124 pattern: pattern_constraint.as_str().to_string(),
125 });
126 }
127 }
128 }
129 }
130
131 for field_def in table_def.get_fields() {
133 let field_path = format!("{}.{}", table_path, field_def.name);
134
135 let field_value = get_nested_field(table, &field_def.name);
137
138 let is_required = if field_def.required_if.is_some() {
141 field_def.is_conditionally_required(&toml::Value::Table(table.clone()))
142 } else {
143 field_def.required
144 };
145
146 if is_required && field_value.is_none() {
147 if field_def.required_if.is_some() {
148 errors.push(ValidationError::ConditionalRequirementFailed {
149 field_path: field_path.clone(),
150 condition: field_def.required_if.as_ref().unwrap().clone(),
151 });
152 } else {
153 errors.push(ValidationError::MissingRequiredField {
154 field_path: field_path.clone(),
155 });
156 }
157 continue;
158 }
159
160 if let Some(value) = field_value {
162 validate_field(value, field_def, &field_path, schema, errors);
163 }
164 }
165
166 }
169
170fn validate_field(
172 value: &toml::Value,
173 field_def: &FieldDefinition,
174 field_path: &str,
175 schema: &Schema,
176 errors: &mut Vec<ValidationError>,
177) {
178 if !validate_type(value, &field_def.field_type) {
180 errors.push(ValidationError::InvalidType {
181 field_path: field_path.to_string(),
182 expected: field_def.field_type.clone(),
183 found: get_type_name(value),
184 });
185 return;
186 }
187
188 if let Some(enum_values) = &field_def.enum_values {
190 if let Some(str_value) = value.as_str() {
191 if !enum_values.contains(&str_value.to_string()) {
192 errors.push(ValidationError::InvalidEnumValue {
193 field_path: field_path.to_string(),
194 value: str_value.to_string(),
195 allowed: enum_values.clone(),
196 });
197 }
198 }
199 }
200
201 if let Some(int_value) = value.as_integer() {
203 if let Some(min) = field_def.min {
204 if int_value < min {
205 errors.push(ValidationError::OutOfRange {
206 field_path: field_path.to_string(),
207 value: int_value.to_string(),
208 min: Some(min.to_string()),
209 max: field_def.max.map(|m| m.to_string()),
210 });
211 }
212 }
213 if let Some(max) = field_def.max {
214 if int_value > max {
215 errors.push(ValidationError::OutOfRange {
216 field_path: field_path.to_string(),
217 value: int_value.to_string(),
218 min: field_def.min.map(|m| m.to_string()),
219 max: Some(max.to_string()),
220 });
221 }
222 }
223 }
224
225 if let Some(pattern) = &field_def.pattern {
227 if let Some(str_value) = value.as_str() {
228 if !pattern.is_match(str_value) {
229 errors.push(ValidationError::PatternMismatch {
230 field_path: field_path.to_string(),
231 pattern: pattern.as_str().to_string(),
232 });
233 }
234 }
235 }
236
237 if let Some(array) = value.as_array() {
239 if let Some(min_items) = field_def.min_items {
241 if array.len() < min_items {
242 errors.push(ValidationError::InvalidArrayLength {
243 field_path: field_path.to_string(),
244 length: array.len(),
245 min: Some(min_items),
246 max: field_def.max_items,
247 });
248 }
249 }
250 if let Some(max_items) = field_def.max_items {
251 if array.len() > max_items {
252 errors.push(ValidationError::InvalidArrayLength {
253 field_path: field_path.to_string(),
254 length: array.len(),
255 min: field_def.min_items,
256 max: Some(max_items),
257 });
258 }
259 }
260
261 if let Some(expected_item_type) = &field_def.array_item_type {
263 for (i, item) in array.iter().enumerate() {
264 if !validate_type(item, expected_item_type) {
265 errors.push(ValidationError::InvalidType {
266 field_path: format!("{}[{}]", field_path, i),
267 expected: expected_item_type.clone(),
268 found: get_type_name(item),
269 });
270 }
271 }
272 }
273 }
274
275 if field_def.field_type == "table" {
277 if let Some(table_value) = value.as_table() {
278 for (nested_key, nested_value) in table_value.iter() {
282 let nested_path = format!("{}.{}", field_path, nested_key);
283 if let Some(nested_table_def) = schema.find_table(&nested_path) {
284 validate_table(nested_value, nested_table_def, &nested_path, schema, errors);
285 }
286 }
287 }
288 }
289}
290
291fn validate_type(value: &toml::Value, expected_type: &str) -> bool {
293 match expected_type {
294 "string" => value.is_str(),
295 "integer" => value.is_integer(),
296 "boolean" => value.is_bool(),
297 "float" => value.is_float() || value.is_integer(), "array" => value.is_array(),
299 "table" => value.is_table(),
300 _ => false,
301 }
302}
303
304fn get_type_name(value: &toml::Value) -> String {
306 match value {
307 toml::Value::String(_) => "string".to_string(),
308 toml::Value::Integer(_) => "integer".to_string(),
309 toml::Value::Float(_) => "float".to_string(),
310 toml::Value::Boolean(_) => "boolean".to_string(),
311 toml::Value::Array(_) => "array".to_string(),
312 toml::Value::Table(_) => "table".to_string(),
313 toml::Value::Datetime(_) => "datetime".to_string(),
314 }
315}
316
317fn get_nested_field<'a>(
321 table: &'a toml::map::Map<String, toml::Value>,
322 path: &str,
323) -> Option<&'a toml::Value> {
324 let parts: Vec<&str> = path.split('.').collect();
325
326 if parts.len() == 1 {
327 return table.get(path);
328 }
329
330 let mut current = table.get(parts[0])?;
331
332 for part in &parts[1..] {
333 current = current.as_table()?.get(*part)?;
334 }
335
336 Some(current)
337}
338
339fn extract_dynamic_part(table_path: &str, pattern: &str) -> Option<String> {
343 if !pattern.contains('*') {
344 return None;
345 }
346
347 let pattern_parts: Vec<&str> = pattern.split('.').collect();
348 let path_parts: Vec<&str> = table_path.split('.').collect();
349
350 if pattern_parts.len() != path_parts.len() {
351 return None;
352 }
353
354 for (i, pattern_part) in pattern_parts.iter().enumerate() {
355 if *pattern_part == "*" {
356 return Some(path_parts[i].to_string());
357 }
358 }
359
360 None
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 fn simple_schema() -> &'static str {
368 r#"
369[schema]
370version = "1.0"
371description = "Simple test schema"
372
373[[table]]
374name = "config"
375required = true
376
377[[table.field]]
378name = "name"
379type = "string"
380required = true
381
382[[table.field]]
383name = "port"
384type = "integer"
385required = false
386min = 1
387max = 65535
388
389[[table.field]]
390name = "enabled"
391type = "boolean"
392required = false
393
394[[table.field]]
395name = "log_level"
396type = "string"
397required = false
398enum = ["debug", "info", "warn", "error"]
399
400[[table.field]]
401name = "tags"
402type = "array"
403array_item_type = "string"
404min_items = 1
405"#
406 }
407
408 #[test]
409 fn test_valid_toml() {
410 let schema = simple_schema();
411 let content = r#"
412[config]
413name = "test"
414port = 8080
415enabled = true
416log_level = "info"
417tags = ["api", "production"]
418"#;
419
420 assert!(validate(content, schema).is_ok());
421 }
422
423 #[test]
424 fn test_missing_required_field() {
425 let schema = simple_schema();
426 let content = r#"
427[config]
428port = 8080
429"#;
430
431 let result = validate(content, schema);
432 assert!(result.is_err());
433 let error = result.unwrap_err();
434 assert!(matches!(
435 error,
436 ValidationError::MissingRequiredField { .. }
437 ));
438 }
439
440 #[test]
441 fn test_invalid_type() {
442 let schema = simple_schema();
443 let content = r#"
444[config]
445name = "test"
446port = "not a number"
447"#;
448
449 let result = validate(content, schema);
450 assert!(result.is_err());
451 let error = result.unwrap_err();
452 assert!(matches!(error, ValidationError::InvalidType { .. }));
453 }
454
455 #[test]
456 fn test_invalid_enum_value() {
457 let schema = simple_schema();
458 let content = r#"
459[config]
460name = "test"
461log_level = "invalid"
462"#;
463
464 let result = validate(content, schema);
465 assert!(result.is_err());
466 let error = result.unwrap_err();
467 assert!(matches!(error, ValidationError::InvalidEnumValue { .. }));
468 }
469
470 #[test]
471 fn test_out_of_range() {
472 let schema = simple_schema();
473 let content = r#"
474[config]
475name = "test"
476port = 999999
477"#;
478
479 let result = validate(content, schema);
480 assert!(result.is_err());
481 let error = result.unwrap_err();
482 assert!(matches!(error, ValidationError::OutOfRange { .. }));
483 }
484
485 #[test]
486 fn test_invalid_array_length() {
487 let schema = simple_schema();
488 let content = r#"
489[config]
490name = "test"
491tags = []
492"#;
493
494 let result = validate(content, schema);
495 assert!(result.is_err());
496 let error = result.unwrap_err();
497 assert!(matches!(error, ValidationError::InvalidArrayLength { .. }));
498 }
499
500 #[test]
501 fn test_invalid_array_item_type() {
502 let schema = simple_schema();
503 let content = r#"
504[config]
505name = "test"
506tags = [1, 2, 3]
507"#;
508
509 let result = validate(content, schema);
510 assert!(result.is_err());
511 let error = result.unwrap_err();
512 match error {
514 ValidationError::Multiple(errors) => {
515 assert!(errors
516 .iter()
517 .any(|e| matches!(e, ValidationError::InvalidType { .. })));
518 }
519 ValidationError::InvalidType { .. } => {}
520 _ => panic!("Expected InvalidType or Multiple errors"),
521 }
522 }
523
524 #[test]
525 fn test_pattern_tables() {
526 let schema = r#"
527[schema]
528version = "1.0"
529description = "Pattern table test"
530
531[[table]]
532name = "network.*"
533pattern = true
534pattern_constraint = "^[a-z0-9_-]+$"
535
536[[table.field]]
537name = "bind_address"
538type = "string"
539required = true
540"#;
541
542 let content = r#"
543[network.default]
544bind_address = "0.0.0.0"
545
546[network.management]
547bind_address = "127.0.0.1"
548"#;
549
550 let result = validate(content, schema);
551 if let Err(e) = &result {
552 eprintln!("Validation error: {}", e);
553 }
554 assert!(result.is_ok());
555 }
556
557 #[test]
558 fn test_pattern_constraint_violation() {
559 let schema = r#"
560[schema]
561version = "1.0"
562description = "Pattern constraint test"
563
564[[table]]
565name = "network.*"
566pattern = true
567pattern_constraint = "^[a-z0-9_-]+$"
568
569[[table.field]]
570name = "bind_address"
571type = "string"
572required = true
573"#;
574
575 let content = r#"
576[network.INVALID_NAME]
577bind_address = "0.0.0.0"
578"#;
579
580 let result = validate(content, schema);
581 assert!(result.is_err());
582 let error = result.unwrap_err();
583 assert!(matches!(error, ValidationError::PatternMismatch { .. }));
584 }
585
586 #[test]
587 fn test_nested_fields() {
588 let schema = r#"
589[schema]
590version = "1.0"
591description = "Nested fields test"
592
593[[table]]
594name = "network"
595required = true
596
597[[table.field]]
598name = "tcp_config.bind_address"
599type = "string"
600required = true
601
602[[table.field]]
603name = "tcp_config.port"
604type = "integer"
605required = true
606"#;
607
608 let content = r#"
609[network.tcp_config]
610bind_address = "0.0.0.0"
611port = 8080
612"#;
613
614 assert!(validate(content, schema).is_ok());
615 }
616
617 #[test]
618 fn test_get_nested_field() {
619 let mut table = toml::map::Map::new();
620 let mut tcp_config = toml::map::Map::new();
621 tcp_config.insert(
622 "bind_address".to_string(),
623 toml::Value::String("0.0.0.0".to_string()),
624 );
625 table.insert("tcp_config".to_string(), toml::Value::Table(tcp_config));
626
627 let value = get_nested_field(&table, "tcp_config.bind_address");
628 assert!(value.is_some());
629 assert_eq!(value.unwrap().as_str(), Some("0.0.0.0"));
630 }
631
632 #[test]
633 fn test_extract_dynamic_part() {
634 assert_eq!(
635 extract_dynamic_part("network.default", "network.*"),
636 Some("default".to_string())
637 );
638 assert_eq!(
639 extract_dynamic_part("network.management", "network.*"),
640 Some("management".to_string())
641 );
642 assert_eq!(extract_dynamic_part("network", "network.*"), None);
643 }
644}