use regex::Regex;
use std::collections::HashMap;
use super::{ValidationError, ValidationResult, Validator};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Country {
US,
UK,
JP,
CA,
DE,
}
impl Country {
fn pattern(&self) -> &'static str {
match self {
Country::US => r"^\d{5}(-\d{4})?$",
Country::UK => r"^[A-Z]{1,2}\d{1,2}[A-Z]?\s?\d[A-Z]{2}$",
Country::JP => r"^\d{3}-\d{4}$",
Country::CA => r"^[A-Z]\d[A-Z]\s?\d[A-Z]\d$",
Country::DE => r"^\d{5}$",
}
}
fn code(&self) -> &'static str {
match self {
Country::US => "US",
Country::UK => "UK",
Country::JP => "JP",
Country::CA => "CA",
Country::DE => "DE",
}
}
fn all() -> Vec<Country> {
vec![
Country::US,
Country::UK,
Country::JP,
Country::CA,
Country::DE,
]
}
}
#[derive(Debug)]
pub struct PostalCodeValidator {
allowed_countries: Option<Vec<Country>>,
patterns: HashMap<Country, Regex>,
}
impl PostalCodeValidator {
pub fn new() -> Self {
let mut patterns = HashMap::new();
for country in Country::all() {
patterns.insert(
country,
Regex::new(country.pattern()).expect("Invalid regex pattern"),
);
}
Self {
allowed_countries: None,
patterns,
}
}
pub fn with_countries(countries: Vec<Country>) -> Self {
let mut patterns = HashMap::new();
for country in &countries {
patterns.insert(
*country,
Regex::new(country.pattern()).expect("Invalid regex pattern"),
);
}
Self {
allowed_countries: Some(countries),
patterns,
}
}
pub fn for_country(country: Country) -> Self {
Self::with_countries(vec![country])
}
pub fn validate_with_country(&self, value: &str) -> Result<Country, ValidationError> {
let value = value.trim().to_uppercase();
let priority_order = vec![
Country::UK, Country::CA, Country::JP, Country::US, Country::DE, ];
let countries_to_check: Vec<Country> = if let Some(ref allowed) = self.allowed_countries {
priority_order
.into_iter()
.filter(|c| allowed.contains(c))
.collect()
} else {
priority_order
};
for country in countries_to_check {
if let Some(pattern) = self.patterns.get(&country)
&& pattern.is_match(&value)
{
return Ok(country);
}
}
if let Some(ref allowed) = self.allowed_countries {
for country in Country::all() {
if !allowed.contains(&country)
&& let Ok(pattern) = Regex::new(country.pattern())
&& pattern.is_match(&value)
{
return Err(ValidationError::PostalCodeCountryNotAllowed {
country: country.code().to_string(),
allowed_countries: allowed
.iter()
.map(|c| c.code())
.collect::<Vec<_>>()
.join(", "),
});
}
}
}
Err(ValidationError::PostalCodeCountryNotRecognized { postal_code: value })
}
}
impl Default for PostalCodeValidator {
fn default() -> Self {
Self::new()
}
}
impl Validator<str> for PostalCodeValidator {
fn validate(&self, value: &str) -> ValidationResult<()> {
self.validate_with_country(value).map(|_| ())
}
}
impl Validator<String> for PostalCodeValidator {
fn validate(&self, value: &String) -> ValidationResult<()> {
self.validate(value.as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_us_zip_code() {
let validator = PostalCodeValidator::for_country(Country::US);
assert!(validator.validate("12345").is_ok());
assert!(validator.validate("90210").is_ok());
}
#[test]
fn test_us_zip_plus_4() {
let validator = PostalCodeValidator::for_country(Country::US);
assert!(validator.validate("12345-6789").is_ok());
assert!(validator.validate("90210-1234").is_ok());
}
#[test]
fn test_us_invalid_format() {
let validator = PostalCodeValidator::for_country(Country::US);
assert!(validator.validate("1234").is_err()); assert!(validator.validate("123456").is_err()); assert!(validator.validate("ABCDE").is_err()); }
#[test]
fn test_jp_postal_code() {
let validator = PostalCodeValidator::for_country(Country::JP);
assert!(validator.validate("123-4567").is_ok());
assert!(validator.validate("100-0001").is_ok()); }
#[test]
fn test_jp_invalid_format() {
let validator = PostalCodeValidator::for_country(Country::JP);
assert!(validator.validate("1234567").is_err()); assert!(validator.validate("12-34567").is_err()); assert!(validator.validate("ABC-DEFG").is_err()); }
#[test]
fn test_uk_postal_code() {
let validator = PostalCodeValidator::for_country(Country::UK);
assert!(validator.validate("SW1A 1AA").is_ok());
assert!(validator.validate("M1 1AE").is_ok());
assert!(validator.validate("B33 8TH").is_ok());
assert!(validator.validate("CR2 6XH").is_ok());
assert!(validator.validate("DN55 1PT").is_ok());
}
#[test]
fn test_uk_postal_code_without_space() {
let validator = PostalCodeValidator::for_country(Country::UK);
assert!(validator.validate("SW1A1AA").is_ok());
assert!(validator.validate("M11AE").is_ok());
}
#[test]
fn test_ca_postal_code() {
let validator = PostalCodeValidator::for_country(Country::CA);
assert!(validator.validate("K1A 0B1").is_ok());
assert!(validator.validate("M5W 1E6").is_ok()); }
#[test]
fn test_ca_postal_code_without_space() {
let validator = PostalCodeValidator::for_country(Country::CA);
assert!(validator.validate("K1A0B1").is_ok());
}
#[test]
fn test_ca_invalid_format() {
let validator = PostalCodeValidator::for_country(Country::CA);
assert!(validator.validate("K1A 0B").is_err()); assert!(validator.validate("111 111").is_err()); }
#[test]
fn test_de_postal_code() {
let validator = PostalCodeValidator::for_country(Country::DE);
assert!(validator.validate("12345").is_ok());
assert!(validator.validate("10115").is_ok()); assert!(validator.validate("80331").is_ok()); }
#[test]
fn test_de_invalid_format() {
let validator = PostalCodeValidator::for_country(Country::DE);
assert!(validator.validate("1234").is_err()); assert!(validator.validate("123456").is_err()); assert!(validator.validate("ABCDE").is_err()); }
#[test]
fn test_multiple_countries() {
let validator = PostalCodeValidator::with_countries(vec![Country::US, Country::JP]);
assert!(validator.validate("12345").is_ok()); assert!(validator.validate("123-4567").is_ok()); assert!(validator.validate("SW1A 1AA").is_err()); }
#[test]
fn test_all_countries() {
let validator = PostalCodeValidator::new();
assert!(validator.validate("12345").is_ok()); assert!(validator.validate("123-4567").is_ok()); assert!(validator.validate("SW1A 1AA").is_ok()); assert!(validator.validate("K1A 0B1").is_ok()); assert!(validator.validate("10115").is_ok()); }
#[test]
fn test_validate_with_country_detection() {
let validator = PostalCodeValidator::new();
assert_eq!(
validator.validate_with_country("12345-6789").unwrap(),
Country::US
);
assert_eq!(
validator.validate_with_country("123-4567").unwrap(),
Country::JP
);
assert_eq!(
validator.validate_with_country("SW1A 1AA").unwrap(),
Country::UK
);
assert_eq!(
validator.validate_with_country("K1A 0B1").unwrap(),
Country::CA
);
let result = validator.validate_with_country("12345").unwrap();
assert!(result == Country::US || result == Country::DE);
}
#[test]
fn test_validate_with_country_restriction() {
let validator = PostalCodeValidator::with_countries(vec![Country::US, Country::JP]);
assert!(validator.validate_with_country("12345").is_ok());
assert!(validator.validate_with_country("123-4567").is_ok());
match validator.validate_with_country("SW1A 1AA") {
Err(ValidationError::PostalCodeCountryNotAllowed { country, .. }) => {
assert_eq!(country, "UK");
}
_ => panic!("Expected PostalCodeCountryNotAllowed error"),
}
}
#[test]
fn test_invalid_postal_code() {
let validator = PostalCodeValidator::new();
match validator.validate_with_country("invalid") {
Err(ValidationError::PostalCodeCountryNotRecognized { postal_code }) => {
assert_eq!(postal_code, "INVALID");
}
_ => panic!("Expected PostalCodeCountryNotRecognized error"),
}
}
#[test]
fn test_case_insensitive() {
let validator = PostalCodeValidator::new();
assert!(validator.validate("sw1a 1aa").is_ok()); assert!(validator.validate("k1a 0b1").is_ok()); }
#[test]
fn test_whitespace_trimming() {
let validator = PostalCodeValidator::new();
assert!(validator.validate(" 12345 ").is_ok()); assert!(validator.validate(" SW1A 1AA ").is_ok()); }
}