Skip to main content

mabi_core/config/
validation.rs

1//! Configuration validation system.
2//!
3//! This module provides comprehensive validation for configuration types with:
4//! - Field-level validation rules
5//! - Cross-field validation (constraints between multiple fields)
6//! - Nested configuration validation
7//! - Protocol-specific validation
8//! - Custom validation rules
9//!
10//! # Example
11//!
12//! ```rust,ignore
13//! use mabi_core::config::{EngineConfig, Validator};
14//!
15//! let config = EngineConfig::default();
16//! let result = config.validate();
17//!
18//! match result {
19//!     Ok(()) => println!("Configuration is valid"),
20//!     Err(errors) => {
21//!         for (field, messages) in errors.iter() {
22//!             for msg in messages {
23//!                 eprintln!("  {}: {}", field, msg);
24//!             }
25//!         }
26//!     }
27//! }
28//! ```
29
30use std::fmt;
31use std::net::SocketAddr;
32use std::path::Path;
33
34use crate::error::ValidationErrors;
35use crate::Result;
36
37/// Trait for types that can be validated.
38pub trait Validatable {
39    /// Validate this configuration.
40    ///
41    /// Returns Ok(()) if valid, or Err with validation errors.
42    fn validate(&self) -> Result<()>;
43
44    /// Validate and collect errors without failing immediately.
45    fn validate_collect(&self, errors: &mut ValidationErrors);
46}
47
48/// Validation context for tracking nested paths.
49#[derive(Debug, Clone, Default)]
50pub struct ValidationContext {
51    /// Current path in the configuration tree.
52    path: Vec<String>,
53}
54
55impl ValidationContext {
56    /// Create a new validation context.
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    /// Enter a nested field.
62    pub fn enter(&mut self, field: impl Into<String>) {
63        self.path.push(field.into());
64    }
65
66    /// Leave the current nested field.
67    pub fn leave(&mut self) {
68        self.path.pop();
69    }
70
71    /// Get the current full path as a string.
72    pub fn path(&self) -> String {
73        self.path.join(".")
74    }
75
76    /// Create a field path combining context path and field name.
77    pub fn field(&self, name: &str) -> String {
78        if self.path.is_empty() {
79            name.to_string()
80        } else {
81            format!("{}.{}", self.path(), name)
82        }
83    }
84
85    /// Execute validation within a nested context.
86    pub fn with_field<F>(&mut self, field: impl Into<String>, f: F)
87    where
88        F: FnOnce(&mut Self),
89    {
90        self.enter(field);
91        f(self);
92        self.leave();
93    }
94}
95
96/// A validation rule that can be applied to a value.
97pub trait ValidationRule<T: ?Sized>: Send + Sync {
98    /// Validate the value.
99    fn validate(&self, value: &T) -> std::result::Result<(), String>;
100
101    /// Get a description of this rule.
102    fn description(&self) -> &str;
103}
104
105/// Rule that checks if a value is within a range.
106pub struct RangeRule<T> {
107    min: Option<T>,
108    max: Option<T>,
109    description: String,
110}
111
112impl<T: PartialOrd + fmt::Display + Copy> RangeRule<T> {
113    /// Create a range rule with minimum and maximum.
114    pub fn new(min: Option<T>, max: Option<T>) -> Self {
115        let description = match (&min, &max) {
116            (Some(min), Some(max)) => format!("Value must be between {} and {}", min, max),
117            (Some(min), None) => format!("Value must be at least {}", min),
118            (None, Some(max)) => format!("Value must be at most {}", max),
119            (None, None) => "No range constraint".to_string(),
120        };
121        Self {
122            min,
123            max,
124            description,
125        }
126    }
127
128    /// Create a minimum-only rule.
129    pub fn min(min: T) -> Self {
130        Self::new(Some(min), None)
131    }
132
133    /// Create a maximum-only rule.
134    pub fn max(max: T) -> Self {
135        Self::new(None, Some(max))
136    }
137
138    /// Create a between rule.
139    pub fn between(min: T, max: T) -> Self {
140        Self::new(Some(min), Some(max))
141    }
142}
143
144impl<T: PartialOrd + fmt::Display + Copy + Send + Sync> ValidationRule<T> for RangeRule<T> {
145    fn validate(&self, value: &T) -> std::result::Result<(), String> {
146        if let Some(min) = &self.min {
147            if value < min {
148                return Err(format!("Value {} is below minimum {}", value, min));
149            }
150        }
151        if let Some(max) = &self.max {
152            if value > max {
153                return Err(format!("Value {} exceeds maximum {}", value, max));
154            }
155        }
156        Ok(())
157    }
158
159    fn description(&self) -> &str {
160        &self.description
161    }
162}
163
164/// Rule that checks string length.
165pub struct StringLengthRule {
166    min: Option<usize>,
167    max: Option<usize>,
168    description: String,
169}
170
171impl StringLengthRule {
172    /// Create a string length rule.
173    pub fn new(min: Option<usize>, max: Option<usize>) -> Self {
174        let description = match (min, max) {
175            (Some(min), Some(max)) => format!("Length must be between {} and {}", min, max),
176            (Some(min), None) => format!("Length must be at least {}", min),
177            (None, Some(max)) => format!("Length must be at most {}", max),
178            (None, None) => "No length constraint".to_string(),
179        };
180        Self {
181            min,
182            max,
183            description,
184        }
185    }
186
187    /// Create a non-empty rule.
188    pub fn non_empty() -> Self {
189        Self::new(Some(1), None)
190    }
191
192    /// Create a max length rule.
193    pub fn max(max: usize) -> Self {
194        Self::new(None, Some(max))
195    }
196}
197
198impl ValidationRule<String> for StringLengthRule {
199    fn validate(&self, value: &String) -> std::result::Result<(), String> {
200        let len = value.len();
201        if let Some(min) = self.min {
202            if len < min {
203                return Err(format!("String length {} is below minimum {}", len, min));
204            }
205        }
206        if let Some(max) = self.max {
207            if len > max {
208                return Err(format!("String length {} exceeds maximum {}", len, max));
209            }
210        }
211        Ok(())
212    }
213
214    fn description(&self) -> &str {
215        &self.description
216    }
217}
218
219impl ValidationRule<str> for StringLengthRule {
220    fn validate(&self, value: &str) -> std::result::Result<(), String> {
221        self.validate(&value.to_string())
222    }
223
224    fn description(&self) -> &str {
225        &self.description
226    }
227}
228
229/// Rule that checks if a path exists.
230pub struct PathExistsRule {
231    check_file: bool,
232    check_dir: bool,
233    description: String,
234}
235
236impl PathExistsRule {
237    /// Create a rule that checks if a file exists.
238    pub fn file() -> Self {
239        Self {
240            check_file: true,
241            check_dir: false,
242            description: "Path must be an existing file".to_string(),
243        }
244    }
245
246    /// Create a rule that checks if a directory exists.
247    pub fn directory() -> Self {
248        Self {
249            check_file: false,
250            check_dir: true,
251            description: "Path must be an existing directory".to_string(),
252        }
253    }
254
255    /// Create a rule that checks if path exists (file or directory).
256    pub fn exists() -> Self {
257        Self {
258            check_file: false,
259            check_dir: false,
260            description: "Path must exist".to_string(),
261        }
262    }
263}
264
265impl<P: AsRef<Path>> ValidationRule<P> for PathExistsRule {
266    fn validate(&self, value: &P) -> std::result::Result<(), String> {
267        let path = value.as_ref();
268        if self.check_file {
269            if !path.is_file() {
270                return Err(format!("Path '{}' is not a file", path.display()));
271            }
272        } else if self.check_dir {
273            if !path.is_dir() {
274                return Err(format!("Path '{}' is not a directory", path.display()));
275            }
276        } else if !path.exists() {
277            return Err(format!("Path '{}' does not exist", path.display()));
278        }
279        Ok(())
280    }
281
282    fn description(&self) -> &str {
283        &self.description
284    }
285}
286
287/// Rule that validates socket addresses.
288pub struct SocketAddrRule {
289    require_ipv4: bool,
290    port_range: Option<(u16, u16)>,
291    description: String,
292}
293
294impl SocketAddrRule {
295    /// Create a basic socket address rule.
296    pub fn new() -> Self {
297        Self {
298            require_ipv4: false,
299            port_range: None,
300            description: "Must be a valid socket address".to_string(),
301        }
302    }
303
304    /// Require IPv4 address.
305    pub fn ipv4_only(mut self) -> Self {
306        self.require_ipv4 = true;
307        self.description = "Must be a valid IPv4 socket address".to_string();
308        self
309    }
310
311    /// Require port in range.
312    pub fn port_range(mut self, min: u16, max: u16) -> Self {
313        self.port_range = Some((min, max));
314        self.description = format!(
315            "Must be a valid socket address with port between {} and {}",
316            min, max
317        );
318        self
319    }
320
321    /// Require non-privileged port (>= 1024).
322    pub fn non_privileged_port(self) -> Self {
323        self.port_range(1024, 65535)
324    }
325}
326
327impl Default for SocketAddrRule {
328    fn default() -> Self {
329        Self::new()
330    }
331}
332
333impl ValidationRule<SocketAddr> for SocketAddrRule {
334    fn validate(&self, value: &SocketAddr) -> std::result::Result<(), String> {
335        if self.require_ipv4 && value.is_ipv6() {
336            return Err("IPv4 address required".to_string());
337        }
338
339        if let Some((min, max)) = self.port_range {
340            let port = value.port();
341            if port < min || port > max {
342                return Err(format!("Port {} must be between {} and {}", port, min, max));
343            }
344        }
345
346        Ok(())
347    }
348
349    fn description(&self) -> &str {
350        &self.description
351    }
352}
353
354/// Validator for collecting and running multiple validation rules.
355#[derive(Default)]
356pub struct Validator {
357    errors: ValidationErrors,
358    context: ValidationContext,
359}
360
361impl Validator {
362    /// Create a new validator.
363    pub fn new() -> Self {
364        Self::default()
365    }
366
367    /// Get current validation context.
368    pub fn context(&self) -> &ValidationContext {
369        &self.context
370    }
371
372    /// Get mutable validation context.
373    pub fn context_mut(&mut self) -> &mut ValidationContext {
374        &mut self.context
375    }
376
377    /// Add an error for a field.
378    pub fn add_error(&mut self, field: &str, message: impl Into<String>) {
379        let full_path = self.context.field(field);
380        self.errors.add(full_path, message);
381    }
382
383    /// Add error if condition is true.
384    pub fn add_if(&mut self, condition: bool, field: &str, message: impl Into<String>) {
385        if condition {
386            self.add_error(field, message);
387        }
388    }
389
390    /// Validate a value with a rule.
391    pub fn validate_field<T, R>(&mut self, field: &str, value: &T, rule: &R)
392    where
393        R: ValidationRule<T>,
394    {
395        if let Err(msg) = rule.validate(value) {
396            self.add_error(field, msg);
397        }
398    }
399
400    /// Validate a required string field is not empty.
401    pub fn require_non_empty(&mut self, field: &str, value: &str) {
402        if value.trim().is_empty() {
403            self.add_error(field, "Value cannot be empty");
404        }
405    }
406
407    /// Validate a numeric value is positive.
408    pub fn require_positive<T: PartialOrd + Default + fmt::Display>(
409        &mut self,
410        field: &str,
411        value: T,
412    ) {
413        if value <= T::default() {
414            self.add_error(field, format!("Value must be positive, got {}", value));
415        }
416    }
417
418    /// Validate a value is in a range.
419    pub fn require_range<T: PartialOrd + fmt::Display + Copy>(
420        &mut self,
421        field: &str,
422        value: T,
423        min: T,
424        max: T,
425    ) {
426        if value < min || value > max {
427            self.add_error(
428                field,
429                format!("Value {} must be between {} and {}", value, min, max),
430            );
431        }
432    }
433
434    /// Validate with a nested context.
435    pub fn validate_nested<F>(&mut self, field: &str, f: F)
436    where
437        F: FnOnce(&mut Self),
438    {
439        self.context.enter(field);
440        f(self);
441        self.context.leave();
442    }
443
444    /// Check if there are any errors.
445    pub fn has_errors(&self) -> bool {
446        !self.errors.is_empty()
447    }
448
449    /// Get the collected errors.
450    pub fn errors(&self) -> &ValidationErrors {
451        &self.errors
452    }
453
454    /// Consume and return the errors.
455    pub fn into_errors(self) -> ValidationErrors {
456        self.errors
457    }
458
459    /// Convert to Result.
460    pub fn into_result(self) -> Result<()> {
461        self.errors.into_result(())
462    }
463
464    /// Merge errors from another validator.
465    pub fn merge(&mut self, other: Validator) {
466        self.errors.merge(other.errors);
467    }
468}
469
470/// Cross-field validation helper.
471pub struct CrossFieldValidator<'a, T> {
472    config: &'a T,
473    errors: ValidationErrors,
474}
475
476impl<'a, T> CrossFieldValidator<'a, T> {
477    /// Create a new cross-field validator.
478    pub fn new(config: &'a T) -> Self {
479        Self {
480            config,
481            errors: ValidationErrors::new(),
482        }
483    }
484
485    /// Get the configuration being validated.
486    pub fn config(&self) -> &T {
487        self.config
488    }
489
490    /// Add a cross-field validation error.
491    pub fn add_error(&mut self, fields: &[&str], message: impl Into<String>) {
492        let field_name = fields.join(", ");
493        self.errors.add(field_name, message);
494    }
495
496    /// Add error if condition is true.
497    pub fn add_if(&mut self, condition: bool, fields: &[&str], message: impl Into<String>) {
498        if condition {
499            self.add_error(fields, message);
500        }
501    }
502
503    /// Check if there are any errors.
504    pub fn has_errors(&self) -> bool {
505        !self.errors.is_empty()
506    }
507
508    /// Consume and return the errors.
509    pub fn into_errors(self) -> ValidationErrors {
510        self.errors
511    }
512}
513
514/// Macro to simplify validation checks.
515#[macro_export]
516macro_rules! validate {
517    // Simple condition check
518    ($validator:expr, $field:expr, $cond:expr, $msg:expr) => {
519        $validator.add_if(!$cond, $field, $msg);
520    };
521
522    // Range check
523    ($validator:expr, $field:expr, range $value:expr, $min:expr, $max:expr) => {
524        $validator.require_range($field, $value, $min, $max);
525    };
526
527    // Non-empty string check
528    ($validator:expr, $field:expr, non_empty $value:expr) => {
529        $validator.require_non_empty($field, $value);
530    };
531
532    // Positive number check
533    ($validator:expr, $field:expr, positive $value:expr) => {
534        $validator.require_positive($field, $value);
535    };
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541    use std::path::PathBuf;
542
543    #[test]
544    fn test_range_rule() {
545        let rule = RangeRule::between(0, 100);
546
547        assert!(rule.validate(&50).is_ok());
548        assert!(rule.validate(&0).is_ok());
549        assert!(rule.validate(&100).is_ok());
550        assert!(rule.validate(&-1).is_err());
551        assert!(rule.validate(&101).is_err());
552    }
553
554    #[test]
555    fn test_string_length_rule() {
556        let rule = StringLengthRule::new(Some(3), Some(10));
557
558        assert!(rule.validate(&"hello".to_string()).is_ok());
559        assert!(rule.validate(&"ab".to_string()).is_err());
560        assert!(rule.validate(&"this is too long".to_string()).is_err());
561    }
562
563    #[test]
564    fn test_string_non_empty() {
565        let rule = StringLengthRule::non_empty();
566
567        assert!(rule.validate(&"hello".to_string()).is_ok());
568        assert!(rule.validate(&"".to_string()).is_err());
569    }
570
571    #[test]
572    fn test_socket_addr_rule() {
573        let rule = SocketAddrRule::new().non_privileged_port();
574
575        let valid: SocketAddr = "127.0.0.1:8080".parse().unwrap();
576        let invalid: SocketAddr = "127.0.0.1:80".parse().unwrap();
577
578        assert!(rule.validate(&valid).is_ok());
579        assert!(rule.validate(&invalid).is_err());
580    }
581
582    #[test]
583    fn test_validator() {
584        let mut validator = Validator::new();
585
586        validator.require_non_empty("name", "");
587        validator.require_positive("count", -1i32);
588        validator.require_range("percent", 150, 0, 100);
589
590        assert!(validator.has_errors());
591        assert_eq!(validator.errors().len(), 3);
592    }
593
594    #[test]
595    fn test_validator_nested() {
596        let mut validator = Validator::new();
597
598        validator.validate_nested("engine", |v| {
599            v.add_error("max_devices", "Too low");
600        });
601
602        let errors = validator.into_errors();
603        assert!(errors.get("engine.max_devices").is_some());
604    }
605
606    #[test]
607    fn test_validation_context() {
608        let mut ctx = ValidationContext::new();
609
610        ctx.enter("engine");
611        assert_eq!(ctx.field("max_devices"), "engine.max_devices");
612
613        ctx.enter("modbus");
614        assert_eq!(ctx.field("port"), "engine.modbus.port");
615
616        ctx.leave();
617        assert_eq!(ctx.field("workers"), "engine.workers");
618
619        ctx.leave();
620        assert_eq!(ctx.field("name"), "name");
621    }
622
623    #[test]
624    fn test_cross_field_validator() {
625        struct Config {
626            min: u32,
627            max: u32,
628        }
629
630        let config = Config { min: 100, max: 50 };
631        let mut validator = CrossFieldValidator::new(&config);
632
633        validator.add_if(
634            config.min > config.max,
635            &["min", "max"],
636            "min cannot be greater than max",
637        );
638
639        assert!(validator.has_errors());
640    }
641
642    #[test]
643    fn test_validate_macro() {
644        let mut validator = Validator::new();
645        let value = 150;
646
647        validate!(validator, "percent", value <= 100, "Must be <= 100");
648        assert!(validator.has_errors());
649
650        let mut validator2 = Validator::new();
651        let value2 = 50;
652        validate!(validator2, "percent", value2 <= 100, "Must be <= 100");
653        assert!(!validator2.has_errors());
654    }
655
656    #[test]
657    fn test_path_exists_rule_file() {
658        let rule = PathExistsRule::file();
659        let non_existent = PathBuf::from("/non/existent/file.txt");
660        assert!(rule.validate(&non_existent).is_err());
661    }
662}