use crate::errors::{AshError, AshErrorCode, InternalReason};
pub const MAX_BINDING_VALUE_LENGTH: usize = 8192;
pub const MIN_BINDING_VALUE_LENGTH: usize = 1;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BindingType {
Route,
Ip,
Device,
Session,
User,
Tenant,
Custom,
}
impl std::fmt::Display for BindingType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BindingType::Route => write!(f, "route"),
BindingType::Ip => write!(f, "ip"),
BindingType::Device => write!(f, "device"),
BindingType::Session => write!(f, "session"),
BindingType::User => write!(f, "user"),
BindingType::Tenant => write!(f, "tenant"),
BindingType::Custom => write!(f, "custom"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NormalizedBindingValue {
pub value: String,
pub binding_type: BindingType,
pub original_length: usize,
pub was_trimmed: bool,
}
pub fn ash_normalize_binding_value(
binding_type: BindingType,
value: &str,
) -> Result<NormalizedBindingValue, AshError> {
let original_length = value.len();
let trimmed = value.trim();
let was_trimmed = trimmed.len() != original_length;
if trimmed.is_empty() {
return Err(AshError::with_reason(
AshErrorCode::ValidationError,
InternalReason::General,
format!("Binding value for '{}' cannot be empty", binding_type),
));
}
if trimmed.len() > MAX_BINDING_VALUE_LENGTH {
return Err(AshError::with_reason(
AshErrorCode::ValidationError,
InternalReason::General,
format!(
"Binding value for '{}' exceeds maximum length of {} bytes",
binding_type, MAX_BINDING_VALUE_LENGTH
),
));
}
for (i, ch) in trimmed.char_indices() {
if ch == '\0' {
return Err(AshError::with_reason(
AshErrorCode::ValidationError,
InternalReason::General,
format!(
"Binding value for '{}' contains NULL byte at position {}",
binding_type, i
),
));
}
if ch == '\r' || ch == '\n' {
return Err(AshError::with_reason(
AshErrorCode::ValidationError,
InternalReason::General,
format!(
"Binding value for '{}' contains newline at position {}",
binding_type, i
),
));
}
if ch.is_control() {
return Err(AshError::with_reason(
AshErrorCode::ValidationError,
InternalReason::General,
format!(
"Binding value for '{}' contains control character at position {}",
binding_type, i
),
));
}
}
match binding_type {
BindingType::Route => {
return Err(AshError::new(
AshErrorCode::ValidationError,
"Use ash_normalize_binding() for Route bindings — it has specialized path/query normalization",
));
}
BindingType::Ip => {
if !trimmed.is_ascii() {
return Err(AshError::with_reason(
AshErrorCode::ValidationError,
InternalReason::General,
"IP binding must contain only ASCII characters",
));
}
if trimmed.contains(' ') {
return Err(AshError::with_reason(
AshErrorCode::ValidationError,
InternalReason::General,
"IP binding must not contain spaces",
));
}
}
BindingType::User => {
use unicode_normalization::UnicodeNormalization;
let normalized: String = trimmed.nfc().collect();
return Ok(NormalizedBindingValue {
value: normalized,
binding_type,
original_length,
was_trimmed,
});
}
_ => {}
}
Ok(NormalizedBindingValue {
value: trimmed.to_string(),
binding_type,
original_length,
was_trimmed,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trim_whitespace() {
let r = ash_normalize_binding_value(BindingType::Device, " dev_123 ").unwrap();
assert_eq!(r.value, "dev_123");
assert!(r.was_trimmed);
}
#[test]
fn test_no_trim_needed() {
let r = ash_normalize_binding_value(BindingType::Device, "dev_123").unwrap();
assert_eq!(r.value, "dev_123");
assert!(!r.was_trimmed);
}
#[test]
fn test_reject_empty() {
assert!(ash_normalize_binding_value(BindingType::Session, "").is_err());
}
#[test]
fn test_reject_whitespace_only() {
assert!(ash_normalize_binding_value(BindingType::Session, " ").is_err());
}
#[test]
fn test_reject_null_byte() {
assert!(ash_normalize_binding_value(BindingType::Device, "dev\x00abc").is_err());
}
#[test]
fn test_reject_newline() {
assert!(ash_normalize_binding_value(BindingType::Device, "dev\nabc").is_err());
assert!(ash_normalize_binding_value(BindingType::Device, "dev\rabc").is_err());
}
#[test]
fn test_reject_control_chars() {
assert!(ash_normalize_binding_value(BindingType::Device, "dev\x01abc").is_err());
assert!(ash_normalize_binding_value(BindingType::Device, "dev\x1Fabc").is_err());
}
#[test]
fn test_reject_too_long() {
let long = "a".repeat(MAX_BINDING_VALUE_LENGTH + 1);
assert!(ash_normalize_binding_value(BindingType::Custom, &long).is_err());
}
#[test]
fn test_accept_max_length() {
let max = "a".repeat(MAX_BINDING_VALUE_LENGTH);
assert!(ash_normalize_binding_value(BindingType::Custom, &max).is_ok());
}
#[test]
fn test_route_type_rejected() {
let err = ash_normalize_binding_value(BindingType::Route, "POST|/api|").unwrap_err();
assert!(err.message().contains("ash_normalize_binding"));
}
#[test]
fn test_ip_valid_ipv4() {
let r = ash_normalize_binding_value(BindingType::Ip, "192.168.1.1").unwrap();
assert_eq!(r.value, "192.168.1.1");
}
#[test]
fn test_ip_valid_ipv6() {
let r = ash_normalize_binding_value(BindingType::Ip, "::1").unwrap();
assert_eq!(r.value, "::1");
}
#[test]
fn test_ip_trimmed() {
let r = ash_normalize_binding_value(BindingType::Ip, " 10.0.0.1 ").unwrap();
assert_eq!(r.value, "10.0.0.1");
assert!(r.was_trimmed);
}
#[test]
fn test_ip_reject_non_ascii() {
assert!(ash_normalize_binding_value(BindingType::Ip, "192.168.١.1").is_err());
}
#[test]
fn test_ip_reject_spaces() {
assert!(ash_normalize_binding_value(BindingType::Ip, "192.168.1.1 extra").is_err());
}
#[test]
fn test_user_nfc_normalization() {
let decomposed = "caf\u{0065}\u{0301}";
let r = ash_normalize_binding_value(BindingType::User, decomposed).unwrap();
assert_eq!(r.value, "café");
}
#[test]
fn test_user_already_nfc() {
let r = ash_normalize_binding_value(BindingType::User, "user@example.com").unwrap();
assert_eq!(r.value, "user@example.com");
}
#[test]
fn test_binding_type_preserved() {
let r = ash_normalize_binding_value(BindingType::Tenant, "acme").unwrap();
assert_eq!(r.binding_type, BindingType::Tenant);
}
#[test]
fn test_original_length_tracked() {
let r = ash_normalize_binding_value(BindingType::Device, " abc ").unwrap();
assert_eq!(r.original_length, 7);
assert_eq!(r.value, "abc");
}
#[test]
fn test_custom_type_accepts_unicode() {
let r = ash_normalize_binding_value(BindingType::Custom, "مستخدم").unwrap();
assert_eq!(r.value, "مستخدم");
}
#[test]
fn test_binding_type_display() {
assert_eq!(BindingType::Route.to_string(), "route");
assert_eq!(BindingType::Ip.to_string(), "ip");
assert_eq!(BindingType::Device.to_string(), "device");
assert_eq!(BindingType::Session.to_string(), "session");
assert_eq!(BindingType::User.to_string(), "user");
assert_eq!(BindingType::Tenant.to_string(), "tenant");
assert_eq!(BindingType::Custom.to_string(), "custom");
}
}