Skip to main content

varpulis_cluster/
rate_limit.rs

1//! Rate limiting module for Varpulis API endpoints.
2//!
3//! Provides token bucket rate limiting with per-IP tracking for HTTP APIs.
4//! Used by both the coordinator API and the CLI WebSocket server.
5
6use std::collections::HashMap;
7use std::net::IpAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use axum::extract::ConnectInfo;
12use axum::http::StatusCode;
13use axum::response::{IntoResponse, Response};
14use tokio::sync::RwLock;
15
16// =============================================================================
17// Configuration
18// =============================================================================
19
20/// Rate limiting configuration
21#[derive(Debug, Clone)]
22pub struct RateLimitConfig {
23    /// Whether rate limiting is enabled
24    pub enabled: bool,
25    /// Maximum requests per second per client
26    pub requests_per_second: u32,
27    /// Burst capacity (max tokens in bucket)
28    pub burst_size: u32,
29    /// Maximum number of tracked IP addresses (prevents memory exhaustion)
30    pub max_tracked_ips: usize,
31}
32
33impl RateLimitConfig {
34    /// Default maximum number of tracked IP addresses.
35    const DEFAULT_MAX_TRACKED_IPS: usize = 10_000;
36
37    /// Create a disabled rate limit config
38    pub fn disabled() -> Self {
39        Self {
40            enabled: false,
41            requests_per_second: 0,
42            burst_size: 0,
43            max_tracked_ips: Self::DEFAULT_MAX_TRACKED_IPS,
44        }
45    }
46
47    /// Create a rate limit config with the given rate
48    pub fn new(requests_per_second: u32) -> Self {
49        Self {
50            enabled: true,
51            requests_per_second,
52            // Default burst size is 2x the rate
53            burst_size: requests_per_second.saturating_mul(2),
54            max_tracked_ips: Self::DEFAULT_MAX_TRACKED_IPS,
55        }
56    }
57
58    /// Create a rate limit config with custom burst size
59    pub fn with_burst(requests_per_second: u32, burst_size: u32) -> Self {
60        Self {
61            enabled: true,
62            requests_per_second,
63            burst_size,
64            max_tracked_ips: Self::DEFAULT_MAX_TRACKED_IPS,
65        }
66    }
67}
68
69impl Default for RateLimitConfig {
70    fn default() -> Self {
71        Self::disabled()
72    }
73}
74
75// =============================================================================
76// Token Bucket
77// =============================================================================
78
79/// Token bucket state for a single client
80#[derive(Debug, Clone)]
81struct TokenBucket {
82    /// Current number of tokens
83    tokens: f64,
84    /// Last time tokens were added
85    last_update: Instant,
86    /// Maximum tokens (burst capacity)
87    max_tokens: f64,
88    /// Tokens added per second
89    refill_rate: f64,
90}
91
92impl TokenBucket {
93    fn new(max_tokens: u32, refill_rate: u32) -> Self {
94        Self {
95            tokens: max_tokens as f64,
96            last_update: Instant::now(),
97            max_tokens: max_tokens as f64,
98            refill_rate: refill_rate as f64,
99        }
100    }
101
102    /// Try to consume a token, returning true if successful
103    fn try_consume(&mut self) -> bool {
104        self.refill();
105
106        if self.tokens >= 1.0 {
107            self.tokens -= 1.0;
108            true
109        } else {
110            false
111        }
112    }
113
114    /// Refill tokens based on elapsed time
115    fn refill(&mut self) {
116        let now = Instant::now();
117        let elapsed = now.duration_since(self.last_update);
118        let new_tokens = elapsed.as_secs_f64() * self.refill_rate;
119
120        self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
121        self.last_update = now;
122    }
123
124    /// Get remaining tokens (for header)
125    fn remaining(&self) -> u32 {
126        self.tokens as u32
127    }
128
129    /// Get time until next token is available (for header)
130    fn reset_after(&self) -> Duration {
131        if self.tokens >= 1.0 {
132            Duration::ZERO
133        } else {
134            let tokens_needed = 1.0 - self.tokens;
135            let seconds = tokens_needed / self.refill_rate;
136            Duration::from_secs_f64(seconds)
137        }
138    }
139}
140
141// =============================================================================
142// Rate Limiter
143// =============================================================================
144
145/// Rate limiter with per-IP tracking
146#[derive(Debug)]
147pub struct RateLimiter {
148    config: RateLimitConfig,
149    buckets: RwLock<HashMap<IpAddr, TokenBucket>>,
150}
151
152impl RateLimiter {
153    /// Create a new rate limiter
154    pub fn new(config: RateLimitConfig) -> Self {
155        Self {
156            config,
157            buckets: RwLock::new(HashMap::new()),
158        }
159    }
160
161    /// Check if a request from the given IP should be allowed
162    pub async fn check(&self, ip: IpAddr) -> RateLimitResult {
163        if !self.config.enabled {
164            return RateLimitResult::Allowed {
165                remaining: u32::MAX,
166                reset_after: Duration::ZERO,
167            };
168        }
169
170        let mut buckets = self.buckets.write().await;
171
172        // Evict oldest entry if at capacity and this is a new IP
173        if !buckets.contains_key(&ip) && buckets.len() >= self.config.max_tracked_ips {
174            let oldest_ip = buckets
175                .iter()
176                .min_by_key(|(_, b)| b.last_update)
177                .map(|(ip, _)| *ip);
178            if let Some(ip_to_evict) = oldest_ip {
179                buckets.remove(&ip_to_evict);
180            }
181        }
182
183        let bucket = buckets.entry(ip).or_insert_with(|| {
184            TokenBucket::new(self.config.burst_size, self.config.requests_per_second)
185        });
186
187        if bucket.try_consume() {
188            RateLimitResult::Allowed {
189                remaining: bucket.remaining(),
190                reset_after: bucket.reset_after(),
191            }
192        } else {
193            RateLimitResult::Limited {
194                retry_after: bucket.reset_after(),
195            }
196        }
197    }
198
199    /// Clean up old buckets (call periodically)
200    pub async fn cleanup(&self, max_age: Duration) {
201        let now = Instant::now();
202        let mut buckets = self.buckets.write().await;
203        buckets.retain(|_, bucket| now.duration_since(bucket.last_update) < max_age);
204    }
205
206    /// Get number of tracked clients
207    pub async fn client_count(&self) -> usize {
208        self.buckets.read().await.len()
209    }
210}
211
212/// Result of a rate limit check
213#[derive(Debug, Clone)]
214pub enum RateLimitResult {
215    /// Request is allowed
216    Allowed {
217        /// Remaining requests in current window
218        remaining: u32,
219        /// Time until bucket is full
220        reset_after: Duration,
221    },
222    /// Request is rate limited
223    Limited {
224        /// Time until next request is allowed
225        retry_after: Duration,
226    },
227}
228
229// =============================================================================
230// Axum Middleware
231// =============================================================================
232
233/// Check rate limiting for a request, returning an error response if limited.
234///
235/// This is used as an axum middleware function via `axum::middleware::from_fn_with_state`.
236/// It extracts the client IP from [`ConnectInfo`] and checks the rate limiter.
237pub async fn rate_limit_middleware(
238    connect_info: Option<ConnectInfo<std::net::SocketAddr>>,
239    limiter: Option<Arc<RateLimiter>>,
240    req: axum::extract::Request,
241    next: axum::middleware::Next,
242) -> Response {
243    if let Some(ref limiter) = limiter {
244        let ip = connect_info
245            .map(|ci| ci.0.ip())
246            .unwrap_or(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
247
248        match limiter.check(ip).await {
249            RateLimitResult::Allowed { .. } => {}
250            RateLimitResult::Limited { retry_after } => {
251                let retry_after_secs = retry_after.as_secs().max(1);
252                return rate_limit_error_response(retry_after_secs);
253            }
254        }
255    }
256
257    next.run(req).await
258}
259
260/// Build a 429 Too Many Requests response with JSON body and Retry-After header.
261pub fn rate_limit_error_response(retry_after_secs: u64) -> Response {
262    let body = serde_json::json!({
263        "error": "rate_limited",
264        "message": "Too many requests",
265        "retry_after_seconds": retry_after_secs,
266    });
267    (
268        StatusCode::TOO_MANY_REQUESTS,
269        [("retry-after", retry_after_secs.to_string())],
270        axum::Json(body),
271    )
272        .into_response()
273}
274
275// =============================================================================
276// Tests
277// =============================================================================
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_config_disabled() {
285        let config = RateLimitConfig::disabled();
286        assert!(!config.enabled);
287    }
288
289    #[test]
290    fn test_config_new() {
291        let config = RateLimitConfig::new(100);
292        assert!(config.enabled);
293        assert_eq!(config.requests_per_second, 100);
294        assert_eq!(config.burst_size, 200);
295    }
296
297    #[test]
298    fn test_config_with_burst() {
299        let config = RateLimitConfig::with_burst(100, 50);
300        assert!(config.enabled);
301        assert_eq!(config.requests_per_second, 100);
302        assert_eq!(config.burst_size, 50);
303    }
304
305    #[test]
306    fn test_token_bucket_basic() {
307        let mut bucket = TokenBucket::new(10, 10);
308        assert_eq!(bucket.remaining(), 10);
309
310        // Consume all tokens
311        for _ in 0..10 {
312            assert!(bucket.try_consume());
313        }
314
315        // Should be rate limited
316        assert!(!bucket.try_consume());
317    }
318
319    #[test]
320    fn test_token_bucket_refill() {
321        let mut bucket = TokenBucket::new(10, 100); // 100/sec refill
322
323        // Consume all tokens
324        for _ in 0..10 {
325            bucket.try_consume();
326        }
327
328        // Manually set last_update to past
329        bucket.last_update = Instant::now() - Duration::from_millis(100);
330
331        // After 100ms at 100/sec, should have ~10 tokens
332        bucket.refill();
333        assert!(bucket.remaining() >= 9); // Allow some margin
334    }
335
336    #[tokio::test]
337    async fn test_rate_limiter_disabled() {
338        let config = RateLimitConfig::disabled();
339        let limiter = RateLimiter::new(config);
340
341        let ip = "127.0.0.1".parse().unwrap();
342        match limiter.check(ip).await {
343            RateLimitResult::Allowed { remaining, .. } => {
344                assert_eq!(remaining, u32::MAX);
345            }
346            RateLimitResult::Limited { .. } => panic!("Should not be limited"),
347        }
348    }
349
350    #[tokio::test]
351    async fn test_rate_limiter_allows_burst() {
352        let config = RateLimitConfig::with_burst(10, 5); // 5 burst, 10/sec refill
353        let limiter = RateLimiter::new(config);
354
355        let ip = "127.0.0.1".parse().unwrap();
356
357        // Should allow burst of 5
358        for i in 0..5 {
359            match limiter.check(ip).await {
360                RateLimitResult::Allowed { remaining, .. } => {
361                    assert_eq!(remaining, 4 - i);
362                }
363                RateLimitResult::Limited { .. } => panic!("Should not be limited at request {}", i),
364            }
365        }
366
367        // 6th request should be limited
368        match limiter.check(ip).await {
369            RateLimitResult::Allowed { .. } => panic!("Should be limited"),
370            RateLimitResult::Limited { retry_after } => {
371                assert!(retry_after.as_millis() <= 100);
372            }
373        }
374    }
375
376    #[tokio::test]
377    async fn test_rate_limiter_per_ip() {
378        let config = RateLimitConfig::with_burst(10, 2);
379        let limiter = RateLimiter::new(config);
380
381        let ip1: IpAddr = "127.0.0.1".parse().unwrap();
382        let ip2: IpAddr = "127.0.0.2".parse().unwrap();
383
384        // Exhaust ip1
385        for _ in 0..2 {
386            limiter.check(ip1).await;
387        }
388
389        // ip1 should be limited
390        match limiter.check(ip1).await {
391            RateLimitResult::Allowed { .. } => panic!("ip1 should be limited"),
392            RateLimitResult::Limited { .. } => {}
393        }
394
395        // ip2 should still be allowed
396        match limiter.check(ip2).await {
397            RateLimitResult::Allowed { .. } => {}
398            RateLimitResult::Limited { .. } => panic!("ip2 should not be limited"),
399        }
400    }
401
402    #[tokio::test]
403    async fn test_rate_limiter_cleanup() {
404        let config = RateLimitConfig::new(10);
405        let limiter = RateLimiter::new(config);
406
407        let ip: IpAddr = "127.0.0.1".parse().unwrap();
408        limiter.check(ip).await;
409
410        assert_eq!(limiter.client_count().await, 1);
411
412        // Cleanup with very short max age should remove the bucket
413        limiter.cleanup(Duration::from_nanos(1)).await;
414        assert_eq!(limiter.client_count().await, 0);
415    }
416
417    #[tokio::test]
418    async fn test_rate_limiter_bounded() {
419        let mut config = RateLimitConfig::new(10);
420        config.max_tracked_ips = 3;
421        let limiter = RateLimiter::new(config);
422
423        // Add 3 IPs (at capacity)
424        for i in 1..=3u8 {
425            let ip: IpAddr = format!("10.0.0.{}", i).parse().unwrap();
426            limiter.check(ip).await;
427        }
428        assert_eq!(limiter.client_count().await, 3);
429
430        // Adding a 4th IP should evict the oldest
431        let ip4: IpAddr = "10.0.0.4".parse().unwrap();
432        limiter.check(ip4).await;
433        assert_eq!(limiter.client_count().await, 3);
434    }
435
436    #[test]
437    fn test_reset_after_calculation() {
438        let mut bucket = TokenBucket::new(10, 10); // 10/sec
439
440        // When full, reset_after should be 0
441        assert_eq!(bucket.reset_after(), Duration::ZERO);
442
443        // Exhaust tokens
444        for _ in 0..10 {
445            bucket.try_consume();
446        }
447
448        // reset_after should be ~100ms (need 1 token at 10/sec)
449        let reset = bucket.reset_after();
450        assert!(reset.as_millis() >= 90 && reset.as_millis() <= 110);
451    }
452}