use std::collections::HashMap;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use crate::{WafDecision, WafRequest};
const MAX_BUCKETS: usize = 100_000;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum KeySource {
#[default]
Ip,
Header(String),
Cookie(String),
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum DelayMode {
#[default]
NoDelay,
Delay,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitRule {
pub name: String,
pub pattern: String,
pub rpm: u32,
#[serde(default)]
pub burst: u32,
#[serde(default)]
pub key_source: KeySource,
#[serde(default)]
pub delay_mode: DelayMode,
}
struct TokenBucket {
tokens: f64,
max_tokens: f64,
refill_rate: f64, last_refill: Instant,
}
impl TokenBucket {
fn new(rpm: u32, burst: u32) -> Self {
let max_tokens = (rpm + burst) as f64;
let refill_rate = rpm as f64 / 60.0;
Self {
tokens: max_tokens,
max_tokens,
refill_rate,
last_refill: Instant::now(),
}
}
fn try_consume(&mut self) -> (bool, u32, u64) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.last_refill = now;
self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
if self.tokens >= 1.0 {
self.tokens -= 1.0;
let remaining = self.tokens.floor() as u32;
let deficit = self.max_tokens - self.tokens;
let reset_secs = if self.refill_rate > 0.0 {
(deficit / self.refill_rate).ceil() as u64
} else {
60
};
(true, remaining, reset_secs)
} else {
let wait = if self.refill_rate > 0.0 {
((1.0 - self.tokens) / self.refill_rate).ceil() as u64
} else {
60
};
(false, 0, wait)
}
}
}
pub struct EnhancedRateLimiter {
rules: Vec<RateLimitRule>,
buckets: DashMap<String, TokenBucket>,
}
impl EnhancedRateLimiter {
pub fn new(rules: Vec<RateLimitRule>) -> Self {
Self {
rules,
buckets: DashMap::new(),
}
}
fn extract_key(&self, rule: &RateLimitRule, req: &WafRequest) -> String {
let client_key = match &rule.key_source {
KeySource::Ip => normalize_ip_for_rate_limit(req.client_ip),
KeySource::Header(name) => {
let lower = name.to_lowercase();
req.headers
.iter()
.find(|(k, _)| k.to_lowercase() == lower)
.map(|(_, v)| v.clone())
.unwrap_or_else(|| req.client_ip.to_string())
}
KeySource::Cookie(name) => {
extract_cookie(&req.headers, name).unwrap_or_else(|| req.client_ip.to_string())
}
};
format!("{}:{}", rule.name, client_key)
}
pub fn check(&self, req: &WafRequest) -> Option<(WafDecision, Vec<(String, String)>)> {
for rule in &self.rules {
if !path_matches(&rule.pattern, &req.path) {
continue;
}
let bucket_key = self.extract_key(rule, req);
if !self.buckets.contains_key(&bucket_key) && self.buckets.len() >= MAX_BUCKETS {
self.cleanup(std::time::Duration::from_secs(60));
if self.buckets.len() >= MAX_BUCKETS {
let limit = rule.rpm + rule.burst;
return Some((
WafDecision::RateLimit { retry_after: 1 },
vec![
("RateLimit-Limit".into(), limit.to_string()),
("RateLimit-Remaining".into(), "0".to_string()),
("RateLimit-Reset".into(), "1".to_string()),
("Retry-After".into(), "1".to_string()),
],
));
}
}
let mut entry = self
.buckets
.entry(bucket_key)
.or_insert_with(|| TokenBucket::new(rule.rpm, rule.burst));
let (allowed, remaining, reset_secs) = entry.try_consume();
let limit = rule.rpm + rule.burst;
let headers = vec![
("RateLimit-Limit".into(), limit.to_string()),
("RateLimit-Remaining".into(), remaining.to_string()),
("RateLimit-Reset".into(), reset_secs.to_string()),
];
if !allowed {
let mut hdrs = headers;
hdrs.push(("Retry-After".into(), reset_secs.to_string()));
return Some((
WafDecision::RateLimit {
retry_after: reset_secs,
},
hdrs,
));
}
}
None
}
pub fn cleanup(&self, max_age: Duration) {
let now = Instant::now();
self.buckets
.retain(|_, bucket| now.duration_since(bucket.last_refill) < max_age);
}
}
fn normalize_ip_for_rate_limit(ip: std::net::IpAddr) -> String {
match ip {
std::net::IpAddr::V4(v4) => v4.to_string(),
std::net::IpAddr::V6(v6) => {
let octets = v6.octets();
let prefix = std::net::Ipv6Addr::new(
u16::from_be_bytes([octets[0], octets[1]]),
u16::from_be_bytes([octets[2], octets[3]]),
u16::from_be_bytes([octets[4], octets[5]]),
u16::from_be_bytes([octets[6], octets[7]]),
0,
0,
0,
0,
);
format!("{prefix}/64")
}
}
}
fn extract_cookie(headers: &HashMap<String, String>, cookie_name: &str) -> Option<String> {
let cookie_header = headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("cookie"))
.map(|(_, v)| v)?;
for part in cookie_header.split(';') {
let trimmed = part.trim();
if let Some((name, value)) = trimmed.split_once('=') {
if name.trim() == cookie_name {
return Some(value.trim().to_string());
}
}
}
None
}
fn path_matches(pattern: &str, path: &str) -> bool {
if pattern == "*" || pattern == "/**" {
return true;
}
if let Some(prefix) = pattern.strip_suffix("/**") {
return path == prefix || path.starts_with(&format!("{prefix}/"));
}
if let Some(prefix) = pattern.strip_suffix("/*") {
if !path.starts_with(&format!("{prefix}/")) && path != prefix {
return false;
}
let rest = &path[prefix.len()..];
return rest.matches('/').count() <= 1;
}
pattern == path
}
#[cfg(test)]
mod tests {
use super::*;
fn make_req(ip: &str, path: &str) -> WafRequest {
WafRequest {
client_ip: ip.parse().unwrap(),
method: "GET".into(),
path: path.into(),
query: None,
headers: HashMap::new(),
body: None,
user_agent: Some("Mozilla/5.0".into()),
}
}
fn make_req_with_headers(ip: &str, path: &str, headers: Vec<(&str, &str)>) -> WafRequest {
WafRequest {
client_ip: ip.parse().unwrap(),
method: "GET".into(),
path: path.into(),
query: None,
headers: headers
.into_iter()
.map(|(k, v)| (k.into(), v.into()))
.collect(),
body: None,
user_agent: Some("Mozilla/5.0".into()),
}
}
#[test]
fn no_rules_allows_all() {
let limiter = EnhancedRateLimiter::new(vec![]);
assert!(limiter.check(&make_req("10.0.0.1", "/api/data")).is_none());
}
#[test]
fn within_limit_allows() {
let rules = vec![RateLimitRule {
name: "api".into(),
pattern: "/api/**".into(),
rpm: 60,
burst: 10,
key_source: KeySource::Ip,
delay_mode: DelayMode::NoDelay,
}];
let limiter = EnhancedRateLimiter::new(rules);
let req = make_req("10.0.0.1", "/api/data");
assert!(limiter.check(&req).is_none());
}
#[test]
fn exceeds_limit_blocks() {
let rules = vec![RateLimitRule {
name: "strict".into(),
pattern: "/api/**".into(),
rpm: 2,
burst: 0,
key_source: KeySource::Ip,
delay_mode: DelayMode::NoDelay,
}];
let limiter = EnhancedRateLimiter::new(rules);
let req = make_req("10.0.0.1", "/api/data");
assert!(limiter.check(&req).is_none());
assert!(limiter.check(&req).is_none());
let result = limiter.check(&req);
assert!(result.is_some());
let (decision, headers) = result.unwrap();
assert!(matches!(decision, WafDecision::RateLimit { .. }));
assert!(headers.iter().any(|(k, _)| k == "Retry-After"));
assert!(headers.iter().any(|(k, _)| k == "RateLimit-Limit"));
assert!(headers.iter().any(|(k, _)| k == "RateLimit-Remaining"));
}
#[test]
fn burst_allows_extra() {
let rules = vec![RateLimitRule {
name: "burst-test".into(),
pattern: "/**".into(),
rpm: 2,
burst: 3,
key_source: KeySource::Ip,
delay_mode: DelayMode::NoDelay,
}];
let limiter = EnhancedRateLimiter::new(rules);
let req = make_req("10.0.0.1", "/page");
for _ in 0..5 {
assert!(limiter.check(&req).is_none());
}
assert!(limiter.check(&req).is_some());
}
#[test]
fn different_ips_have_separate_limits() {
let rules = vec![RateLimitRule {
name: "per-ip".into(),
pattern: "/**".into(),
rpm: 1,
burst: 0,
key_source: KeySource::Ip,
delay_mode: DelayMode::NoDelay,
}];
let limiter = EnhancedRateLimiter::new(rules);
assert!(limiter.check(&make_req("10.0.0.1", "/")).is_none());
assert!(limiter.check(&make_req("10.0.0.1", "/")).is_some());
assert!(limiter.check(&make_req("10.0.0.2", "/")).is_none());
}
#[test]
fn non_matching_path_skipped() {
let rules = vec![RateLimitRule {
name: "api-only".into(),
pattern: "/api/**".into(),
rpm: 1,
burst: 0,
key_source: KeySource::Ip,
delay_mode: DelayMode::NoDelay,
}];
let limiter = EnhancedRateLimiter::new(rules);
assert!(limiter
.check(&make_req("10.0.0.1", "/static/file.js"))
.is_none());
assert!(limiter
.check(&make_req("10.0.0.1", "/static/file.js"))
.is_none());
}
#[test]
fn header_key_source() {
let rules = vec![RateLimitRule {
name: "by-api-key".into(),
pattern: "/**".into(),
rpm: 1,
burst: 0,
key_source: KeySource::Header("X-API-Key".into()),
delay_mode: DelayMode::NoDelay,
}];
let limiter = EnhancedRateLimiter::new(rules);
let req1 = make_req_with_headers("10.0.0.1", "/api", vec![("X-API-Key", "key-a")]);
let req2 = make_req_with_headers("10.0.0.2", "/api", vec![("X-API-Key", "key-b")]);
assert!(limiter.check(&req1).is_none());
assert!(limiter.check(&req1).is_some());
assert!(limiter.check(&req2).is_none());
}
#[test]
fn cookie_key_source() {
let rules = vec![RateLimitRule {
name: "by-session".into(),
pattern: "/**".into(),
rpm: 1,
burst: 0,
key_source: KeySource::Cookie("session_id".into()),
delay_mode: DelayMode::NoDelay,
}];
let limiter = EnhancedRateLimiter::new(rules);
let req = make_req_with_headers(
"10.0.0.1",
"/",
vec![("Cookie", "session_id=abc123; other=val")],
);
assert!(limiter.check(&req).is_none());
assert!(limiter.check(&req).is_some());
}
#[test]
fn path_matching() {
assert!(path_matches("/api/**", "/api/users"));
assert!(path_matches("/api/**", "/api/users/123/details"));
assert!(path_matches("/api/**", "/api"));
assert!(!path_matches("/api/**", "/static/file"));
assert!(path_matches("/**", "/anything"));
assert!(path_matches("*", "/anything"));
assert!(path_matches("/health", "/health"));
assert!(!path_matches("/health", "/healthz"));
}
#[test]
fn extract_cookie_works() {
let mut headers = HashMap::new();
headers.insert("Cookie".into(), "a=1; session_id=abc; b=2".into());
assert_eq!(extract_cookie(&headers, "session_id"), Some("abc".into()));
assert_eq!(extract_cookie(&headers, "a"), Some("1".into()));
assert_eq!(extract_cookie(&headers, "missing"), None);
}
#[test]
fn cleanup_removes_stale_buckets() {
let rules = vec![RateLimitRule {
name: "test".into(),
pattern: "/**".into(),
rpm: 60,
burst: 0,
key_source: KeySource::Ip,
delay_mode: DelayMode::NoDelay,
}];
let limiter = EnhancedRateLimiter::new(rules);
let req = make_req("10.0.0.1", "/");
limiter.check(&req);
assert!(!limiter.buckets.is_empty());
limiter.cleanup(Duration::from_secs(0));
assert!(limiter.buckets.is_empty());
}
#[test]
fn rate_limit_headers_correct() {
let rules = vec![RateLimitRule {
name: "strict".into(),
pattern: "/**".into(),
rpm: 1,
burst: 0,
key_source: KeySource::Ip,
delay_mode: DelayMode::NoDelay,
}];
let limiter = EnhancedRateLimiter::new(rules);
let req = make_req("10.0.0.1", "/");
limiter.check(&req);
let (_, headers) = limiter.check(&req).unwrap();
let limit = headers
.iter()
.find(|(k, _)| k == "RateLimit-Limit")
.unwrap();
assert_eq!(limit.1, "1");
let remaining = headers
.iter()
.find(|(k, _)| k == "RateLimit-Remaining")
.unwrap();
assert_eq!(remaining.1, "0");
let retry = headers.iter().find(|(k, _)| k == "Retry-After").unwrap();
assert!(!retry.1.is_empty());
}
}