use std::fmt;
use std::net::SocketAddr;
use std::path::Path;
use crate::error::ValidationErrors;
use crate::Result;
pub trait Validatable {
fn validate(&self) -> Result<()>;
fn validate_collect(&self, errors: &mut ValidationErrors);
}
#[derive(Debug, Clone, Default)]
pub struct ValidationContext {
path: Vec<String>,
}
impl ValidationContext {
pub fn new() -> Self {
Self::default()
}
pub fn enter(&mut self, field: impl Into<String>) {
self.path.push(field.into());
}
pub fn leave(&mut self) {
self.path.pop();
}
pub fn path(&self) -> String {
self.path.join(".")
}
pub fn field(&self, name: &str) -> String {
if self.path.is_empty() {
name.to_string()
} else {
format!("{}.{}", self.path(), name)
}
}
pub fn with_field<F>(&mut self, field: impl Into<String>, f: F)
where
F: FnOnce(&mut Self),
{
self.enter(field);
f(self);
self.leave();
}
}
pub trait ValidationRule<T: ?Sized>: Send + Sync {
fn validate(&self, value: &T) -> std::result::Result<(), String>;
fn description(&self) -> &str;
}
pub struct RangeRule<T> {
min: Option<T>,
max: Option<T>,
description: String,
}
impl<T: PartialOrd + fmt::Display + Copy> RangeRule<T> {
pub fn new(min: Option<T>, max: Option<T>) -> Self {
let description = match (&min, &max) {
(Some(min), Some(max)) => format!("Value must be between {} and {}", min, max),
(Some(min), None) => format!("Value must be at least {}", min),
(None, Some(max)) => format!("Value must be at most {}", max),
(None, None) => "No range constraint".to_string(),
};
Self {
min,
max,
description,
}
}
pub fn min(min: T) -> Self {
Self::new(Some(min), None)
}
pub fn max(max: T) -> Self {
Self::new(None, Some(max))
}
pub fn between(min: T, max: T) -> Self {
Self::new(Some(min), Some(max))
}
}
impl<T: PartialOrd + fmt::Display + Copy + Send + Sync> ValidationRule<T> for RangeRule<T> {
fn validate(&self, value: &T) -> std::result::Result<(), String> {
if let Some(min) = &self.min {
if value < min {
return Err(format!("Value {} is below minimum {}", value, min));
}
}
if let Some(max) = &self.max {
if value > max {
return Err(format!("Value {} exceeds maximum {}", value, max));
}
}
Ok(())
}
fn description(&self) -> &str {
&self.description
}
}
pub struct StringLengthRule {
min: Option<usize>,
max: Option<usize>,
description: String,
}
impl StringLengthRule {
pub fn new(min: Option<usize>, max: Option<usize>) -> Self {
let description = match (min, max) {
(Some(min), Some(max)) => format!("Length must be between {} and {}", min, max),
(Some(min), None) => format!("Length must be at least {}", min),
(None, Some(max)) => format!("Length must be at most {}", max),
(None, None) => "No length constraint".to_string(),
};
Self {
min,
max,
description,
}
}
pub fn non_empty() -> Self {
Self::new(Some(1), None)
}
pub fn max(max: usize) -> Self {
Self::new(None, Some(max))
}
}
impl ValidationRule<String> for StringLengthRule {
fn validate(&self, value: &String) -> std::result::Result<(), String> {
let len = value.len();
if let Some(min) = self.min {
if len < min {
return Err(format!("String length {} is below minimum {}", len, min));
}
}
if let Some(max) = self.max {
if len > max {
return Err(format!("String length {} exceeds maximum {}", len, max));
}
}
Ok(())
}
fn description(&self) -> &str {
&self.description
}
}
impl ValidationRule<str> for StringLengthRule {
fn validate(&self, value: &str) -> std::result::Result<(), String> {
self.validate(&value.to_string())
}
fn description(&self) -> &str {
&self.description
}
}
pub struct PathExistsRule {
check_file: bool,
check_dir: bool,
description: String,
}
impl PathExistsRule {
pub fn file() -> Self {
Self {
check_file: true,
check_dir: false,
description: "Path must be an existing file".to_string(),
}
}
pub fn directory() -> Self {
Self {
check_file: false,
check_dir: true,
description: "Path must be an existing directory".to_string(),
}
}
pub fn exists() -> Self {
Self {
check_file: false,
check_dir: false,
description: "Path must exist".to_string(),
}
}
}
impl<P: AsRef<Path>> ValidationRule<P> for PathExistsRule {
fn validate(&self, value: &P) -> std::result::Result<(), String> {
let path = value.as_ref();
if self.check_file {
if !path.is_file() {
return Err(format!("Path '{}' is not a file", path.display()));
}
} else if self.check_dir {
if !path.is_dir() {
return Err(format!("Path '{}' is not a directory", path.display()));
}
} else if !path.exists() {
return Err(format!("Path '{}' does not exist", path.display()));
}
Ok(())
}
fn description(&self) -> &str {
&self.description
}
}
pub struct SocketAddrRule {
require_ipv4: bool,
port_range: Option<(u16, u16)>,
description: String,
}
impl SocketAddrRule {
pub fn new() -> Self {
Self {
require_ipv4: false,
port_range: None,
description: "Must be a valid socket address".to_string(),
}
}
pub fn ipv4_only(mut self) -> Self {
self.require_ipv4 = true;
self.description = "Must be a valid IPv4 socket address".to_string();
self
}
pub fn port_range(mut self, min: u16, max: u16) -> Self {
self.port_range = Some((min, max));
self.description = format!(
"Must be a valid socket address with port between {} and {}",
min, max
);
self
}
pub fn non_privileged_port(self) -> Self {
self.port_range(1024, 65535)
}
}
impl Default for SocketAddrRule {
fn default() -> Self {
Self::new()
}
}
impl ValidationRule<SocketAddr> for SocketAddrRule {
fn validate(&self, value: &SocketAddr) -> std::result::Result<(), String> {
if self.require_ipv4 && value.is_ipv6() {
return Err("IPv4 address required".to_string());
}
if let Some((min, max)) = self.port_range {
let port = value.port();
if port < min || port > max {
return Err(format!("Port {} must be between {} and {}", port, min, max));
}
}
Ok(())
}
fn description(&self) -> &str {
&self.description
}
}
#[derive(Default)]
pub struct Validator {
errors: ValidationErrors,
context: ValidationContext,
}
impl Validator {
pub fn new() -> Self {
Self::default()
}
pub fn context(&self) -> &ValidationContext {
&self.context
}
pub fn context_mut(&mut self) -> &mut ValidationContext {
&mut self.context
}
pub fn add_error(&mut self, field: &str, message: impl Into<String>) {
let full_path = self.context.field(field);
self.errors.add(full_path, message);
}
pub fn add_if(&mut self, condition: bool, field: &str, message: impl Into<String>) {
if condition {
self.add_error(field, message);
}
}
pub fn validate_field<T, R>(&mut self, field: &str, value: &T, rule: &R)
where
R: ValidationRule<T>,
{
if let Err(msg) = rule.validate(value) {
self.add_error(field, msg);
}
}
pub fn require_non_empty(&mut self, field: &str, value: &str) {
if value.trim().is_empty() {
self.add_error(field, "Value cannot be empty");
}
}
pub fn require_positive<T: PartialOrd + Default + fmt::Display>(
&mut self,
field: &str,
value: T,
) {
if value <= T::default() {
self.add_error(field, format!("Value must be positive, got {}", value));
}
}
pub fn require_range<T: PartialOrd + fmt::Display + Copy>(
&mut self,
field: &str,
value: T,
min: T,
max: T,
) {
if value < min || value > max {
self.add_error(
field,
format!("Value {} must be between {} and {}", value, min, max),
);
}
}
pub fn validate_nested<F>(&mut self, field: &str, f: F)
where
F: FnOnce(&mut Self),
{
self.context.enter(field);
f(self);
self.context.leave();
}
pub fn has_errors(&self) -> bool {
!self.errors.is_empty()
}
pub fn errors(&self) -> &ValidationErrors {
&self.errors
}
pub fn into_errors(self) -> ValidationErrors {
self.errors
}
pub fn into_result(self) -> Result<()> {
self.errors.into_result(())
}
pub fn merge(&mut self, other: Validator) {
self.errors.merge(other.errors);
}
}
pub struct CrossFieldValidator<'a, T> {
config: &'a T,
errors: ValidationErrors,
}
impl<'a, T> CrossFieldValidator<'a, T> {
pub fn new(config: &'a T) -> Self {
Self {
config,
errors: ValidationErrors::new(),
}
}
pub fn config(&self) -> &T {
self.config
}
pub fn add_error(&mut self, fields: &[&str], message: impl Into<String>) {
let field_name = fields.join(", ");
self.errors.add(field_name, message);
}
pub fn add_if(&mut self, condition: bool, fields: &[&str], message: impl Into<String>) {
if condition {
self.add_error(fields, message);
}
}
pub fn has_errors(&self) -> bool {
!self.errors.is_empty()
}
pub fn into_errors(self) -> ValidationErrors {
self.errors
}
}
#[macro_export]
macro_rules! validate {
($validator:expr, $field:expr, $cond:expr, $msg:expr) => {
$validator.add_if(!$cond, $field, $msg);
};
($validator:expr, $field:expr, range $value:expr, $min:expr, $max:expr) => {
$validator.require_range($field, $value, $min, $max);
};
($validator:expr, $field:expr, non_empty $value:expr) => {
$validator.require_non_empty($field, $value);
};
($validator:expr, $field:expr, positive $value:expr) => {
$validator.require_positive($field, $value);
};
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_range_rule() {
let rule = RangeRule::between(0, 100);
assert!(rule.validate(&50).is_ok());
assert!(rule.validate(&0).is_ok());
assert!(rule.validate(&100).is_ok());
assert!(rule.validate(&-1).is_err());
assert!(rule.validate(&101).is_err());
}
#[test]
fn test_string_length_rule() {
let rule = StringLengthRule::new(Some(3), Some(10));
assert!(rule.validate(&"hello".to_string()).is_ok());
assert!(rule.validate(&"ab".to_string()).is_err());
assert!(rule.validate(&"this is too long".to_string()).is_err());
}
#[test]
fn test_string_non_empty() {
let rule = StringLengthRule::non_empty();
assert!(rule.validate(&"hello".to_string()).is_ok());
assert!(rule.validate(&"".to_string()).is_err());
}
#[test]
fn test_socket_addr_rule() {
let rule = SocketAddrRule::new().non_privileged_port();
let valid: SocketAddr = "127.0.0.1:8080".parse().unwrap();
let invalid: SocketAddr = "127.0.0.1:80".parse().unwrap();
assert!(rule.validate(&valid).is_ok());
assert!(rule.validate(&invalid).is_err());
}
#[test]
fn test_validator() {
let mut validator = Validator::new();
validator.require_non_empty("name", "");
validator.require_positive("count", -1i32);
validator.require_range("percent", 150, 0, 100);
assert!(validator.has_errors());
assert_eq!(validator.errors().len(), 3);
}
#[test]
fn test_validator_nested() {
let mut validator = Validator::new();
validator.validate_nested("engine", |v| {
v.add_error("max_devices", "Too low");
});
let errors = validator.into_errors();
assert!(errors.get("engine.max_devices").is_some());
}
#[test]
fn test_validation_context() {
let mut ctx = ValidationContext::new();
ctx.enter("engine");
assert_eq!(ctx.field("max_devices"), "engine.max_devices");
ctx.enter("modbus");
assert_eq!(ctx.field("port"), "engine.modbus.port");
ctx.leave();
assert_eq!(ctx.field("workers"), "engine.workers");
ctx.leave();
assert_eq!(ctx.field("name"), "name");
}
#[test]
fn test_cross_field_validator() {
struct Config {
min: u32,
max: u32,
}
let config = Config { min: 100, max: 50 };
let mut validator = CrossFieldValidator::new(&config);
validator.add_if(
config.min > config.max,
&["min", "max"],
"min cannot be greater than max",
);
assert!(validator.has_errors());
}
#[test]
fn test_validate_macro() {
let mut validator = Validator::new();
let value = 150;
validate!(validator, "percent", value <= 100, "Must be <= 100");
assert!(validator.has_errors());
let mut validator2 = Validator::new();
let value2 = 50;
validate!(validator2, "percent", value2 <= 100, "Must be <= 100");
assert!(!validator2.has_errors());
}
#[test]
fn test_path_exists_rule_file() {
let rule = PathExistsRule::file();
let non_existent = PathBuf::from("/non/existent/file.txt");
assert!(rule.validate(&non_existent).is_err());
}
}