Skip to main content

bext_waf/
rate_limit.rs

1//! Token-bucket rate limiting with per-rule path matching.
2//!
3//! Each [`RateLimitRule`] targets a URL pattern and enforces requests-per-minute
4//! with burst capacity.  Keys can be extracted from the client IP (with IPv6 /64
5//! normalisation), a named header, or a cookie.  The limiter emits standard
6//! `RateLimit-Limit`, `RateLimit-Remaining`, and `Retry-After` response headers.
7//! Bucket count is capped at 100 000 to bound memory usage.
8
9use std::collections::HashMap;
10use std::time::{Duration, Instant};
11
12use dashmap::DashMap;
13use serde::{Deserialize, Serialize};
14
15use crate::{WafDecision, WafRequest};
16
17const MAX_BUCKETS: usize = 100_000;
18
19/// Source for the rate-limit key.
20#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
21#[serde(rename_all = "snake_case")]
22pub enum KeySource {
23    /// Key by client IP address.
24    #[default]
25    Ip,
26    /// Key by a specific header value.
27    Header(String),
28    /// Key by a specific cookie value.
29    Cookie(String),
30}
31
32/// Delay mode when rate limit is approached.
33#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
34#[serde(rename_all = "snake_case")]
35pub enum DelayMode {
36    /// Reject immediately when over limit.
37    #[default]
38    NoDelay,
39    /// Delay the response (not implemented at WAF layer — signals the server to add delay).
40    Delay,
41}
42
43/// A single rate-limiting rule.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct RateLimitRule {
46    /// Rule name for identification.
47    pub name: String,
48    /// Path glob pattern to match.
49    pub pattern: String,
50    /// Requests per minute (sustained rate).
51    pub rpm: u32,
52    /// Burst allowance above rpm.
53    #[serde(default)]
54    pub burst: u32,
55    /// What to key on.
56    #[serde(default)]
57    pub key_source: KeySource,
58    /// Delay mode.
59    #[serde(default)]
60    pub delay_mode: DelayMode,
61}
62
63/// Token bucket state for a single key.
64struct TokenBucket {
65    tokens: f64,
66    max_tokens: f64,
67    refill_rate: f64, // tokens per second
68    last_refill: Instant,
69}
70
71impl TokenBucket {
72    fn new(rpm: u32, burst: u32) -> Self {
73        let max_tokens = (rpm + burst) as f64;
74        let refill_rate = rpm as f64 / 60.0;
75        Self {
76            tokens: max_tokens,
77            max_tokens,
78            refill_rate,
79            last_refill: Instant::now(),
80        }
81    }
82
83    /// Try to consume one token. Returns (allowed, remaining, reset_after_secs).
84    fn try_consume(&mut self) -> (bool, u32, u64) {
85        let now = Instant::now();
86        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
87        self.last_refill = now;
88
89        // Refill tokens.
90        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
91
92        if self.tokens >= 1.0 {
93            self.tokens -= 1.0;
94            let remaining = self.tokens.floor() as u32;
95            // Reset is when bucket would be full again.
96            let deficit = self.max_tokens - self.tokens;
97            let reset_secs = if self.refill_rate > 0.0 {
98                (deficit / self.refill_rate).ceil() as u64
99            } else {
100                60
101            };
102            (true, remaining, reset_secs)
103        } else {
104            // Time until next token.
105            let wait = if self.refill_rate > 0.0 {
106                ((1.0 - self.tokens) / self.refill_rate).ceil() as u64
107            } else {
108                60
109            };
110            (false, 0, wait)
111        }
112    }
113}
114
115/// Enhanced rate limiter with per-rule token buckets.
116pub struct EnhancedRateLimiter {
117    rules: Vec<RateLimitRule>,
118    // Key: "rule_name:client_key" → TokenBucket
119    buckets: DashMap<String, TokenBucket>,
120}
121
122impl EnhancedRateLimiter {
123    pub fn new(rules: Vec<RateLimitRule>) -> Self {
124        Self {
125            rules,
126            buckets: DashMap::new(),
127        }
128    }
129
130    /// Extract the rate-limit key for a given rule and request.
131    ///
132    /// For IP-based keys, IPv6 addresses are normalized to their /64 prefix
133    /// to prevent rate-limit bypass via IPv6 address rotation (an attacker
134    /// with a /64 block has 2^64 individual addresses).
135    fn extract_key(&self, rule: &RateLimitRule, req: &WafRequest) -> String {
136        let client_key = match &rule.key_source {
137            KeySource::Ip => normalize_ip_for_rate_limit(req.client_ip),
138            KeySource::Header(name) => {
139                let lower = name.to_lowercase();
140                req.headers
141                    .iter()
142                    .find(|(k, _)| k.to_lowercase() == lower)
143                    .map(|(_, v)| v.clone())
144                    .unwrap_or_else(|| req.client_ip.to_string())
145            }
146            KeySource::Cookie(name) => {
147                extract_cookie(&req.headers, name).unwrap_or_else(|| req.client_ip.to_string())
148            }
149        };
150        format!("{}:{}", rule.name, client_key)
151    }
152
153    /// Check a request against all rate-limit rules.
154    /// Returns `(decision, response_headers)`.
155    pub fn check(&self, req: &WafRequest) -> Option<(WafDecision, Vec<(String, String)>)> {
156        for rule in &self.rules {
157            if !path_matches(&rule.pattern, &req.path) {
158                continue;
159            }
160
161            let bucket_key = self.extract_key(rule, req);
162
163            // Guard against unbounded DashMap growth
164            if !self.buckets.contains_key(&bucket_key) && self.buckets.len() >= MAX_BUCKETS {
165                self.cleanup(std::time::Duration::from_secs(60));
166                if self.buckets.len() >= MAX_BUCKETS {
167                    let limit = rule.rpm + rule.burst;
168                    return Some((
169                        WafDecision::RateLimit { retry_after: 1 },
170                        vec![
171                            ("RateLimit-Limit".into(), limit.to_string()),
172                            ("RateLimit-Remaining".into(), "0".to_string()),
173                            ("RateLimit-Reset".into(), "1".to_string()),
174                            ("Retry-After".into(), "1".to_string()),
175                        ],
176                    ));
177                }
178            }
179
180            let mut entry = self
181                .buckets
182                .entry(bucket_key)
183                .or_insert_with(|| TokenBucket::new(rule.rpm, rule.burst));
184
185            let (allowed, remaining, reset_secs) = entry.try_consume();
186            let limit = rule.rpm + rule.burst;
187
188            let headers = vec![
189                ("RateLimit-Limit".into(), limit.to_string()),
190                ("RateLimit-Remaining".into(), remaining.to_string()),
191                ("RateLimit-Reset".into(), reset_secs.to_string()),
192            ];
193
194            if !allowed {
195                let mut hdrs = headers;
196                hdrs.push(("Retry-After".into(), reset_secs.to_string()));
197
198                return Some((
199                    WafDecision::RateLimit {
200                        retry_after: reset_secs,
201                    },
202                    hdrs,
203                ));
204            }
205
206            // If allowed, we still might want to return headers for the first matching rule.
207            // But per the API contract, we return None for allow.
208            // We'll return headers only on rate limit.
209        }
210
211        None
212    }
213
214    /// Clean up stale buckets that haven't been accessed recently.
215    pub fn cleanup(&self, max_age: Duration) {
216        let now = Instant::now();
217        self.buckets
218            .retain(|_, bucket| now.duration_since(bucket.last_refill) < max_age);
219    }
220}
221
222/// Normalize an IP address for rate-limit keying.
223///
224/// IPv6 addresses are masked to their /64 prefix so that all addresses
225/// within the same /64 block share a single rate-limit bucket. This
226/// prevents bypass via IPv6 address rotation (a /64 has 2^64 addresses).
227/// IPv4 addresses are returned as-is.
228fn normalize_ip_for_rate_limit(ip: std::net::IpAddr) -> String {
229    match ip {
230        std::net::IpAddr::V4(v4) => v4.to_string(),
231        std::net::IpAddr::V6(v6) => {
232            let octets = v6.octets();
233            // Zero out the interface identifier (last 8 bytes) to get /64 prefix.
234            let prefix = std::net::Ipv6Addr::new(
235                u16::from_be_bytes([octets[0], octets[1]]),
236                u16::from_be_bytes([octets[2], octets[3]]),
237                u16::from_be_bytes([octets[4], octets[5]]),
238                u16::from_be_bytes([octets[6], octets[7]]),
239                0,
240                0,
241                0,
242                0,
243            );
244            format!("{prefix}/64")
245        }
246    }
247}
248
249/// Extract a cookie value from headers.
250fn extract_cookie(headers: &HashMap<String, String>, cookie_name: &str) -> Option<String> {
251    let cookie_header = headers
252        .iter()
253        .find(|(k, _)| k.eq_ignore_ascii_case("cookie"))
254        .map(|(_, v)| v)?;
255
256    for part in cookie_header.split(';') {
257        let trimmed = part.trim();
258        if let Some((name, value)) = trimmed.split_once('=') {
259            if name.trim() == cookie_name {
260                return Some(value.trim().to_string());
261            }
262        }
263    }
264    None
265}
266
267/// Simple path matching: supports exact match, prefix with `*`, and glob `**`.
268fn path_matches(pattern: &str, path: &str) -> bool {
269    if pattern == "*" || pattern == "/**" {
270        return true;
271    }
272    if let Some(prefix) = pattern.strip_suffix("/**") {
273        return path == prefix || path.starts_with(&format!("{prefix}/"));
274    }
275    if let Some(prefix) = pattern.strip_suffix("/*") {
276        if !path.starts_with(&format!("{prefix}/")) && path != prefix {
277            return false;
278        }
279        // Must be exactly one segment after prefix.
280        let rest = &path[prefix.len()..];
281        return rest.matches('/').count() <= 1;
282    }
283    pattern == path
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    fn make_req(ip: &str, path: &str) -> WafRequest {
291        WafRequest {
292            client_ip: ip.parse().unwrap(),
293            method: "GET".into(),
294            path: path.into(),
295            query: None,
296            headers: HashMap::new(),
297            body: None,
298            user_agent: Some("Mozilla/5.0".into()),
299        }
300    }
301
302    fn make_req_with_headers(ip: &str, path: &str, headers: Vec<(&str, &str)>) -> WafRequest {
303        WafRequest {
304            client_ip: ip.parse().unwrap(),
305            method: "GET".into(),
306            path: path.into(),
307            query: None,
308            headers: headers
309                .into_iter()
310                .map(|(k, v)| (k.into(), v.into()))
311                .collect(),
312            body: None,
313            user_agent: Some("Mozilla/5.0".into()),
314        }
315    }
316
317    #[test]
318    fn no_rules_allows_all() {
319        let limiter = EnhancedRateLimiter::new(vec![]);
320        assert!(limiter.check(&make_req("10.0.0.1", "/api/data")).is_none());
321    }
322
323    #[test]
324    fn within_limit_allows() {
325        let rules = vec![RateLimitRule {
326            name: "api".into(),
327            pattern: "/api/**".into(),
328            rpm: 60,
329            burst: 10,
330            key_source: KeySource::Ip,
331            delay_mode: DelayMode::NoDelay,
332        }];
333        let limiter = EnhancedRateLimiter::new(rules);
334        let req = make_req("10.0.0.1", "/api/data");
335        assert!(limiter.check(&req).is_none());
336    }
337
338    #[test]
339    fn exceeds_limit_blocks() {
340        let rules = vec![RateLimitRule {
341            name: "strict".into(),
342            pattern: "/api/**".into(),
343            rpm: 2,
344            burst: 0,
345            key_source: KeySource::Ip,
346            delay_mode: DelayMode::NoDelay,
347        }];
348        let limiter = EnhancedRateLimiter::new(rules);
349        let req = make_req("10.0.0.1", "/api/data");
350
351        // First 2 requests: allowed (initial tokens = rpm + burst = 2).
352        assert!(limiter.check(&req).is_none());
353        assert!(limiter.check(&req).is_none());
354
355        // Third request: should be rate limited.
356        let result = limiter.check(&req);
357        assert!(result.is_some());
358        let (decision, headers) = result.unwrap();
359        assert!(matches!(decision, WafDecision::RateLimit { .. }));
360        assert!(headers.iter().any(|(k, _)| k == "Retry-After"));
361        assert!(headers.iter().any(|(k, _)| k == "RateLimit-Limit"));
362        assert!(headers.iter().any(|(k, _)| k == "RateLimit-Remaining"));
363    }
364
365    #[test]
366    fn burst_allows_extra() {
367        let rules = vec![RateLimitRule {
368            name: "burst-test".into(),
369            pattern: "/**".into(),
370            rpm: 2,
371            burst: 3,
372            key_source: KeySource::Ip,
373            delay_mode: DelayMode::NoDelay,
374        }];
375        let limiter = EnhancedRateLimiter::new(rules);
376        let req = make_req("10.0.0.1", "/page");
377
378        // Should allow rpm + burst = 5 requests.
379        for _ in 0..5 {
380            assert!(limiter.check(&req).is_none());
381        }
382
383        // 6th should be limited.
384        assert!(limiter.check(&req).is_some());
385    }
386
387    #[test]
388    fn different_ips_have_separate_limits() {
389        let rules = vec![RateLimitRule {
390            name: "per-ip".into(),
391            pattern: "/**".into(),
392            rpm: 1,
393            burst: 0,
394            key_source: KeySource::Ip,
395            delay_mode: DelayMode::NoDelay,
396        }];
397        let limiter = EnhancedRateLimiter::new(rules);
398
399        // ip1 uses its one token.
400        assert!(limiter.check(&make_req("10.0.0.1", "/")).is_none());
401        assert!(limiter.check(&make_req("10.0.0.1", "/")).is_some());
402
403        // ip2 should have its own bucket.
404        assert!(limiter.check(&make_req("10.0.0.2", "/")).is_none());
405    }
406
407    #[test]
408    fn non_matching_path_skipped() {
409        let rules = vec![RateLimitRule {
410            name: "api-only".into(),
411            pattern: "/api/**".into(),
412            rpm: 1,
413            burst: 0,
414            key_source: KeySource::Ip,
415            delay_mode: DelayMode::NoDelay,
416        }];
417        let limiter = EnhancedRateLimiter::new(rules);
418
419        // /static/ doesn't match /api/**.
420        assert!(limiter
421            .check(&make_req("10.0.0.1", "/static/file.js"))
422            .is_none());
423        assert!(limiter
424            .check(&make_req("10.0.0.1", "/static/file.js"))
425            .is_none());
426    }
427
428    #[test]
429    fn header_key_source() {
430        let rules = vec![RateLimitRule {
431            name: "by-api-key".into(),
432            pattern: "/**".into(),
433            rpm: 1,
434            burst: 0,
435            key_source: KeySource::Header("X-API-Key".into()),
436            delay_mode: DelayMode::NoDelay,
437        }];
438        let limiter = EnhancedRateLimiter::new(rules);
439
440        let req1 = make_req_with_headers("10.0.0.1", "/api", vec![("X-API-Key", "key-a")]);
441        let req2 = make_req_with_headers("10.0.0.2", "/api", vec![("X-API-Key", "key-b")]);
442
443        // key-a uses its token.
444        assert!(limiter.check(&req1).is_none());
445        assert!(limiter.check(&req1).is_some());
446
447        // key-b has its own limit.
448        assert!(limiter.check(&req2).is_none());
449    }
450
451    #[test]
452    fn cookie_key_source() {
453        let rules = vec![RateLimitRule {
454            name: "by-session".into(),
455            pattern: "/**".into(),
456            rpm: 1,
457            burst: 0,
458            key_source: KeySource::Cookie("session_id".into()),
459            delay_mode: DelayMode::NoDelay,
460        }];
461        let limiter = EnhancedRateLimiter::new(rules);
462
463        let req = make_req_with_headers(
464            "10.0.0.1",
465            "/",
466            vec![("Cookie", "session_id=abc123; other=val")],
467        );
468        assert!(limiter.check(&req).is_none());
469        assert!(limiter.check(&req).is_some());
470    }
471
472    #[test]
473    fn path_matching() {
474        assert!(path_matches("/api/**", "/api/users"));
475        assert!(path_matches("/api/**", "/api/users/123/details"));
476        assert!(path_matches("/api/**", "/api"));
477        assert!(!path_matches("/api/**", "/static/file"));
478        assert!(path_matches("/**", "/anything"));
479        assert!(path_matches("*", "/anything"));
480        assert!(path_matches("/health", "/health"));
481        assert!(!path_matches("/health", "/healthz"));
482    }
483
484    #[test]
485    fn extract_cookie_works() {
486        let mut headers = HashMap::new();
487        headers.insert("Cookie".into(), "a=1; session_id=abc; b=2".into());
488        assert_eq!(extract_cookie(&headers, "session_id"), Some("abc".into()));
489        assert_eq!(extract_cookie(&headers, "a"), Some("1".into()));
490        assert_eq!(extract_cookie(&headers, "missing"), None);
491    }
492
493    #[test]
494    fn cleanup_removes_stale_buckets() {
495        let rules = vec![RateLimitRule {
496            name: "test".into(),
497            pattern: "/**".into(),
498            rpm: 60,
499            burst: 0,
500            key_source: KeySource::Ip,
501            delay_mode: DelayMode::NoDelay,
502        }];
503        let limiter = EnhancedRateLimiter::new(rules);
504        let req = make_req("10.0.0.1", "/");
505        limiter.check(&req);
506        assert!(!limiter.buckets.is_empty());
507
508        // Cleanup with zero duration removes everything.
509        limiter.cleanup(Duration::from_secs(0));
510        assert!(limiter.buckets.is_empty());
511    }
512
513    #[test]
514    fn rate_limit_headers_correct() {
515        let rules = vec![RateLimitRule {
516            name: "strict".into(),
517            pattern: "/**".into(),
518            rpm: 1,
519            burst: 0,
520            key_source: KeySource::Ip,
521            delay_mode: DelayMode::NoDelay,
522        }];
523        let limiter = EnhancedRateLimiter::new(rules);
524        let req = make_req("10.0.0.1", "/");
525
526        // Use the one token.
527        limiter.check(&req);
528        // Now limited.
529        let (_, headers) = limiter.check(&req).unwrap();
530
531        let limit = headers
532            .iter()
533            .find(|(k, _)| k == "RateLimit-Limit")
534            .unwrap();
535        assert_eq!(limit.1, "1"); // rpm + burst = 1
536
537        let remaining = headers
538            .iter()
539            .find(|(k, _)| k == "RateLimit-Remaining")
540            .unwrap();
541        assert_eq!(remaining.1, "0");
542
543        let retry = headers.iter().find(|(k, _)| k == "Retry-After").unwrap();
544        assert!(!retry.1.is_empty());
545    }
546}