1use std::collections::HashMap;
10use std::time::{Duration, Instant};
11
12use dashmap::DashMap;
13use serde::{Deserialize, Serialize};
14
15use crate::{WafDecision, WafRequest};
16
17const MAX_BUCKETS: usize = 100_000;
18
19#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
21#[serde(rename_all = "snake_case")]
22pub enum KeySource {
23 #[default]
25 Ip,
26 Header(String),
28 Cookie(String),
30}
31
32#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
34#[serde(rename_all = "snake_case")]
35pub enum DelayMode {
36 #[default]
38 NoDelay,
39 Delay,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct RateLimitRule {
46 pub name: String,
48 pub pattern: String,
50 pub rpm: u32,
52 #[serde(default)]
54 pub burst: u32,
55 #[serde(default)]
57 pub key_source: KeySource,
58 #[serde(default)]
60 pub delay_mode: DelayMode,
61}
62
63struct TokenBucket {
65 tokens: f64,
66 max_tokens: f64,
67 refill_rate: f64, last_refill: Instant,
69}
70
71impl TokenBucket {
72 fn new(rpm: u32, burst: u32) -> Self {
73 let max_tokens = (rpm + burst) as f64;
74 let refill_rate = rpm as f64 / 60.0;
75 Self {
76 tokens: max_tokens,
77 max_tokens,
78 refill_rate,
79 last_refill: Instant::now(),
80 }
81 }
82
83 fn try_consume(&mut self) -> (bool, u32, u64) {
85 let now = Instant::now();
86 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
87 self.last_refill = now;
88
89 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
91
92 if self.tokens >= 1.0 {
93 self.tokens -= 1.0;
94 let remaining = self.tokens.floor() as u32;
95 let deficit = self.max_tokens - self.tokens;
97 let reset_secs = if self.refill_rate > 0.0 {
98 (deficit / self.refill_rate).ceil() as u64
99 } else {
100 60
101 };
102 (true, remaining, reset_secs)
103 } else {
104 let wait = if self.refill_rate > 0.0 {
106 ((1.0 - self.tokens) / self.refill_rate).ceil() as u64
107 } else {
108 60
109 };
110 (false, 0, wait)
111 }
112 }
113}
114
115pub struct EnhancedRateLimiter {
117 rules: Vec<RateLimitRule>,
118 buckets: DashMap<String, TokenBucket>,
120}
121
122impl EnhancedRateLimiter {
123 pub fn new(rules: Vec<RateLimitRule>) -> Self {
124 Self {
125 rules,
126 buckets: DashMap::new(),
127 }
128 }
129
130 fn extract_key(&self, rule: &RateLimitRule, req: &WafRequest) -> String {
136 let client_key = match &rule.key_source {
137 KeySource::Ip => normalize_ip_for_rate_limit(req.client_ip),
138 KeySource::Header(name) => {
139 let lower = name.to_lowercase();
140 req.headers
141 .iter()
142 .find(|(k, _)| k.to_lowercase() == lower)
143 .map(|(_, v)| v.clone())
144 .unwrap_or_else(|| req.client_ip.to_string())
145 }
146 KeySource::Cookie(name) => {
147 extract_cookie(&req.headers, name).unwrap_or_else(|| req.client_ip.to_string())
148 }
149 };
150 format!("{}:{}", rule.name, client_key)
151 }
152
153 pub fn check(&self, req: &WafRequest) -> Option<(WafDecision, Vec<(String, String)>)> {
156 for rule in &self.rules {
157 if !path_matches(&rule.pattern, &req.path) {
158 continue;
159 }
160
161 let bucket_key = self.extract_key(rule, req);
162
163 if !self.buckets.contains_key(&bucket_key) && self.buckets.len() >= MAX_BUCKETS {
165 self.cleanup(std::time::Duration::from_secs(60));
166 if self.buckets.len() >= MAX_BUCKETS {
167 let limit = rule.rpm + rule.burst;
168 return Some((
169 WafDecision::RateLimit { retry_after: 1 },
170 vec![
171 ("RateLimit-Limit".into(), limit.to_string()),
172 ("RateLimit-Remaining".into(), "0".to_string()),
173 ("RateLimit-Reset".into(), "1".to_string()),
174 ("Retry-After".into(), "1".to_string()),
175 ],
176 ));
177 }
178 }
179
180 let mut entry = self
181 .buckets
182 .entry(bucket_key)
183 .or_insert_with(|| TokenBucket::new(rule.rpm, rule.burst));
184
185 let (allowed, remaining, reset_secs) = entry.try_consume();
186 let limit = rule.rpm + rule.burst;
187
188 let headers = vec![
189 ("RateLimit-Limit".into(), limit.to_string()),
190 ("RateLimit-Remaining".into(), remaining.to_string()),
191 ("RateLimit-Reset".into(), reset_secs.to_string()),
192 ];
193
194 if !allowed {
195 let mut hdrs = headers;
196 hdrs.push(("Retry-After".into(), reset_secs.to_string()));
197
198 return Some((
199 WafDecision::RateLimit {
200 retry_after: reset_secs,
201 },
202 hdrs,
203 ));
204 }
205
206 }
210
211 None
212 }
213
214 pub fn cleanup(&self, max_age: Duration) {
216 let now = Instant::now();
217 self.buckets
218 .retain(|_, bucket| now.duration_since(bucket.last_refill) < max_age);
219 }
220}
221
222fn normalize_ip_for_rate_limit(ip: std::net::IpAddr) -> String {
229 match ip {
230 std::net::IpAddr::V4(v4) => v4.to_string(),
231 std::net::IpAddr::V6(v6) => {
232 let octets = v6.octets();
233 let prefix = std::net::Ipv6Addr::new(
235 u16::from_be_bytes([octets[0], octets[1]]),
236 u16::from_be_bytes([octets[2], octets[3]]),
237 u16::from_be_bytes([octets[4], octets[5]]),
238 u16::from_be_bytes([octets[6], octets[7]]),
239 0,
240 0,
241 0,
242 0,
243 );
244 format!("{prefix}/64")
245 }
246 }
247}
248
249fn extract_cookie(headers: &HashMap<String, String>, cookie_name: &str) -> Option<String> {
251 let cookie_header = headers
252 .iter()
253 .find(|(k, _)| k.eq_ignore_ascii_case("cookie"))
254 .map(|(_, v)| v)?;
255
256 for part in cookie_header.split(';') {
257 let trimmed = part.trim();
258 if let Some((name, value)) = trimmed.split_once('=') {
259 if name.trim() == cookie_name {
260 return Some(value.trim().to_string());
261 }
262 }
263 }
264 None
265}
266
267fn path_matches(pattern: &str, path: &str) -> bool {
269 if pattern == "*" || pattern == "/**" {
270 return true;
271 }
272 if let Some(prefix) = pattern.strip_suffix("/**") {
273 return path == prefix || path.starts_with(&format!("{prefix}/"));
274 }
275 if let Some(prefix) = pattern.strip_suffix("/*") {
276 if !path.starts_with(&format!("{prefix}/")) && path != prefix {
277 return false;
278 }
279 let rest = &path[prefix.len()..];
281 return rest.matches('/').count() <= 1;
282 }
283 pattern == path
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 fn make_req(ip: &str, path: &str) -> WafRequest {
291 WafRequest {
292 client_ip: ip.parse().unwrap(),
293 method: "GET".into(),
294 path: path.into(),
295 query: None,
296 headers: HashMap::new(),
297 body: None,
298 user_agent: Some("Mozilla/5.0".into()),
299 }
300 }
301
302 fn make_req_with_headers(ip: &str, path: &str, headers: Vec<(&str, &str)>) -> WafRequest {
303 WafRequest {
304 client_ip: ip.parse().unwrap(),
305 method: "GET".into(),
306 path: path.into(),
307 query: None,
308 headers: headers
309 .into_iter()
310 .map(|(k, v)| (k.into(), v.into()))
311 .collect(),
312 body: None,
313 user_agent: Some("Mozilla/5.0".into()),
314 }
315 }
316
317 #[test]
318 fn no_rules_allows_all() {
319 let limiter = EnhancedRateLimiter::new(vec![]);
320 assert!(limiter.check(&make_req("10.0.0.1", "/api/data")).is_none());
321 }
322
323 #[test]
324 fn within_limit_allows() {
325 let rules = vec![RateLimitRule {
326 name: "api".into(),
327 pattern: "/api/**".into(),
328 rpm: 60,
329 burst: 10,
330 key_source: KeySource::Ip,
331 delay_mode: DelayMode::NoDelay,
332 }];
333 let limiter = EnhancedRateLimiter::new(rules);
334 let req = make_req("10.0.0.1", "/api/data");
335 assert!(limiter.check(&req).is_none());
336 }
337
338 #[test]
339 fn exceeds_limit_blocks() {
340 let rules = vec![RateLimitRule {
341 name: "strict".into(),
342 pattern: "/api/**".into(),
343 rpm: 2,
344 burst: 0,
345 key_source: KeySource::Ip,
346 delay_mode: DelayMode::NoDelay,
347 }];
348 let limiter = EnhancedRateLimiter::new(rules);
349 let req = make_req("10.0.0.1", "/api/data");
350
351 assert!(limiter.check(&req).is_none());
353 assert!(limiter.check(&req).is_none());
354
355 let result = limiter.check(&req);
357 assert!(result.is_some());
358 let (decision, headers) = result.unwrap();
359 assert!(matches!(decision, WafDecision::RateLimit { .. }));
360 assert!(headers.iter().any(|(k, _)| k == "Retry-After"));
361 assert!(headers.iter().any(|(k, _)| k == "RateLimit-Limit"));
362 assert!(headers.iter().any(|(k, _)| k == "RateLimit-Remaining"));
363 }
364
365 #[test]
366 fn burst_allows_extra() {
367 let rules = vec![RateLimitRule {
368 name: "burst-test".into(),
369 pattern: "/**".into(),
370 rpm: 2,
371 burst: 3,
372 key_source: KeySource::Ip,
373 delay_mode: DelayMode::NoDelay,
374 }];
375 let limiter = EnhancedRateLimiter::new(rules);
376 let req = make_req("10.0.0.1", "/page");
377
378 for _ in 0..5 {
380 assert!(limiter.check(&req).is_none());
381 }
382
383 assert!(limiter.check(&req).is_some());
385 }
386
387 #[test]
388 fn different_ips_have_separate_limits() {
389 let rules = vec![RateLimitRule {
390 name: "per-ip".into(),
391 pattern: "/**".into(),
392 rpm: 1,
393 burst: 0,
394 key_source: KeySource::Ip,
395 delay_mode: DelayMode::NoDelay,
396 }];
397 let limiter = EnhancedRateLimiter::new(rules);
398
399 assert!(limiter.check(&make_req("10.0.0.1", "/")).is_none());
401 assert!(limiter.check(&make_req("10.0.0.1", "/")).is_some());
402
403 assert!(limiter.check(&make_req("10.0.0.2", "/")).is_none());
405 }
406
407 #[test]
408 fn non_matching_path_skipped() {
409 let rules = vec![RateLimitRule {
410 name: "api-only".into(),
411 pattern: "/api/**".into(),
412 rpm: 1,
413 burst: 0,
414 key_source: KeySource::Ip,
415 delay_mode: DelayMode::NoDelay,
416 }];
417 let limiter = EnhancedRateLimiter::new(rules);
418
419 assert!(limiter
421 .check(&make_req("10.0.0.1", "/static/file.js"))
422 .is_none());
423 assert!(limiter
424 .check(&make_req("10.0.0.1", "/static/file.js"))
425 .is_none());
426 }
427
428 #[test]
429 fn header_key_source() {
430 let rules = vec![RateLimitRule {
431 name: "by-api-key".into(),
432 pattern: "/**".into(),
433 rpm: 1,
434 burst: 0,
435 key_source: KeySource::Header("X-API-Key".into()),
436 delay_mode: DelayMode::NoDelay,
437 }];
438 let limiter = EnhancedRateLimiter::new(rules);
439
440 let req1 = make_req_with_headers("10.0.0.1", "/api", vec![("X-API-Key", "key-a")]);
441 let req2 = make_req_with_headers("10.0.0.2", "/api", vec![("X-API-Key", "key-b")]);
442
443 assert!(limiter.check(&req1).is_none());
445 assert!(limiter.check(&req1).is_some());
446
447 assert!(limiter.check(&req2).is_none());
449 }
450
451 #[test]
452 fn cookie_key_source() {
453 let rules = vec![RateLimitRule {
454 name: "by-session".into(),
455 pattern: "/**".into(),
456 rpm: 1,
457 burst: 0,
458 key_source: KeySource::Cookie("session_id".into()),
459 delay_mode: DelayMode::NoDelay,
460 }];
461 let limiter = EnhancedRateLimiter::new(rules);
462
463 let req = make_req_with_headers(
464 "10.0.0.1",
465 "/",
466 vec![("Cookie", "session_id=abc123; other=val")],
467 );
468 assert!(limiter.check(&req).is_none());
469 assert!(limiter.check(&req).is_some());
470 }
471
472 #[test]
473 fn path_matching() {
474 assert!(path_matches("/api/**", "/api/users"));
475 assert!(path_matches("/api/**", "/api/users/123/details"));
476 assert!(path_matches("/api/**", "/api"));
477 assert!(!path_matches("/api/**", "/static/file"));
478 assert!(path_matches("/**", "/anything"));
479 assert!(path_matches("*", "/anything"));
480 assert!(path_matches("/health", "/health"));
481 assert!(!path_matches("/health", "/healthz"));
482 }
483
484 #[test]
485 fn extract_cookie_works() {
486 let mut headers = HashMap::new();
487 headers.insert("Cookie".into(), "a=1; session_id=abc; b=2".into());
488 assert_eq!(extract_cookie(&headers, "session_id"), Some("abc".into()));
489 assert_eq!(extract_cookie(&headers, "a"), Some("1".into()));
490 assert_eq!(extract_cookie(&headers, "missing"), None);
491 }
492
493 #[test]
494 fn cleanup_removes_stale_buckets() {
495 let rules = vec![RateLimitRule {
496 name: "test".into(),
497 pattern: "/**".into(),
498 rpm: 60,
499 burst: 0,
500 key_source: KeySource::Ip,
501 delay_mode: DelayMode::NoDelay,
502 }];
503 let limiter = EnhancedRateLimiter::new(rules);
504 let req = make_req("10.0.0.1", "/");
505 limiter.check(&req);
506 assert!(!limiter.buckets.is_empty());
507
508 limiter.cleanup(Duration::from_secs(0));
510 assert!(limiter.buckets.is_empty());
511 }
512
513 #[test]
514 fn rate_limit_headers_correct() {
515 let rules = vec![RateLimitRule {
516 name: "strict".into(),
517 pattern: "/**".into(),
518 rpm: 1,
519 burst: 0,
520 key_source: KeySource::Ip,
521 delay_mode: DelayMode::NoDelay,
522 }];
523 let limiter = EnhancedRateLimiter::new(rules);
524 let req = make_req("10.0.0.1", "/");
525
526 limiter.check(&req);
528 let (_, headers) = limiter.check(&req).unwrap();
530
531 let limit = headers
532 .iter()
533 .find(|(k, _)| k == "RateLimit-Limit")
534 .unwrap();
535 assert_eq!(limit.1, "1"); let remaining = headers
538 .iter()
539 .find(|(k, _)| k == "RateLimit-Remaining")
540 .unwrap();
541 assert_eq!(remaining.1, "0");
542
543 let retry = headers.iter().find(|(k, _)| k == "Retry-After").unwrap();
544 assert!(!retry.1.is_empty());
545 }
546}