use fraiseql_error::{FraiseQLError, Result};
use regex::{Regex, RegexBuilder};
use serde_json::Value;
const MAX_PATTERN_BYTES: usize = 1024;
fn compile_pattern(pattern: &str) -> Result<Regex> {
if pattern.len() > MAX_PATTERN_BYTES {
return Err(FraiseQLError::validation(format!(
"Validation pattern too long ({} bytes, max {MAX_PATTERN_BYTES})",
pattern.len()
)));
}
RegexBuilder::new(pattern)
.size_limit(1 << 20) .build()
.map_err(|e| {
FraiseQLError::validation(format!("Invalid validation pattern '{pattern}': {e}"))
})
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum ValidationRule {
Pattern(Regex),
Length(usize),
LengthRange {
min: usize,
max: usize,
},
Checksum(ChecksumType),
NumericRange {
min: f64,
max: f64,
},
Enum(Vec<String>),
All(Vec<ValidationRule>),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ChecksumType {
Mod97,
Luhn,
}
impl ValidationRule {
pub fn validate(&self, value: &str) -> Result<()> {
match self {
ValidationRule::Pattern(re) => {
if !re.is_match(value) {
return Err(FraiseQLError::validation(format!(
"Value '{}' does not match pattern '{}'",
value,
re.as_str()
)));
}
Ok(())
},
ValidationRule::Length(expected) => {
if value.len() != *expected {
return Err(FraiseQLError::validation(format!(
"Value '{}' has length {}, expected {}",
value,
value.len(),
expected
)));
}
Ok(())
},
ValidationRule::LengthRange { min, max } => {
let len = value.len();
if len < *min || len > *max {
return Err(FraiseQLError::validation(format!(
"Value '{}' has length {}, expected between {} and {}",
value, len, min, max
)));
}
Ok(())
},
ValidationRule::Checksum(checksum_type) => {
match checksum_type {
ChecksumType::Mod97 => validate_mod97(value)?,
ChecksumType::Luhn => validate_luhn(value)?,
}
Ok(())
},
ValidationRule::NumericRange { min, max } => {
let num: f64 = value.parse().map_err(|_| {
FraiseQLError::validation(format!("Value '{}' is not a valid number", value))
})?;
if num < *min || num > *max {
return Err(FraiseQLError::validation(format!(
"Value {} is outside range [{}, {}]",
num, min, max
)));
}
Ok(())
},
ValidationRule::Enum(options) => {
if !options.contains(&value.to_string()) {
return Err(FraiseQLError::validation(format!(
"Value '{}' must be one of: {}",
value,
options.join(", ")
)));
}
Ok(())
},
ValidationRule::All(rules) => {
for rule in rules {
rule.validate(value)?;
}
Ok(())
},
}
}
pub fn from_json(value: &Value) -> Result<Self> {
match value {
Value::String(s) => {
let re = compile_pattern(s)?;
Ok(ValidationRule::Pattern(re))
},
Value::Object(map) => {
let mut rules = Vec::new();
if let Some(Value::String(pattern)) = map.get("pattern") {
rules.push(ValidationRule::Pattern(compile_pattern(pattern)?));
}
if let Some(Value::Number(n)) = map.get("length") {
if let Some(length) = n.as_u64() {
#[allow(clippy::cast_possible_truncation)]
let length_usize = usize::try_from(length).unwrap_or(usize::MAX);
rules.push(ValidationRule::Length(length_usize));
}
}
if let (Some(Value::Number(min)), Some(Value::Number(max))) =
(map.get("min_length"), map.get("max_length"))
{
if let (Some(min_val), Some(max_val)) = (min.as_u64(), max.as_u64()) {
#[allow(clippy::cast_possible_truncation)]
let (min, max) = (
usize::try_from(min_val).unwrap_or(usize::MAX),
usize::try_from(max_val).unwrap_or(usize::MAX),
);
rules.push(ValidationRule::LengthRange { min, max });
}
}
if let Some(Value::String(checksum)) = map.get("checksum") {
let checksum_type = match checksum.as_str() {
"mod97" => ChecksumType::Mod97,
"luhn" => ChecksumType::Luhn,
_ => {
return Err(FraiseQLError::validation(format!(
"Unknown checksum type: {}",
checksum
)));
},
};
rules.push(ValidationRule::Checksum(checksum_type));
}
if let Some(Value::Array(options)) = map.get("enum") {
let enum_values: Vec<String> =
options.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect();
if !enum_values.is_empty() {
rules.push(ValidationRule::Enum(enum_values));
}
}
if let (Some(Value::Number(min)), Some(Value::Number(max))) =
(map.get("min"), map.get("max"))
{
if let (Some(min_val), Some(max_val)) = (min.as_f64(), max.as_f64()) {
rules.push(ValidationRule::NumericRange {
min: min_val,
max: max_val,
});
}
}
if rules.is_empty() {
return Err(FraiseQLError::validation(
"No valid validation rules found".to_string(),
));
}
if rules.len() == 1 {
Ok(rules.into_iter().next().expect("len checked == 1"))
} else {
Ok(ValidationRule::All(rules))
}
},
_ => Err(FraiseQLError::validation(
"Validation rule must be string or object".to_string(),
)),
}
}
}
fn validate_mod97(value: &str) -> Result<()> {
if value.len() < 4 {
return Err(FraiseQLError::validation("IBAN must be at least 4 characters".to_string()));
}
let rearranged = format!("{}{}", &value[4..], &value[..4]);
let numeric_string: String = rearranged
.chars()
.map(|c| {
if c.is_ascii_digit() {
c.to_string()
} else {
((c.to_ascii_uppercase() as u32 - 'A' as u32) + 10).to_string()
}
})
.collect();
let mut remainder: u64 = 0;
for digit_char in numeric_string.chars() {
if let Some(digit) = digit_char.to_digit(10) {
remainder = (remainder * 10 + u64::from(digit)) % 97;
}
}
if remainder == 1 {
Ok(())
} else {
Err(FraiseQLError::validation("Invalid IBAN checksum".to_string()))
}
}
fn validate_luhn(value: &str) -> Result<()> {
let digits: Vec<u32> = value.chars().filter_map(|c| c.to_digit(10)).collect();
if digits.is_empty() {
return Err(FraiseQLError::validation("Value must contain at least one digit".to_string()));
}
let mut sum = 0u32;
let mut is_even = false;
for digit in digits.iter().rev() {
let mut n = *digit;
if is_even {
n *= 2;
if n > 9 {
n -= 9;
}
}
sum += n;
is_even = !is_even;
}
if sum.is_multiple_of(10) {
Ok(())
} else {
Err(FraiseQLError::validation("Invalid Luhn checksum".to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_validation() {
let rule = ValidationRule::Pattern(Regex::new("^[a-z]+$").expect("valid regex"));
rule.validate("hello")
.unwrap_or_else(|e| panic!("expected Ok for 'hello': {e}"));
assert!(rule.validate("Hello").is_err(), "expected Err for 'Hello' (uppercase)");
}
#[test]
fn test_length_validation() {
let rule = ValidationRule::Length(3);
rule.validate("abc")
.unwrap_or_else(|e| panic!("expected Ok for len=3 string: {e}"));
assert!(rule.validate("ab").is_err(), "expected Err for len=2 string");
assert!(rule.validate("abcd").is_err(), "expected Err for len=4 string");
}
#[test]
fn test_mod97_valid() {
let result = validate_mod97("GB82WEST12345698765432");
result.unwrap_or_else(|e| panic!("expected Ok for valid IBAN: {e}"));
}
#[test]
fn test_luhn_valid() {
let result = validate_luhn("4532015112830366");
result.unwrap_or_else(|e| panic!("expected Ok for valid Luhn number: {e}"));
}
#[test]
fn test_enum_validation() {
let rule = ValidationRule::Enum(vec!["US".to_string(), "CA".to_string()]);
rule.validate("US").unwrap_or_else(|e| panic!("expected Ok for 'US': {e}"));
assert!(rule.validate("UK").is_err(), "expected Err for 'UK' (not in enum)");
}
#[test]
fn test_numeric_range_validation() {
let rule = ValidationRule::NumericRange {
min: 0.0,
max: 90.0,
};
rule.validate("45.5").unwrap_or_else(|e| panic!("expected Ok for 45.5: {e}"));
assert!(rule.validate("91").is_err(), "expected Err for 91 (out of range)");
}
}