ntex_ratelimiter/
limiter.rs

1use dashmap::DashMap;
2use ntex::http::header::{HeaderName, HeaderValue};
3use ntex::{http::StatusCode, Middleware, ServiceCtx};
4use std::net::IpAddr;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8use ntex::{web, Service};
9use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
10
11#[cfg(feature = "tokio")]
12use tokio::time::interval;
13
14#[cfg(feature = "async-std")]
15use async_std::task;
16
17#[cfg(feature = "json")]
18use serde::{Deserialize, Serialize};
19
20const HEADER_RATELIMIT_REMAINING: &str = "x-ratelimit-remaining";
21const HEADER_RATELIMIT_LIMIT: &str = "x-ratelimit-limit";
22const HEADER_RATELIMIT_RESET: &str = "x-ratelimit-reset";
23
24/// Token bucket algorithm implementation for rate limiting
25#[derive(Debug)]
26struct TokenBucket {
27    tokens: f64,
28    last_refill: Instant,
29}
30
31impl TokenBucket {
32    fn new(capacity: usize) -> Self {
33        Self {
34            tokens: capacity as f64,
35            last_refill: Instant::now(),
36        }
37    }
38
39    fn consume(&mut self, tokens: usize, now: Instant, config: &RateLimiterConfig) -> bool {
40        self.refill(now, config);
41        if self.tokens >= tokens as f64 {
42            self.tokens -= tokens as f64;
43            true
44        } else {
45            false
46        }
47    }
48
49    fn refill(&mut self, now: Instant, config: &RateLimiterConfig) {
50        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
51        let refill_rate = config.capacity as f64 / config.window as f64;
52        let new_tokens = elapsed * refill_rate;
53        self.tokens = (self.tokens + new_tokens).min(config.capacity as f64);
54        self.last_refill = now;
55    }
56
57    fn remaining_tokens(&self) -> u32 {
58        self.tokens.floor() as u32
59    }
60
61    fn reset_time(&self, _now: Instant, config: &RateLimiterConfig) -> u64 {
62        let now_secs = SystemTime::now()
63            .duration_since(UNIX_EPOCH)
64            .unwrap_or_default()
65            .as_secs();
66
67        if self.tokens >= config.capacity as f64 {
68            return now_secs;
69        }
70
71        let missing_tokens = config.capacity as f64 - self.tokens;
72        let refill_rate = config.capacity as f64 / config.window as f64;
73        let seconds_to_refill = missing_tokens / refill_rate;
74
75        now_secs + seconds_to_refill.ceil() as u64
76    }
77
78    /// Check if this bucket is stale (hasn't been used recently)
79    fn is_stale(&self, now: Instant, stale_threshold: Duration) -> bool {
80        now.duration_since(self.last_refill) > stale_threshold
81    }
82}
83
84/// Configuration for the rate limiter
85#[derive(Debug, Clone)]
86pub struct RateLimiterConfig {
87    pub capacity: usize,
88    pub window: u64,
89    pub cleanup_interval: Duration,
90    pub stale_threshold: Duration,
91}
92
93impl Default for RateLimiterConfig {
94    fn default() -> Self {
95        Self {
96            capacity: 100,
97            window: 60,
98            cleanup_interval: Duration::from_secs(300), // 5 minutes
99            stale_threshold: Duration::from_secs(3600), // 1 hour
100        }
101    }
102}
103
104/// High-performance rate limiter using token bucket algorithm
105pub struct RateLimiter {
106    map: DashMap<IpAddr, TokenBucket>,
107    config: RateLimiterConfig,
108    last_cleanup: AtomicU64,
109}
110
111impl RateLimiter {
112    /// Create a new rate limiter with default configuration
113    pub fn new(capacity: usize, window: u64) -> Arc<Self> {
114        let config = RateLimiterConfig {
115            capacity,
116            window,
117            ..Default::default()
118        };
119        Self::with_config(config)
120    }
121
122    /// Create a new rate limiter with custom configuration
123    pub fn with_config(config: RateLimiterConfig) -> Arc<Self> {
124        assert!(config.window > 0, "RateLimiter window must be greater than zero");
125
126        let limiter = Arc::new(RateLimiter {
127            map: DashMap::new(),
128            config,
129            last_cleanup: AtomicU64::new(
130                SystemTime::now()
131                    .duration_since(UNIX_EPOCH)
132                    .unwrap_or_default()
133                    .as_secs(),
134            ),
135        });
136
137        // Start periodic cleanup if a runtime is enabled
138        #[cfg(any(feature = "tokio", feature = "async-std"))]
139        Self::start_cleanup_task(Arc::clone(&limiter));
140
141        limiter
142    }
143
144    #[cfg(feature = "tokio")]
145    fn start_cleanup_task(limiter: Arc<RateLimiter>) {
146        tokio::spawn(async move {
147            let mut interval = interval(limiter.config.cleanup_interval);
148            loop {
149                interval.tick().await;
150                limiter.cleanup().await;
151            }
152        });
153    }
154
155    #[cfg(feature = "async-std")]
156    fn start_cleanup_task(limiter: Arc<RateLimiter>) {
157        let cleanup_interval = limiter.config.cleanup_interval;
158        task::spawn(async move {
159            loop {
160                task::sleep(cleanup_interval).await;
161                limiter.cleanup().await;
162            }
163        });
164    }
165
166    /// Check rate limit for a given identifier (usually IP address)
167    pub fn check_rate_limit(&self, identifier: IpAddr) -> RateLimitResult {
168        let now = Instant::now();
169        let mut bucket = self
170            .map
171            .entry(identifier)
172            .or_insert_with(|| TokenBucket::new(self.config.capacity));
173
174        let allowed = bucket.consume(1, now, &self.config);
175        let remaining = bucket.remaining_tokens();
176        let reset = bucket.reset_time(now, &self.config);
177
178        RateLimitResult {
179            allowed,
180            remaining,
181            reset,
182            limit: self.config.capacity,
183        }
184    }
185
186    /// Clean up stale entries
187    async fn cleanup(&self) {
188        let now_secs = SystemTime::now()
189            .duration_since(UNIX_EPOCH)
190            .unwrap_or_default()
191            .as_secs();
192
193        let last_cleanup = self.last_cleanup.load(Ordering::Acquire);
194
195        // Check if enough time has passed since last cleanup
196        if now_secs.saturating_sub(last_cleanup) < self.config.cleanup_interval.as_secs() {
197            return;
198        }
199
200        // Try to update the last cleanup time atomically
201        if self
202            .last_cleanup
203            .compare_exchange(last_cleanup, now_secs, Ordering::AcqRel, Ordering::Relaxed)
204            .is_err()
205        {
206            // Another thread is doing cleanup, skip this one
207            return;
208        }
209
210        let now = Instant::now();
211        let stale_threshold = self.config.stale_threshold;
212
213        let initial_size = self.map.len();
214        self.map
215            .retain(|_, bucket| !bucket.is_stale(now, stale_threshold));
216        let final_size = self.map.len();
217
218        if cfg!(debug_assertions) && initial_size > final_size {
219            eprintln!(
220                "Cleaned {} stale rate limit entries",
221                initial_size - final_size
222            );
223        }
224    }
225
226    /// Get current statistics
227    pub fn stats(&self) -> RateLimiterStats {
228        RateLimiterStats {
229            active_entries: self.map.len(),
230            capacity: self.config.capacity,
231            window: self.config.window,
232        }
233    }
234}
235
236/// Result of a rate limit check
237#[derive(Debug, Clone)]
238pub struct RateLimitResult {
239    pub allowed: bool,
240    pub remaining: u32,
241    pub reset: u64,
242    pub limit: usize,
243}
244
245/// Statistics about the rate limiter
246#[derive(Debug, Clone)]
247pub struct RateLimiterStats {
248    pub active_entries: usize,
249    pub capacity: usize,
250    pub window: u64,
251}
252
253/// Rate limiting middleware
254pub struct RateLimit {
255    pub limiter: Arc<RateLimiter>,
256}
257
258impl RateLimit {
259    pub fn new(limiter: Arc<RateLimiter>) -> Self {
260        Self { limiter }
261    }
262}
263
264impl<S> Middleware<S> for RateLimit {
265    type Service = RateLimitMiddlewareService<S>;
266
267    fn create(&self, service: S) -> Self::Service {
268        RateLimitMiddlewareService {
269            service,
270            limiter: Arc::clone(&self.limiter),
271        }
272    }
273}
274
275pub struct RateLimitMiddlewareService<S> {
276    service: S,
277    limiter: Arc<RateLimiter>,
278}
279
280impl<S, Err> Service<web::WebRequest<Err>> for RateLimitMiddlewareService<S>
281where
282    S: Service<web::WebRequest<Err>, Response = web::WebResponse, Error = web::Error> + 'static,
283    Err: web::ErrorRenderer,
284{
285    type Response = web::WebResponse;
286    type Error = web::Error;
287
288    async fn call(
289        &self,
290        req: web::WebRequest<Err>,
291        ctx: ServiceCtx<'_, Self>,
292    ) -> Result<Self::Response, Self::Error> {
293        let ip = extract_client_ip(&req);
294
295        let result = self.limiter.check_rate_limit(ip);
296
297        if !result.allowed {
298            return Err(RateLimitError::from(result).into());
299        }
300
301        let mut response = ctx.call(&self.service, req).await?;
302
303        // Add rate limit headers to successful responses
304        add_rate_limit_headers(response.headers_mut(), &result);
305
306        Ok(response)
307    }
308}
309
310/// Extract client IP from request, considering proxy headers
311fn extract_client_ip<Err>(req: &web::WebRequest<Err>) -> IpAddr {
312    // Check X-Forwarded-For header first
313    if let Some(forwarded) = req.headers().get("x-forwarded-for") {
314        if let Ok(forwarded_str) = forwarded.to_str() {
315            if let Some(ip) = forwarded_str.split(',').next() {
316                let ip = ip.trim();
317                if let Ok(parsed_ip) = ip.parse::<IpAddr>() {
318                    return parsed_ip;
319                }
320            }
321        }
322    }
323
324    // Check X-Real-IP header
325    if let Some(real_ip) = req.headers().get("x-real-ip") {
326        if let Ok(ip_str) = real_ip.to_str() {
327            let ip = ip_str.trim();
328            if let Ok(parsed_ip) = ip.parse::<IpAddr>() {
329                return parsed_ip;
330            }
331        }
332    }
333
334    // Fallback to connection info - parse SocketAddr to get IP only
335    if let Some(addr_str) = req.connection_info().remote() {
336        if let Ok(sock_addr) = addr_str.parse::<std::net::SocketAddr>() {
337            return sock_addr.ip();
338        }
339    }
340
341    // Default to localhost if all else fails
342    IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))
343}
344
345/// Add rate limit headers to response
346fn add_rate_limit_headers(headers: &mut ntex::http::HeaderMap, result: &RateLimitResult) {
347    if let Ok(value) = HeaderValue::from_str(&result.remaining.to_string()) {
348        headers.insert(HeaderName::from_static(HEADER_RATELIMIT_REMAINING), value);
349    }
350    if let Ok(value) = HeaderValue::from_str(&result.limit.to_string()) {
351        headers.insert(HeaderName::from_static(HEADER_RATELIMIT_LIMIT), value);
352    }
353    if let Ok(value) = HeaderValue::from_str(&result.reset.to_string()) {
354        headers.insert(HeaderName::from_static(HEADER_RATELIMIT_RESET), value);
355    }
356}
357
358/// Rate limit error response
359#[derive(Debug)]
360#[cfg_attr(feature = "json", derive(Serialize, Deserialize))]
361struct RateLimitErrorData {
362    remaining: u32,
363    reset: u64,
364    limit: usize,
365}
366
367#[derive(Debug)]
368#[cfg_attr(feature = "json", derive(Serialize, Deserialize))]
369struct RateLimitErrorResponse {
370    code: u32,
371    message: String,
372    data: RateLimitErrorData,
373}
374
375#[derive(Debug)]
376struct RateLimitError {
377    data: RateLimitErrorData,
378}
379
380impl From<RateLimitResult> for RateLimitError {
381    fn from(result: RateLimitResult) -> Self {
382        Self {
383            data: RateLimitErrorData {
384                remaining: result.remaining,
385                reset: result.reset,
386                limit: result.limit,
387            },
388        }
389    }
390}
391
392impl std::fmt::Display for RateLimitError {
393    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394        write!(
395            f,
396            "Rate limit exceeded. Remaining: {}, Reset: {}, Limit: {}",
397            self.data.remaining, self.data.reset, self.data.limit
398        )
399    }
400}
401
402impl web::error::WebResponseError for RateLimitError {
403    fn error_response(&self, _: &ntex::web::HttpRequest) -> web::HttpResponse {
404        let error_response = RateLimitErrorResponse {
405            code: 429,
406            message: "Rate limit exceeded".to_string(),
407            data: RateLimitErrorData {
408                remaining: self.data.remaining,
409                reset: self.data.reset,
410                limit: self.data.limit,
411            },
412        };
413
414        #[cfg(feature = "json")]
415        let body = serde_json::to_string(&error_response)
416            .unwrap_or_else(|_| r#"{"code":429,"message":"Rate limit exceeded"}"#.to_string());
417
418        #[cfg(not(feature = "json"))]
419        let body = format!(
420            r#"{{"code":429,"message":"Rate limit exceeded","data":{{"remaining":{},"reset":{},"limit":{}}}}}"#,
421            self.data.remaining, self.data.reset, self.data.limit
422        );
423
424        web::HttpResponse::build(StatusCode::TOO_MANY_REQUESTS)
425            .set_header("content-type", "application/json")
426            .set_header(HEADER_RATELIMIT_REMAINING, self.data.remaining.to_string())
427            .set_header(HEADER_RATELIMIT_LIMIT, self.data.limit.to_string())
428            .set_header(HEADER_RATELIMIT_RESET, self.data.reset.to_string())
429            .body(body)
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_token_bucket_basic() {
439        let config = RateLimiterConfig {
440            capacity: 5,
441            window: 10,
442            ..Default::default()
443        };
444        let mut bucket = TokenBucket::new(5);
445        let now = Instant::now();
446
447        // Should allow up to capacity
448        for _ in 0..5 {
449            assert!(bucket.consume(1, now, &config));
450        }
451
452        // Should deny when capacity exceeded
453        assert!(!bucket.consume(1, now, &config));
454        assert_eq!(bucket.remaining_tokens(), 0);
455    }
456
457    #[test]
458    fn test_token_bucket_refill() {
459        let config = RateLimiterConfig {
460            capacity: 10,
461            window: 10, // 1 token per second
462            ..Default::default()
463        };
464        let mut bucket = TokenBucket::new(10);
465        let now = Instant::now();
466
467        // Consume all tokens
468        for _ in 0..10 {
469            assert!(bucket.consume(1, now, &config));
470        }
471        assert!(!bucket.consume(1, now, &config));
472
473        // After 5 seconds, should have 5 tokens
474        let later = now + Duration::from_secs(5);
475        bucket.refill(later, &config);
476        assert_eq!(bucket.remaining_tokens(), 5);
477
478        // Should be able to consume 5 tokens
479        for _ in 0..5 {
480            assert!(bucket.consume(1, later, &config));
481        }
482        assert!(!bucket.consume(1, later, &config));
483    }
484
485    #[tokio::test]
486    async fn test_rate_limiter() {
487        let config = RateLimiterConfig {
488            capacity: 5,
489            window: 1,
490            ..Default::default()
491        };
492        let limiter = RateLimiter::with_config(config);
493        let ip = "192.168.1.1".parse::<IpAddr>().unwrap();
494
495        // Should allow up to capacity
496        for i in 0..5 {
497            let result = limiter.check_rate_limit(ip);
498            assert!(result.allowed, "Request {} should be allowed", i + 1);
499            assert_eq!(result.remaining, 4 - i as u32);
500        }
501
502        // Should deny when capacity exceeded
503        let result = limiter.check_rate_limit(ip);
504        assert!(!result.allowed);
505        assert_eq!(result.remaining, 0);
506    }
507
508    #[tokio::test]
509    async fn test_rate_limiter_different_ips() {
510        let limiter = RateLimiter::new(2, 60);
511
512        // Different IPs should have separate limits
513        let ip1 = "192.168.1.1".parse::<IpAddr>().unwrap();
514        let ip2 = "192.168.1.2".parse::<IpAddr>().unwrap();
515        let result1 = limiter.check_rate_limit(ip1);
516        let result2 = limiter.check_rate_limit(ip2);
517
518        assert!(result1.allowed);
519        assert!(result2.allowed);
520        assert_eq!(result1.remaining, 1);
521        assert_eq!(result2.remaining, 1);
522    }
523}