1use std::fmt;
31use std::net::SocketAddr;
32use std::path::Path;
33
34use crate::error::ValidationErrors;
35use crate::Result;
36
37pub trait Validatable {
39 fn validate(&self) -> Result<()>;
43
44 fn validate_collect(&self, errors: &mut ValidationErrors);
46}
47
48#[derive(Debug, Clone, Default)]
50pub struct ValidationContext {
51 path: Vec<String>,
53}
54
55impl ValidationContext {
56 pub fn new() -> Self {
58 Self::default()
59 }
60
61 pub fn enter(&mut self, field: impl Into<String>) {
63 self.path.push(field.into());
64 }
65
66 pub fn leave(&mut self) {
68 self.path.pop();
69 }
70
71 pub fn path(&self) -> String {
73 self.path.join(".")
74 }
75
76 pub fn field(&self, name: &str) -> String {
78 if self.path.is_empty() {
79 name.to_string()
80 } else {
81 format!("{}.{}", self.path(), name)
82 }
83 }
84
85 pub fn with_field<F>(&mut self, field: impl Into<String>, f: F)
87 where
88 F: FnOnce(&mut Self),
89 {
90 self.enter(field);
91 f(self);
92 self.leave();
93 }
94}
95
96pub trait ValidationRule<T: ?Sized>: Send + Sync {
98 fn validate(&self, value: &T) -> std::result::Result<(), String>;
100
101 fn description(&self) -> &str;
103}
104
105pub struct RangeRule<T> {
107 min: Option<T>,
108 max: Option<T>,
109 description: String,
110}
111
112impl<T: PartialOrd + fmt::Display + Copy> RangeRule<T> {
113 pub fn new(min: Option<T>, max: Option<T>) -> Self {
115 let description = match (&min, &max) {
116 (Some(min), Some(max)) => format!("Value must be between {} and {}", min, max),
117 (Some(min), None) => format!("Value must be at least {}", min),
118 (None, Some(max)) => format!("Value must be at most {}", max),
119 (None, None) => "No range constraint".to_string(),
120 };
121 Self {
122 min,
123 max,
124 description,
125 }
126 }
127
128 pub fn min(min: T) -> Self {
130 Self::new(Some(min), None)
131 }
132
133 pub fn max(max: T) -> Self {
135 Self::new(None, Some(max))
136 }
137
138 pub fn between(min: T, max: T) -> Self {
140 Self::new(Some(min), Some(max))
141 }
142}
143
144impl<T: PartialOrd + fmt::Display + Copy + Send + Sync> ValidationRule<T> for RangeRule<T> {
145 fn validate(&self, value: &T) -> std::result::Result<(), String> {
146 if let Some(min) = &self.min {
147 if value < min {
148 return Err(format!("Value {} is below minimum {}", value, min));
149 }
150 }
151 if let Some(max) = &self.max {
152 if value > max {
153 return Err(format!("Value {} exceeds maximum {}", value, max));
154 }
155 }
156 Ok(())
157 }
158
159 fn description(&self) -> &str {
160 &self.description
161 }
162}
163
164pub struct StringLengthRule {
166 min: Option<usize>,
167 max: Option<usize>,
168 description: String,
169}
170
171impl StringLengthRule {
172 pub fn new(min: Option<usize>, max: Option<usize>) -> Self {
174 let description = match (min, max) {
175 (Some(min), Some(max)) => format!("Length must be between {} and {}", min, max),
176 (Some(min), None) => format!("Length must be at least {}", min),
177 (None, Some(max)) => format!("Length must be at most {}", max),
178 (None, None) => "No length constraint".to_string(),
179 };
180 Self {
181 min,
182 max,
183 description,
184 }
185 }
186
187 pub fn non_empty() -> Self {
189 Self::new(Some(1), None)
190 }
191
192 pub fn max(max: usize) -> Self {
194 Self::new(None, Some(max))
195 }
196}
197
198impl ValidationRule<String> for StringLengthRule {
199 fn validate(&self, value: &String) -> std::result::Result<(), String> {
200 let len = value.len();
201 if let Some(min) = self.min {
202 if len < min {
203 return Err(format!("String length {} is below minimum {}", len, min));
204 }
205 }
206 if let Some(max) = self.max {
207 if len > max {
208 return Err(format!("String length {} exceeds maximum {}", len, max));
209 }
210 }
211 Ok(())
212 }
213
214 fn description(&self) -> &str {
215 &self.description
216 }
217}
218
219impl ValidationRule<str> for StringLengthRule {
220 fn validate(&self, value: &str) -> std::result::Result<(), String> {
221 self.validate(&value.to_string())
222 }
223
224 fn description(&self) -> &str {
225 &self.description
226 }
227}
228
229pub struct PathExistsRule {
231 check_file: bool,
232 check_dir: bool,
233 description: String,
234}
235
236impl PathExistsRule {
237 pub fn file() -> Self {
239 Self {
240 check_file: true,
241 check_dir: false,
242 description: "Path must be an existing file".to_string(),
243 }
244 }
245
246 pub fn directory() -> Self {
248 Self {
249 check_file: false,
250 check_dir: true,
251 description: "Path must be an existing directory".to_string(),
252 }
253 }
254
255 pub fn exists() -> Self {
257 Self {
258 check_file: false,
259 check_dir: false,
260 description: "Path must exist".to_string(),
261 }
262 }
263}
264
265impl<P: AsRef<Path>> ValidationRule<P> for PathExistsRule {
266 fn validate(&self, value: &P) -> std::result::Result<(), String> {
267 let path = value.as_ref();
268 if self.check_file {
269 if !path.is_file() {
270 return Err(format!("Path '{}' is not a file", path.display()));
271 }
272 } else if self.check_dir {
273 if !path.is_dir() {
274 return Err(format!("Path '{}' is not a directory", path.display()));
275 }
276 } else if !path.exists() {
277 return Err(format!("Path '{}' does not exist", path.display()));
278 }
279 Ok(())
280 }
281
282 fn description(&self) -> &str {
283 &self.description
284 }
285}
286
287pub struct SocketAddrRule {
289 require_ipv4: bool,
290 port_range: Option<(u16, u16)>,
291 description: String,
292}
293
294impl SocketAddrRule {
295 pub fn new() -> Self {
297 Self {
298 require_ipv4: false,
299 port_range: None,
300 description: "Must be a valid socket address".to_string(),
301 }
302 }
303
304 pub fn ipv4_only(mut self) -> Self {
306 self.require_ipv4 = true;
307 self.description = "Must be a valid IPv4 socket address".to_string();
308 self
309 }
310
311 pub fn port_range(mut self, min: u16, max: u16) -> Self {
313 self.port_range = Some((min, max));
314 self.description = format!(
315 "Must be a valid socket address with port between {} and {}",
316 min, max
317 );
318 self
319 }
320
321 pub fn non_privileged_port(self) -> Self {
323 self.port_range(1024, 65535)
324 }
325}
326
327impl Default for SocketAddrRule {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333impl ValidationRule<SocketAddr> for SocketAddrRule {
334 fn validate(&self, value: &SocketAddr) -> std::result::Result<(), String> {
335 if self.require_ipv4 && value.is_ipv6() {
336 return Err("IPv4 address required".to_string());
337 }
338
339 if let Some((min, max)) = self.port_range {
340 let port = value.port();
341 if port < min || port > max {
342 return Err(format!("Port {} must be between {} and {}", port, min, max));
343 }
344 }
345
346 Ok(())
347 }
348
349 fn description(&self) -> &str {
350 &self.description
351 }
352}
353
354#[derive(Default)]
356pub struct Validator {
357 errors: ValidationErrors,
358 context: ValidationContext,
359}
360
361impl Validator {
362 pub fn new() -> Self {
364 Self::default()
365 }
366
367 pub fn context(&self) -> &ValidationContext {
369 &self.context
370 }
371
372 pub fn context_mut(&mut self) -> &mut ValidationContext {
374 &mut self.context
375 }
376
377 pub fn add_error(&mut self, field: &str, message: impl Into<String>) {
379 let full_path = self.context.field(field);
380 self.errors.add(full_path, message);
381 }
382
383 pub fn add_if(&mut self, condition: bool, field: &str, message: impl Into<String>) {
385 if condition {
386 self.add_error(field, message);
387 }
388 }
389
390 pub fn validate_field<T, R>(&mut self, field: &str, value: &T, rule: &R)
392 where
393 R: ValidationRule<T>,
394 {
395 if let Err(msg) = rule.validate(value) {
396 self.add_error(field, msg);
397 }
398 }
399
400 pub fn require_non_empty(&mut self, field: &str, value: &str) {
402 if value.trim().is_empty() {
403 self.add_error(field, "Value cannot be empty");
404 }
405 }
406
407 pub fn require_positive<T: PartialOrd + Default + fmt::Display>(
409 &mut self,
410 field: &str,
411 value: T,
412 ) {
413 if value <= T::default() {
414 self.add_error(field, format!("Value must be positive, got {}", value));
415 }
416 }
417
418 pub fn require_range<T: PartialOrd + fmt::Display + Copy>(
420 &mut self,
421 field: &str,
422 value: T,
423 min: T,
424 max: T,
425 ) {
426 if value < min || value > max {
427 self.add_error(
428 field,
429 format!("Value {} must be between {} and {}", value, min, max),
430 );
431 }
432 }
433
434 pub fn validate_nested<F>(&mut self, field: &str, f: F)
436 where
437 F: FnOnce(&mut Self),
438 {
439 self.context.enter(field);
440 f(self);
441 self.context.leave();
442 }
443
444 pub fn has_errors(&self) -> bool {
446 !self.errors.is_empty()
447 }
448
449 pub fn errors(&self) -> &ValidationErrors {
451 &self.errors
452 }
453
454 pub fn into_errors(self) -> ValidationErrors {
456 self.errors
457 }
458
459 pub fn into_result(self) -> Result<()> {
461 self.errors.into_result(())
462 }
463
464 pub fn merge(&mut self, other: Validator) {
466 self.errors.merge(other.errors);
467 }
468}
469
470pub struct CrossFieldValidator<'a, T> {
472 config: &'a T,
473 errors: ValidationErrors,
474}
475
476impl<'a, T> CrossFieldValidator<'a, T> {
477 pub fn new(config: &'a T) -> Self {
479 Self {
480 config,
481 errors: ValidationErrors::new(),
482 }
483 }
484
485 pub fn config(&self) -> &T {
487 self.config
488 }
489
490 pub fn add_error(&mut self, fields: &[&str], message: impl Into<String>) {
492 let field_name = fields.join(", ");
493 self.errors.add(field_name, message);
494 }
495
496 pub fn add_if(&mut self, condition: bool, fields: &[&str], message: impl Into<String>) {
498 if condition {
499 self.add_error(fields, message);
500 }
501 }
502
503 pub fn has_errors(&self) -> bool {
505 !self.errors.is_empty()
506 }
507
508 pub fn into_errors(self) -> ValidationErrors {
510 self.errors
511 }
512}
513
514#[macro_export]
516macro_rules! validate {
517 ($validator:expr, $field:expr, $cond:expr, $msg:expr) => {
519 $validator.add_if(!$cond, $field, $msg);
520 };
521
522 ($validator:expr, $field:expr, range $value:expr, $min:expr, $max:expr) => {
524 $validator.require_range($field, $value, $min, $max);
525 };
526
527 ($validator:expr, $field:expr, non_empty $value:expr) => {
529 $validator.require_non_empty($field, $value);
530 };
531
532 ($validator:expr, $field:expr, positive $value:expr) => {
534 $validator.require_positive($field, $value);
535 };
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use std::path::PathBuf;
542
543 #[test]
544 fn test_range_rule() {
545 let rule = RangeRule::between(0, 100);
546
547 assert!(rule.validate(&50).is_ok());
548 assert!(rule.validate(&0).is_ok());
549 assert!(rule.validate(&100).is_ok());
550 assert!(rule.validate(&-1).is_err());
551 assert!(rule.validate(&101).is_err());
552 }
553
554 #[test]
555 fn test_string_length_rule() {
556 let rule = StringLengthRule::new(Some(3), Some(10));
557
558 assert!(rule.validate(&"hello".to_string()).is_ok());
559 assert!(rule.validate(&"ab".to_string()).is_err());
560 assert!(rule.validate(&"this is too long".to_string()).is_err());
561 }
562
563 #[test]
564 fn test_string_non_empty() {
565 let rule = StringLengthRule::non_empty();
566
567 assert!(rule.validate(&"hello".to_string()).is_ok());
568 assert!(rule.validate(&"".to_string()).is_err());
569 }
570
571 #[test]
572 fn test_socket_addr_rule() {
573 let rule = SocketAddrRule::new().non_privileged_port();
574
575 let valid: SocketAddr = "127.0.0.1:8080".parse().unwrap();
576 let invalid: SocketAddr = "127.0.0.1:80".parse().unwrap();
577
578 assert!(rule.validate(&valid).is_ok());
579 assert!(rule.validate(&invalid).is_err());
580 }
581
582 #[test]
583 fn test_validator() {
584 let mut validator = Validator::new();
585
586 validator.require_non_empty("name", "");
587 validator.require_positive("count", -1i32);
588 validator.require_range("percent", 150, 0, 100);
589
590 assert!(validator.has_errors());
591 assert_eq!(validator.errors().len(), 3);
592 }
593
594 #[test]
595 fn test_validator_nested() {
596 let mut validator = Validator::new();
597
598 validator.validate_nested("engine", |v| {
599 v.add_error("max_devices", "Too low");
600 });
601
602 let errors = validator.into_errors();
603 assert!(errors.get("engine.max_devices").is_some());
604 }
605
606 #[test]
607 fn test_validation_context() {
608 let mut ctx = ValidationContext::new();
609
610 ctx.enter("engine");
611 assert_eq!(ctx.field("max_devices"), "engine.max_devices");
612
613 ctx.enter("modbus");
614 assert_eq!(ctx.field("port"), "engine.modbus.port");
615
616 ctx.leave();
617 assert_eq!(ctx.field("workers"), "engine.workers");
618
619 ctx.leave();
620 assert_eq!(ctx.field("name"), "name");
621 }
622
623 #[test]
624 fn test_cross_field_validator() {
625 struct Config {
626 min: u32,
627 max: u32,
628 }
629
630 let config = Config { min: 100, max: 50 };
631 let mut validator = CrossFieldValidator::new(&config);
632
633 validator.add_if(
634 config.min > config.max,
635 &["min", "max"],
636 "min cannot be greater than max",
637 );
638
639 assert!(validator.has_errors());
640 }
641
642 #[test]
643 fn test_validate_macro() {
644 let mut validator = Validator::new();
645 let value = 150;
646
647 validate!(validator, "percent", value <= 100, "Must be <= 100");
648 assert!(validator.has_errors());
649
650 let mut validator2 = Validator::new();
651 let value2 = 50;
652 validate!(validator2, "percent", value2 <= 100, "Must be <= 100");
653 assert!(!validator2.has_errors());
654 }
655
656 #[test]
657 fn test_path_exists_rule_file() {
658 let rule = PathExistsRule::file();
659 let non_existent = PathBuf::from("/non/existent/file.txt");
660 assert!(rule.validate(&non_existent).is_err());
661 }
662}