Skip to main content

pylon_runtime/
rate_limit.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::{Duration, Instant};
4
5/// Per-IP rate limiter using a sliding window.
6///
7/// Each IP address gets a bucket of timestamps. When a request arrives, expired
8/// entries (older than `window`) are pruned, and the remaining count is checked
9/// against `max_requests`. If the limit is exceeded, `check()` returns `Err`
10/// with the number of seconds the caller should wait before retrying.
11pub struct RateLimiter {
12    window: Duration,
13    max_requests: u32,
14    buckets: Mutex<HashMap<String, Vec<Instant>>>,
15}
16
17impl RateLimiter {
18    /// Create a new rate limiter.
19    ///
20    /// - `max_requests`: maximum number of requests allowed within the window.
21    /// - `window_secs`: sliding window duration in seconds.
22    pub fn new(max_requests: u32, window_secs: u64) -> Self {
23        Self {
24            window: Duration::from_secs(window_secs),
25            max_requests,
26            buckets: Mutex::new(HashMap::new()),
27        }
28    }
29
30    /// Check if a request from this IP is allowed.
31    ///
32    /// Returns `Ok(())` if the request is within limits, or `Err(retry_after)`
33    /// with the number of seconds to wait before the next request will be
34    /// accepted.
35    pub fn check(&self, ip: &str) -> Result<(), u64> {
36        let now = Instant::now();
37        let mut buckets = self.buckets.lock().unwrap();
38        let timestamps = buckets.entry(ip.to_string()).or_default();
39
40        // Remove entries outside the sliding window.
41        timestamps.retain(|t| now.duration_since(*t) < self.window);
42
43        if timestamps.len() as u32 >= self.max_requests {
44            let oldest = timestamps.first().unwrap();
45            let elapsed = now.duration_since(*oldest).as_secs();
46            let retry_after = self.window.as_secs().saturating_sub(elapsed);
47            // Ensure we always return at least 1 second.
48            return Err(retry_after.max(1));
49        }
50
51        timestamps.push(now);
52        Ok(())
53    }
54
55    /// Remove all expired entries from every bucket.
56    ///
57    /// Call periodically (e.g., from a background thread) to prevent unbounded
58    /// memory growth from IPs that stop sending requests.
59    pub fn cleanup(&self) {
60        let now = Instant::now();
61        let mut buckets = self.buckets.lock().unwrap();
62
63        // Remove expired timestamps, then drop empty buckets entirely.
64        buckets.retain(|_ip, timestamps| {
65            timestamps.retain(|t| now.duration_since(*t) < self.window);
66            !timestamps.is_empty()
67        });
68    }
69
70    /// Get the current request count for an IP within the active window.
71    pub fn current_count(&self, ip: &str) -> u32 {
72        let now = Instant::now();
73        let buckets = self.buckets.lock().unwrap();
74        match buckets.get(ip) {
75            Some(timestamps) => timestamps
76                .iter()
77                .filter(|t| now.duration_since(**t) < self.window)
78                .count() as u32,
79            None => 0,
80        }
81    }
82}
83
84// ---------------------------------------------------------------------------
85// Tests
86// ---------------------------------------------------------------------------
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use std::thread;
92    use std::time::Duration;
93
94    #[test]
95    fn under_limit_passes() {
96        let rl = RateLimiter::new(5, 60);
97        for _ in 0..5 {
98            assert!(rl.check("10.0.0.1").is_ok());
99        }
100    }
101
102    #[test]
103    fn over_limit_rejected() {
104        let rl = RateLimiter::new(3, 60);
105        for _ in 0..3 {
106            assert!(rl.check("10.0.0.1").is_ok());
107        }
108        let err = rl.check("10.0.0.1").unwrap_err();
109        assert!(err >= 1, "retry_after should be at least 1 second");
110    }
111
112    #[test]
113    fn window_expiry_allows_new_requests() {
114        // Use a very short window so the test finishes quickly.
115        let rl = RateLimiter::new(2, 1);
116        assert!(rl.check("10.0.0.1").is_ok());
117        assert!(rl.check("10.0.0.1").is_ok());
118        assert!(rl.check("10.0.0.1").is_err());
119
120        // Wait for the window to expire.
121        thread::sleep(Duration::from_millis(1100));
122
123        // Should be allowed again.
124        assert!(rl.check("10.0.0.1").is_ok());
125    }
126
127    #[test]
128    fn different_ips_are_independent() {
129        let rl = RateLimiter::new(2, 60);
130        assert!(rl.check("10.0.0.1").is_ok());
131        assert!(rl.check("10.0.0.1").is_ok());
132        assert!(rl.check("10.0.0.1").is_err());
133
134        // Different IP should still be allowed.
135        assert!(rl.check("10.0.0.2").is_ok());
136        assert!(rl.check("10.0.0.2").is_ok());
137    }
138
139    #[test]
140    fn cleanup_removes_expired_buckets() {
141        let rl = RateLimiter::new(10, 1);
142        assert!(rl.check("10.0.0.1").is_ok());
143        assert!(rl.check("10.0.0.2").is_ok());
144
145        // Wait for expiry.
146        thread::sleep(Duration::from_millis(1100));
147
148        rl.cleanup();
149
150        // After cleanup, counts should be zero (expired entries removed).
151        assert_eq!(rl.current_count("10.0.0.1"), 0);
152        assert_eq!(rl.current_count("10.0.0.2"), 0);
153    }
154
155    #[test]
156    fn current_count_reflects_active_requests() {
157        let rl = RateLimiter::new(10, 60);
158        assert_eq!(rl.current_count("10.0.0.1"), 0);
159
160        rl.check("10.0.0.1").unwrap();
161        assert_eq!(rl.current_count("10.0.0.1"), 1);
162
163        rl.check("10.0.0.1").unwrap();
164        rl.check("10.0.0.1").unwrap();
165        assert_eq!(rl.current_count("10.0.0.1"), 3);
166    }
167}