use regex::Regex;
use std::cmp::Reverse;
use std::collections::HashMap;
use std::sync::LazyLock;
use crate::config::RouteRateLimitConfig;
static UUID_REGEX: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}")
.expect("UUID regex is valid")
});
#[derive(Debug, Clone)]
pub struct CompiledRoutePatterns {
method_exact: HashMap<String, RouteRateLimitConfig>,
exact: HashMap<String, RouteRateLimitConfig>,
patterns: Vec<CompiledPattern>,
}
#[derive(Debug, Clone)]
struct CompiledPattern {
#[allow(dead_code)]
original: String,
method: Option<String>,
regex: Regex,
config: RouteRateLimitConfig,
specificity: usize,
}
impl CompiledRoutePatterns {
pub fn compile(routes: &HashMap<String, RouteRateLimitConfig>) -> Self {
let mut method_exact = HashMap::new();
let mut exact = HashMap::new();
let mut patterns = Vec::new();
for (pattern, config) in routes {
let (method, path) = Self::parse_method_prefix(pattern);
if Self::has_wildcards(&path) || path.contains("{id}") {
let regex = Self::compile_pattern_to_regex(&path);
let specificity = Self::calculate_specificity(&path);
patterns.push(CompiledPattern {
original: pattern.clone(),
method,
regex,
config: config.clone(),
specificity,
});
} else if let Some(m) = method {
let key = format!("{} {}", m, path);
method_exact.insert(key, config.clone());
} else {
exact.insert(path, config.clone());
}
}
patterns.sort_by_key(|p| Reverse(p.specificity));
Self {
method_exact,
exact,
patterns,
}
}
pub fn match_route(&self, method: &str, path: &str) -> Option<&RouteRateLimitConfig> {
let normalized = normalize_path(path);
let method_key = format!("{} {}", method, normalized);
if let Some(config) = self.method_exact.get(&method_key) {
return Some(config);
}
if let Some(config) = self.exact.get(&normalized) {
return Some(config);
}
for pattern in &self.patterns {
if let Some(ref m) = pattern.method {
if m != method {
continue;
}
}
if pattern.regex.is_match(&normalized) {
return Some(&pattern.config);
}
}
None
}
pub fn is_empty(&self) -> bool {
self.method_exact.is_empty() && self.exact.is_empty() && self.patterns.is_empty()
}
fn parse_method_prefix(pattern: &str) -> (Option<String>, String) {
let trimmed = pattern.trim();
let methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"];
for method in methods {
if let Some(rest) = trimmed.strip_prefix(method) {
let rest = rest.trim_start();
if rest.starts_with('/') {
return (Some(method.to_string()), rest.to_string());
}
}
}
(None, trimmed.to_string())
}
fn has_wildcards(path: &str) -> bool {
path.contains('*')
}
fn compile_pattern_to_regex(pattern: &str) -> Regex {
let mut regex_str = String::from("^");
let mut chars = pattern.chars().peekable();
while let Some(c) = chars.next() {
match c {
'*' => {
if chars.peek() == Some(&'*') {
chars.next();
regex_str.push_str(".*");
} else {
regex_str.push_str("[^/]+");
}
}
'{' => {
for c in chars.by_ref() {
if c == '}' {
break;
}
}
regex_str.push_str("[^/]+");
}
'.' | '+' | '?' | '(' | ')' | '[' | ']' | '^' | '$' | '|' | '\\' => {
regex_str.push('\\');
regex_str.push(c);
}
_ => {
regex_str.push(c);
}
}
}
regex_str.push('$');
Regex::new(®ex_str).expect("Generated regex should be valid")
}
fn calculate_specificity(pattern: &str) -> usize {
let mut score = 0;
for segment in pattern.split('/') {
if !segment.is_empty() && !segment.contains('*') && !segment.contains('{') {
score += 10;
} else if segment == "*" {
score += 5; } else if segment == "**" {
score += 1; } else if segment.contains('{') {
score += 7; }
}
score += pattern.len();
score
}
}
pub fn normalize_path(path: &str) -> String {
let normalized = UUID_REGEX.replace_all(path, "{id}");
let segments: Vec<&str> = normalized.split('/').collect();
let normalized_segments: Vec<String> = segments
.iter()
.map(|segment| {
if !segment.is_empty() && segment.chars().all(|c| c.is_ascii_digit()) {
"{id}".to_string()
} else {
segment.to_string()
}
})
.collect();
normalized_segments.join("/")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_path_uuids() {
assert_eq!(
normalize_path("/api/v1/users/550e8400-e29b-41d4-a716-446655440000"),
"/api/v1/users/{id}"
);
}
#[test]
fn test_normalize_path_numeric_ids() {
assert_eq!(normalize_path("/api/v1/users/123"), "/api/v1/users/{id}");
assert_eq!(
normalize_path("/api/v1/users/123/posts/456"),
"/api/v1/users/{id}/posts/{id}"
);
}
#[test]
fn test_normalize_path_preserves_version() {
assert_eq!(normalize_path("/api/v1/users"), "/api/v1/users");
assert_eq!(normalize_path("/api/v2/users/123"), "/api/v2/users/{id}");
}
#[test]
fn test_normalize_path_no_ids() {
assert_eq!(normalize_path("/api/v1/users"), "/api/v1/users");
assert_eq!(normalize_path("/health"), "/health");
}
#[test]
fn test_parse_method_prefix() {
let (method, path) = CompiledRoutePatterns::parse_method_prefix("POST /api/v1/users");
assert_eq!(method, Some("POST".to_string()));
assert_eq!(path, "/api/v1/users");
let (method, path) = CompiledRoutePatterns::parse_method_prefix("/api/v1/users");
assert_eq!(method, None);
assert_eq!(path, "/api/v1/users");
let (method, path) = CompiledRoutePatterns::parse_method_prefix("GET /api/v1/users");
assert_eq!(method, Some("GET".to_string()));
assert_eq!(path, "/api/v1/users");
}
#[test]
fn test_compile_pattern_to_regex() {
let regex = CompiledRoutePatterns::compile_pattern_to_regex("/api/v1/users/*");
assert!(regex.is_match("/api/v1/users/123"));
assert!(regex.is_match("/api/v1/users/abc"));
assert!(!regex.is_match("/api/v1/users/123/posts"));
let regex = CompiledRoutePatterns::compile_pattern_to_regex("/api/*/admin");
assert!(regex.is_match("/api/v1/admin"));
assert!(regex.is_match("/api/v2/admin"));
assert!(!regex.is_match("/api/v1/v2/admin"));
let regex = CompiledRoutePatterns::compile_pattern_to_regex("/api/**/admin");
assert!(regex.is_match("/api/v1/admin"));
assert!(regex.is_match("/api/v1/v2/admin"));
assert!(regex.is_match("/api/foo/bar/baz/admin"));
}
#[test]
fn test_compile_pattern_with_placeholder() {
let regex = CompiledRoutePatterns::compile_pattern_to_regex("/api/v1/users/{id}");
assert!(regex.is_match("/api/v1/users/123"));
assert!(regex.is_match("/api/v1/users/abc"));
assert!(!regex.is_match("/api/v1/users/123/posts"));
}
#[test]
fn test_match_route_exact() {
let mut routes = HashMap::new();
routes.insert(
"/api/v1/users".to_string(),
RouteRateLimitConfig {
requests_per_minute: 100,
burst_size: 10,
per_user: true,
},
);
let patterns = CompiledRoutePatterns::compile(&routes);
let config = patterns.match_route("GET", "/api/v1/users");
assert!(config.is_some());
assert_eq!(config.unwrap().requests_per_minute, 100);
assert!(patterns.match_route("GET", "/api/v1/posts").is_none());
}
#[test]
fn test_match_route_method_prefix() {
let mut routes = HashMap::new();
routes.insert(
"POST /api/v1/uploads".to_string(),
RouteRateLimitConfig {
requests_per_minute: 10,
burst_size: 2,
per_user: true,
},
);
routes.insert(
"/api/v1/uploads".to_string(),
RouteRateLimitConfig {
requests_per_minute: 100,
burst_size: 10,
per_user: true,
},
);
let patterns = CompiledRoutePatterns::compile(&routes);
let config = patterns.match_route("POST", "/api/v1/uploads");
assert!(config.is_some());
assert_eq!(config.unwrap().requests_per_minute, 10);
let config = patterns.match_route("GET", "/api/v1/uploads");
assert!(config.is_some());
assert_eq!(config.unwrap().requests_per_minute, 100);
}
#[test]
fn test_match_route_with_id_normalization() {
let mut routes = HashMap::new();
routes.insert(
"/api/v1/users/{id}".to_string(),
RouteRateLimitConfig {
requests_per_minute: 50,
burst_size: 5,
per_user: true,
},
);
let patterns = CompiledRoutePatterns::compile(&routes);
let config = patterns.match_route("GET", "/api/v1/users/123");
assert!(config.is_some());
assert_eq!(config.unwrap().requests_per_minute, 50);
let config =
patterns.match_route("GET", "/api/v1/users/550e8400-e29b-41d4-a716-446655440000");
assert!(config.is_some());
assert_eq!(config.unwrap().requests_per_minute, 50);
}
#[test]
fn test_match_route_wildcard() {
let mut routes = HashMap::new();
routes.insert(
"/api/*/admin/*".to_string(),
RouteRateLimitConfig {
requests_per_minute: 20,
burst_size: 2,
per_user: true,
},
);
let patterns = CompiledRoutePatterns::compile(&routes);
let config = patterns.match_route("GET", "/api/v1/admin/users");
assert!(config.is_some());
assert_eq!(config.unwrap().requests_per_minute, 20);
let config = patterns.match_route("GET", "/api/v2/admin/settings");
assert!(config.is_some());
assert_eq!(config.unwrap().requests_per_minute, 20);
}
#[test]
fn test_specificity_ordering() {
let mut routes = HashMap::new();
routes.insert(
"/api/v1/*".to_string(),
RouteRateLimitConfig {
requests_per_minute: 100,
burst_size: 10,
per_user: true,
},
);
routes.insert(
"/api/v1/users".to_string(),
RouteRateLimitConfig {
requests_per_minute: 50,
burst_size: 5,
per_user: true,
},
);
let patterns = CompiledRoutePatterns::compile(&routes);
let config = patterns.match_route("GET", "/api/v1/users");
assert!(config.is_some());
assert_eq!(config.unwrap().requests_per_minute, 50);
let config = patterns.match_route("GET", "/api/v1/posts");
assert!(config.is_some());
assert_eq!(config.unwrap().requests_per_minute, 100);
}
}