Skip to main content

orca_proxy/
rate_limit.rs

1//! Simple per-IP token bucket rate limiter.
2
3use std::collections::HashMap;
4use std::net::IpAddr;
5use std::sync::{Arc, Mutex};
6use std::time::Instant;
7
8/// Maximum tokens (requests) per IP per second.
9const MAX_TOKENS: u32 = 100;
10
11/// Per-IP token bucket rate limiter.
12///
13/// Uses `std::sync::Mutex` for fast in-memory operations.
14#[derive(Clone, Default)]
15pub struct RateLimiter {
16    buckets: Arc<Mutex<HashMap<IpAddr, (u32, Instant)>>>,
17}
18
19impl RateLimiter {
20    /// Create a new rate limiter.
21    pub fn new() -> Self {
22        Self {
23            buckets: Arc::new(Mutex::new(HashMap::new())),
24        }
25    }
26
27    /// Check if a request from the given IP is allowed.
28    ///
29    /// Returns `true` if the request is within the rate limit, `false` if it
30    /// should be rejected with 429 Too Many Requests.
31    pub fn check(&self, ip: IpAddr) -> bool {
32        let mut buckets = self.buckets.lock().expect("rate limiter lock poisoned");
33        let now = Instant::now();
34
35        let entry = buckets.entry(ip).or_insert((MAX_TOKENS, now));
36
37        // Refill tokens based on elapsed time
38        let elapsed = now.duration_since(entry.1);
39        let refill = (elapsed.as_secs_f64() * MAX_TOKENS as f64) as u32;
40        if refill > 0 {
41            entry.0 = (entry.0 + refill).min(MAX_TOKENS);
42            entry.1 = now;
43        }
44
45        // Try to consume a token
46        if entry.0 > 0 {
47            entry.0 -= 1;
48            true
49        } else {
50            false
51        }
52    }
53}
54
55#[cfg(test)]
56mod tests {
57    use super::*;
58    use std::net::Ipv4Addr;
59
60    #[test]
61    fn allows_requests_under_limit() {
62        let limiter = RateLimiter::new();
63        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
64        for _ in 0..MAX_TOKENS {
65            assert!(limiter.check(ip));
66        }
67    }
68
69    #[test]
70    fn rejects_requests_over_limit() {
71        let limiter = RateLimiter::new();
72        let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
73        for _ in 0..MAX_TOKENS {
74            limiter.check(ip);
75        }
76        assert!(!limiter.check(ip));
77    }
78
79    #[test]
80    fn separate_buckets_per_ip() {
81        let limiter = RateLimiter::new();
82        let ip1 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
83        let ip2 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));
84        for _ in 0..MAX_TOKENS {
85            limiter.check(ip1);
86        }
87        assert!(!limiter.check(ip1));
88        assert!(limiter.check(ip2));
89    }
90
91    #[test]
92    fn test_rate_limiter_allows_under_limit() {
93        let limiter = RateLimiter::new();
94        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
95        for i in 0..50 {
96            assert!(limiter.check(ip), "request {i} should be allowed");
97        }
98    }
99
100    #[test]
101    fn test_rate_limiter_blocks_over_limit() {
102        let limiter = RateLimiter::new();
103        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
104        // Exhaust all 100 tokens
105        for _ in 0..MAX_TOKENS {
106            limiter.check(ip);
107        }
108        // The 101st request should be blocked
109        assert!(
110            !limiter.check(ip),
111            "request 101 should be blocked after exhausting tokens"
112        );
113    }
114
115    #[test]
116    fn test_rate_limiter_allows_different_ips() {
117        let limiter = RateLimiter::new();
118        let ip1 = IpAddr::V4(Ipv4Addr::new(10, 1, 0, 1));
119        let ip2 = IpAddr::V4(Ipv4Addr::new(10, 1, 0, 2));
120        // 100 requests from each IP should all be allowed
121        for i in 0..MAX_TOKENS {
122            assert!(limiter.check(ip1), "ip1 request {i} should be allowed");
123            assert!(limiter.check(ip2), "ip2 request {i} should be allowed");
124        }
125    }
126
127    #[test]
128    fn test_rate_limiter_refills_after_time() {
129        let limiter = RateLimiter::new();
130        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 3));
131        // Exhaust all tokens
132        for _ in 0..MAX_TOKENS {
133            limiter.check(ip);
134        }
135        assert!(!limiter.check(ip), "should be blocked after exhaustion");
136
137        // Wait for refill (1 second should refill ~100 tokens)
138        std::thread::sleep(std::time::Duration::from_secs(1));
139
140        assert!(
141            limiter.check(ip),
142            "should be allowed again after token refill"
143        );
144    }
145
146    #[test]
147    fn test_rate_limiter_stress() {
148        let limiter = RateLimiter::new();
149        let ip = IpAddr::V4(Ipv4Addr::new(10, 99, 0, 1));
150        let mut allowed = 0u32;
151        let mut blocked = 0u32;
152        for _ in 0..200 {
153            if limiter.check(ip) {
154                allowed += 1;
155            } else {
156                blocked += 1;
157            }
158        }
159        // MAX_TOKENS is 100, so first 100 pass, remaining ~100 blocked.
160        // Timing jitter from refill may shift counts slightly.
161        assert!(
162            allowed >= 100,
163            "expected at least 100 allowed, got {allowed}"
164        );
165        assert!(blocked >= 50, "expected at least 50 blocked, got {blocked}");
166    }
167}