1use std::fmt;
31use std::net::SocketAddr;
32use std::path::Path;
33
34use crate::error::ValidationErrors;
35use crate::Result;
36
37pub trait Validatable {
39 fn validate(&self) -> Result<()>;
43
44 fn validate_collect(&self, errors: &mut ValidationErrors);
46}
47
48#[derive(Debug, Clone, Default)]
50pub struct ValidationContext {
51 path: Vec<String>,
53}
54
55impl ValidationContext {
56 pub fn new() -> Self {
58 Self::default()
59 }
60
61 pub fn enter(&mut self, field: impl Into<String>) {
63 self.path.push(field.into());
64 }
65
66 pub fn leave(&mut self) {
68 self.path.pop();
69 }
70
71 pub fn path(&self) -> String {
73 self.path.join(".")
74 }
75
76 pub fn field(&self, name: &str) -> String {
78 if self.path.is_empty() {
79 name.to_string()
80 } else {
81 format!("{}.{}", self.path(), name)
82 }
83 }
84
85 pub fn with_field<F>(&mut self, field: impl Into<String>, f: F)
87 where
88 F: FnOnce(&mut Self),
89 {
90 self.enter(field);
91 f(self);
92 self.leave();
93 }
94}
95
96pub trait ValidationRule<T: ?Sized>: Send + Sync {
98 fn validate(&self, value: &T) -> std::result::Result<(), String>;
100
101 fn description(&self) -> &str;
103}
104
105pub struct RangeRule<T> {
107 min: Option<T>,
108 max: Option<T>,
109 description: String,
110}
111
112impl<T: PartialOrd + fmt::Display + Copy> RangeRule<T> {
113 pub fn new(min: Option<T>, max: Option<T>) -> Self {
115 let description = match (&min, &max) {
116 (Some(min), Some(max)) => format!("Value must be between {} and {}", min, max),
117 (Some(min), None) => format!("Value must be at least {}", min),
118 (None, Some(max)) => format!("Value must be at most {}", max),
119 (None, None) => "No range constraint".to_string(),
120 };
121 Self { min, max, description }
122 }
123
124 pub fn min(min: T) -> Self {
126 Self::new(Some(min), None)
127 }
128
129 pub fn max(max: T) -> Self {
131 Self::new(None, Some(max))
132 }
133
134 pub fn between(min: T, max: T) -> Self {
136 Self::new(Some(min), Some(max))
137 }
138}
139
140impl<T: PartialOrd + fmt::Display + Copy + Send + Sync> ValidationRule<T> for RangeRule<T> {
141 fn validate(&self, value: &T) -> std::result::Result<(), String> {
142 if let Some(min) = &self.min {
143 if value < min {
144 return Err(format!("Value {} is below minimum {}", value, min));
145 }
146 }
147 if let Some(max) = &self.max {
148 if value > max {
149 return Err(format!("Value {} exceeds maximum {}", value, max));
150 }
151 }
152 Ok(())
153 }
154
155 fn description(&self) -> &str {
156 &self.description
157 }
158}
159
160pub struct StringLengthRule {
162 min: Option<usize>,
163 max: Option<usize>,
164 description: String,
165}
166
167impl StringLengthRule {
168 pub fn new(min: Option<usize>, max: Option<usize>) -> Self {
170 let description = match (min, max) {
171 (Some(min), Some(max)) => format!("Length must be between {} and {}", min, max),
172 (Some(min), None) => format!("Length must be at least {}", min),
173 (None, Some(max)) => format!("Length must be at most {}", max),
174 (None, None) => "No length constraint".to_string(),
175 };
176 Self { min, max, description }
177 }
178
179 pub fn non_empty() -> Self {
181 Self::new(Some(1), None)
182 }
183
184 pub fn max(max: usize) -> Self {
186 Self::new(None, Some(max))
187 }
188}
189
190impl ValidationRule<String> for StringLengthRule {
191 fn validate(&self, value: &String) -> std::result::Result<(), String> {
192 let len = value.len();
193 if let Some(min) = self.min {
194 if len < min {
195 return Err(format!("String length {} is below minimum {}", len, min));
196 }
197 }
198 if let Some(max) = self.max {
199 if len > max {
200 return Err(format!("String length {} exceeds maximum {}", len, max));
201 }
202 }
203 Ok(())
204 }
205
206 fn description(&self) -> &str {
207 &self.description
208 }
209}
210
211impl ValidationRule<str> for StringLengthRule {
212 fn validate(&self, value: &str) -> std::result::Result<(), String> {
213 self.validate(&value.to_string())
214 }
215
216 fn description(&self) -> &str {
217 &self.description
218 }
219}
220
221pub struct PathExistsRule {
223 check_file: bool,
224 check_dir: bool,
225 description: String,
226}
227
228impl PathExistsRule {
229 pub fn file() -> Self {
231 Self {
232 check_file: true,
233 check_dir: false,
234 description: "Path must be an existing file".to_string(),
235 }
236 }
237
238 pub fn directory() -> Self {
240 Self {
241 check_file: false,
242 check_dir: true,
243 description: "Path must be an existing directory".to_string(),
244 }
245 }
246
247 pub fn exists() -> Self {
249 Self {
250 check_file: false,
251 check_dir: false,
252 description: "Path must exist".to_string(),
253 }
254 }
255}
256
257impl<P: AsRef<Path>> ValidationRule<P> for PathExistsRule {
258 fn validate(&self, value: &P) -> std::result::Result<(), String> {
259 let path = value.as_ref();
260 if self.check_file {
261 if !path.is_file() {
262 return Err(format!("Path '{}' is not a file", path.display()));
263 }
264 } else if self.check_dir {
265 if !path.is_dir() {
266 return Err(format!("Path '{}' is not a directory", path.display()));
267 }
268 } else if !path.exists() {
269 return Err(format!("Path '{}' does not exist", path.display()));
270 }
271 Ok(())
272 }
273
274 fn description(&self) -> &str {
275 &self.description
276 }
277}
278
279pub struct SocketAddrRule {
281 require_ipv4: bool,
282 port_range: Option<(u16, u16)>,
283 description: String,
284}
285
286impl SocketAddrRule {
287 pub fn new() -> Self {
289 Self {
290 require_ipv4: false,
291 port_range: None,
292 description: "Must be a valid socket address".to_string(),
293 }
294 }
295
296 pub fn ipv4_only(mut self) -> Self {
298 self.require_ipv4 = true;
299 self.description = "Must be a valid IPv4 socket address".to_string();
300 self
301 }
302
303 pub fn port_range(mut self, min: u16, max: u16) -> Self {
305 self.port_range = Some((min, max));
306 self.description = format!(
307 "Must be a valid socket address with port between {} and {}",
308 min, max
309 );
310 self
311 }
312
313 pub fn non_privileged_port(self) -> Self {
315 self.port_range(1024, 65535)
316 }
317}
318
319impl Default for SocketAddrRule {
320 fn default() -> Self {
321 Self::new()
322 }
323}
324
325impl ValidationRule<SocketAddr> for SocketAddrRule {
326 fn validate(&self, value: &SocketAddr) -> std::result::Result<(), String> {
327 if self.require_ipv4 && value.is_ipv6() {
328 return Err("IPv4 address required".to_string());
329 }
330
331 if let Some((min, max)) = self.port_range {
332 let port = value.port();
333 if port < min || port > max {
334 return Err(format!("Port {} must be between {} and {}", port, min, max));
335 }
336 }
337
338 Ok(())
339 }
340
341 fn description(&self) -> &str {
342 &self.description
343 }
344}
345
346#[derive(Default)]
348pub struct Validator {
349 errors: ValidationErrors,
350 context: ValidationContext,
351}
352
353impl Validator {
354 pub fn new() -> Self {
356 Self::default()
357 }
358
359 pub fn context(&self) -> &ValidationContext {
361 &self.context
362 }
363
364 pub fn context_mut(&mut self) -> &mut ValidationContext {
366 &mut self.context
367 }
368
369 pub fn add_error(&mut self, field: &str, message: impl Into<String>) {
371 let full_path = self.context.field(field);
372 self.errors.add(full_path, message);
373 }
374
375 pub fn add_if(&mut self, condition: bool, field: &str, message: impl Into<String>) {
377 if condition {
378 self.add_error(field, message);
379 }
380 }
381
382 pub fn validate_field<T, R>(&mut self, field: &str, value: &T, rule: &R)
384 where
385 R: ValidationRule<T>,
386 {
387 if let Err(msg) = rule.validate(value) {
388 self.add_error(field, msg);
389 }
390 }
391
392 pub fn require_non_empty(&mut self, field: &str, value: &str) {
394 if value.trim().is_empty() {
395 self.add_error(field, "Value cannot be empty");
396 }
397 }
398
399 pub fn require_positive<T: PartialOrd + Default + fmt::Display>(&mut self, field: &str, value: T) {
401 if value <= T::default() {
402 self.add_error(field, format!("Value must be positive, got {}", value));
403 }
404 }
405
406 pub fn require_range<T: PartialOrd + fmt::Display + Copy>(
408 &mut self,
409 field: &str,
410 value: T,
411 min: T,
412 max: T,
413 ) {
414 if value < min || value > max {
415 self.add_error(
416 field,
417 format!("Value {} must be between {} and {}", value, min, max),
418 );
419 }
420 }
421
422 pub fn validate_nested<F>(&mut self, field: &str, f: F)
424 where
425 F: FnOnce(&mut Self),
426 {
427 self.context.enter(field);
428 f(self);
429 self.context.leave();
430 }
431
432 pub fn has_errors(&self) -> bool {
434 !self.errors.is_empty()
435 }
436
437 pub fn errors(&self) -> &ValidationErrors {
439 &self.errors
440 }
441
442 pub fn into_errors(self) -> ValidationErrors {
444 self.errors
445 }
446
447 pub fn into_result(self) -> Result<()> {
449 self.errors.into_result(())
450 }
451
452 pub fn merge(&mut self, other: Validator) {
454 self.errors.merge(other.errors);
455 }
456}
457
458pub struct CrossFieldValidator<'a, T> {
460 config: &'a T,
461 errors: ValidationErrors,
462}
463
464impl<'a, T> CrossFieldValidator<'a, T> {
465 pub fn new(config: &'a T) -> Self {
467 Self {
468 config,
469 errors: ValidationErrors::new(),
470 }
471 }
472
473 pub fn config(&self) -> &T {
475 self.config
476 }
477
478 pub fn add_error(&mut self, fields: &[&str], message: impl Into<String>) {
480 let field_name = fields.join(", ");
481 self.errors.add(field_name, message);
482 }
483
484 pub fn add_if(&mut self, condition: bool, fields: &[&str], message: impl Into<String>) {
486 if condition {
487 self.add_error(fields, message);
488 }
489 }
490
491 pub fn has_errors(&self) -> bool {
493 !self.errors.is_empty()
494 }
495
496 pub fn into_errors(self) -> ValidationErrors {
498 self.errors
499 }
500}
501
502#[macro_export]
504macro_rules! validate {
505 ($validator:expr, $field:expr, $cond:expr, $msg:expr) => {
507 $validator.add_if(!$cond, $field, $msg);
508 };
509
510 ($validator:expr, $field:expr, range $value:expr, $min:expr, $max:expr) => {
512 $validator.require_range($field, $value, $min, $max);
513 };
514
515 ($validator:expr, $field:expr, non_empty $value:expr) => {
517 $validator.require_non_empty($field, $value);
518 };
519
520 ($validator:expr, $field:expr, positive $value:expr) => {
522 $validator.require_positive($field, $value);
523 };
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529 use std::path::PathBuf;
530
531 #[test]
532 fn test_range_rule() {
533 let rule = RangeRule::between(0, 100);
534
535 assert!(rule.validate(&50).is_ok());
536 assert!(rule.validate(&0).is_ok());
537 assert!(rule.validate(&100).is_ok());
538 assert!(rule.validate(&-1).is_err());
539 assert!(rule.validate(&101).is_err());
540 }
541
542 #[test]
543 fn test_string_length_rule() {
544 let rule = StringLengthRule::new(Some(3), Some(10));
545
546 assert!(rule.validate(&"hello".to_string()).is_ok());
547 assert!(rule.validate(&"ab".to_string()).is_err());
548 assert!(rule.validate(&"this is too long".to_string()).is_err());
549 }
550
551 #[test]
552 fn test_string_non_empty() {
553 let rule = StringLengthRule::non_empty();
554
555 assert!(rule.validate(&"hello".to_string()).is_ok());
556 assert!(rule.validate(&"".to_string()).is_err());
557 }
558
559 #[test]
560 fn test_socket_addr_rule() {
561 let rule = SocketAddrRule::new().non_privileged_port();
562
563 let valid: SocketAddr = "127.0.0.1:8080".parse().unwrap();
564 let invalid: SocketAddr = "127.0.0.1:80".parse().unwrap();
565
566 assert!(rule.validate(&valid).is_ok());
567 assert!(rule.validate(&invalid).is_err());
568 }
569
570 #[test]
571 fn test_validator() {
572 let mut validator = Validator::new();
573
574 validator.require_non_empty("name", "");
575 validator.require_positive("count", -1i32);
576 validator.require_range("percent", 150, 0, 100);
577
578 assert!(validator.has_errors());
579 assert_eq!(validator.errors().len(), 3);
580 }
581
582 #[test]
583 fn test_validator_nested() {
584 let mut validator = Validator::new();
585
586 validator.validate_nested("engine", |v| {
587 v.add_error("max_devices", "Too low");
588 });
589
590 let errors = validator.into_errors();
591 assert!(errors.get("engine.max_devices").is_some());
592 }
593
594 #[test]
595 fn test_validation_context() {
596 let mut ctx = ValidationContext::new();
597
598 ctx.enter("engine");
599 assert_eq!(ctx.field("max_devices"), "engine.max_devices");
600
601 ctx.enter("modbus");
602 assert_eq!(ctx.field("port"), "engine.modbus.port");
603
604 ctx.leave();
605 assert_eq!(ctx.field("workers"), "engine.workers");
606
607 ctx.leave();
608 assert_eq!(ctx.field("name"), "name");
609 }
610
611 #[test]
612 fn test_cross_field_validator() {
613 struct Config {
614 min: u32,
615 max: u32,
616 }
617
618 let config = Config { min: 100, max: 50 };
619 let mut validator = CrossFieldValidator::new(&config);
620
621 validator.add_if(
622 config.min > config.max,
623 &["min", "max"],
624 "min cannot be greater than max",
625 );
626
627 assert!(validator.has_errors());
628 }
629
630 #[test]
631 fn test_validate_macro() {
632 let mut validator = Validator::new();
633 let value = 150;
634
635 validate!(validator, "percent", value <= 100, "Must be <= 100");
636 assert!(validator.has_errors());
637
638 let mut validator2 = Validator::new();
639 let value2 = 50;
640 validate!(validator2, "percent", value2 <= 100, "Must be <= 100");
641 assert!(!validator2.has_errors());
642 }
643
644 #[test]
645 fn test_path_exists_rule_file() {
646 let rule = PathExistsRule::file();
647 let non_existent = PathBuf::from("/non/existent/file.txt");
648 assert!(rule.validate(&non_existent).is_err());
649 }
650}