Skip to main content

actix_security_core/http/security/
rate_limit.rs

1//! Rate Limiting middleware for brute-force protection.
2//!
3//! Provides configurable rate limiting to protect against brute-force attacks,
4//! DDoS attempts, and API abuse.
5//!
6//! # Spring Security Equivalent
7//! Similar to Spring Security's `RateLimiter` and integration with Bucket4j.
8//!
9//! # Example
10//!
11//! ```ignore
12//! use actix_security::http::security::rate_limit::{RateLimiter, RateLimitConfig};
13//! use actix_web::{App, HttpServer};
14//!
15//! let rate_limiter = RateLimiter::new(
16//!     RateLimitConfig::new()
17//!         .requests_per_second(10)
18//!         .burst_size(20)
19//! );
20//!
21//! HttpServer::new(move || {
22//!     App::new()
23//!         .wrap(rate_limiter.clone())
24//!         .route("/api/login", web::post().to(login))
25//! })
26//! ```
27
28use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
29use actix_web::http::header::{HeaderName, HeaderValue};
30use actix_web::http::StatusCode;
31use actix_web::{Error, HttpResponse};
32use futures_util::future::{ok, LocalBoxFuture, Ready};
33use std::collections::HashMap;
34use std::sync::Arc;
35use std::time::{Duration, Instant};
36use tokio::sync::RwLock;
37
38/// Rate limit exceeded error response.
39#[derive(Debug, Clone)]
40pub struct RateLimitExceeded {
41    /// Retry after seconds
42    pub retry_after: u64,
43    /// Custom message
44    pub message: String,
45}
46
47impl Default for RateLimitExceeded {
48    fn default() -> Self {
49        Self {
50            retry_after: 60,
51            message: "Too many requests. Please try again later.".to_string(),
52        }
53    }
54}
55
56/// Type alias for custom key extractor function.
57pub type KeyExtractorFn = Arc<dyn Fn(&ServiceRequest) -> Option<String> + Send + Sync>;
58
59/// Strategy for identifying clients for rate limiting.
60#[derive(Clone, Default)]
61pub enum KeyExtractor {
62    /// Rate limit by IP address (default)
63    #[default]
64    IpAddress,
65    /// Rate limit by authenticated user
66    User,
67    /// Rate limit by custom header
68    Header(String),
69    /// Rate limit by IP + endpoint combination
70    IpAndEndpoint,
71    /// Custom key extractor function
72    Custom(KeyExtractorFn),
73}
74
75impl std::fmt::Debug for KeyExtractor {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        match self {
78            KeyExtractor::IpAddress => write!(f, "IpAddress"),
79            KeyExtractor::User => write!(f, "User"),
80            KeyExtractor::Header(h) => write!(f, "Header({})", h),
81            KeyExtractor::IpAndEndpoint => write!(f, "IpAndEndpoint"),
82            KeyExtractor::Custom(_) => write!(f, "Custom(<fn>)"),
83        }
84    }
85}
86
87impl KeyExtractor {
88    /// Extract the rate limit key from a request.
89    pub fn extract(&self, req: &ServiceRequest) -> Option<String> {
90        match self {
91            KeyExtractor::IpAddress => req
92                .connection_info()
93                .realip_remote_addr()
94                .map(|s| s.to_string()),
95            KeyExtractor::User => req
96                .headers()
97                .get("Authorization")
98                .and_then(|h| h.to_str().ok())
99                .map(|s| s.to_string()),
100            KeyExtractor::Header(name) => req
101                .headers()
102                .get(name.as_str())
103                .and_then(|h| h.to_str().ok())
104                .map(|s| s.to_string()),
105            KeyExtractor::IpAndEndpoint => {
106                let ip = req.connection_info().realip_remote_addr()?.to_string();
107                let path = req.path().to_string();
108                Some(format!("{}:{}", ip, path))
109            }
110            KeyExtractor::Custom(f) => f(req),
111        }
112    }
113}
114
115/// Rate limiting algorithm.
116#[derive(Debug, Clone, Default)]
117pub enum RateLimitAlgorithm {
118    /// Fixed window counter (simpler, less memory)
119    #[default]
120    FixedWindow,
121    /// Sliding window log (more accurate, more memory)
122    SlidingWindow,
123    /// Token bucket (smooth rate limiting)
124    TokenBucket,
125}
126
127/// Rate limit configuration.
128#[derive(Clone)]
129pub struct RateLimitConfig {
130    /// Maximum requests per window
131    pub max_requests: u64,
132    /// Time window duration
133    pub window: Duration,
134    /// Burst size (for token bucket)
135    pub burst_size: u64,
136    /// Algorithm to use
137    pub algorithm: RateLimitAlgorithm,
138    /// Key extractor
139    pub key_extractor: KeyExtractor,
140    /// Paths to exclude from rate limiting
141    pub excluded_paths: Vec<String>,
142    /// Whether to add rate limit headers to response
143    pub add_headers: bool,
144    /// Custom error response
145    pub error_response: Option<Arc<dyn Fn(RateLimitExceeded) -> HttpResponse + Send + Sync>>,
146}
147
148impl std::fmt::Debug for RateLimitConfig {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        f.debug_struct("RateLimitConfig")
151            .field("max_requests", &self.max_requests)
152            .field("window", &self.window)
153            .field("burst_size", &self.burst_size)
154            .field("algorithm", &self.algorithm)
155            .field("key_extractor", &self.key_extractor)
156            .field("excluded_paths", &self.excluded_paths)
157            .field("add_headers", &self.add_headers)
158            .field(
159                "error_response",
160                &self.error_response.as_ref().map(|_| "<fn>"),
161            )
162            .finish()
163    }
164}
165
166impl Default for RateLimitConfig {
167    fn default() -> Self {
168        Self {
169            max_requests: 100,
170            window: Duration::from_secs(60),
171            burst_size: 10,
172            algorithm: RateLimitAlgorithm::default(),
173            key_extractor: KeyExtractor::default(),
174            excluded_paths: vec![],
175            add_headers: true,
176            error_response: None,
177        }
178    }
179}
180
181impl RateLimitConfig {
182    /// Create a new rate limit configuration.
183    pub fn new() -> Self {
184        Self::default()
185    }
186
187    /// Set maximum requests per window.
188    pub fn max_requests(mut self, max: u64) -> Self {
189        self.max_requests = max;
190        self
191    }
192
193    /// Set requests per second (convenience method).
194    pub fn requests_per_second(mut self, rps: u64) -> Self {
195        self.max_requests = rps;
196        self.window = Duration::from_secs(1);
197        self
198    }
199
200    /// Set requests per minute.
201    pub fn requests_per_minute(mut self, rpm: u64) -> Self {
202        self.max_requests = rpm;
203        self.window = Duration::from_secs(60);
204        self
205    }
206
207    /// Set time window.
208    pub fn window(mut self, window: Duration) -> Self {
209        self.window = window;
210        self
211    }
212
213    /// Set burst size for token bucket algorithm.
214    pub fn burst_size(mut self, size: u64) -> Self {
215        self.burst_size = size;
216        self
217    }
218
219    /// Set the rate limiting algorithm.
220    pub fn algorithm(mut self, algo: RateLimitAlgorithm) -> Self {
221        self.algorithm = algo;
222        self
223    }
224
225    /// Set the key extractor.
226    pub fn key_extractor(mut self, extractor: KeyExtractor) -> Self {
227        self.key_extractor = extractor;
228        self
229    }
230
231    /// Exclude paths from rate limiting.
232    pub fn exclude_paths(mut self, paths: Vec<&str>) -> Self {
233        self.excluded_paths = paths.into_iter().map(String::from).collect();
234        self
235    }
236
237    /// Whether to add rate limit headers.
238    pub fn add_headers(mut self, add: bool) -> Self {
239        self.add_headers = add;
240        self
241    }
242
243    /// Set custom error response handler.
244    pub fn error_response<F>(mut self, handler: F) -> Self
245    where
246        F: Fn(RateLimitExceeded) -> HttpResponse + Send + Sync + 'static,
247    {
248        self.error_response = Some(Arc::new(handler));
249        self
250    }
251
252    /// Create a strict configuration for login endpoints.
253    pub fn strict_login() -> Self {
254        Self::new()
255            .requests_per_minute(5)
256            .burst_size(3)
257            .algorithm(RateLimitAlgorithm::SlidingWindow)
258    }
259
260    /// Create a lenient configuration for API endpoints.
261    pub fn lenient_api() -> Self {
262        Self::new()
263            .requests_per_minute(1000)
264            .burst_size(100)
265            .algorithm(RateLimitAlgorithm::TokenBucket)
266    }
267}
268
269/// Rate limit entry for tracking requests.
270#[derive(Debug, Clone)]
271struct RateLimitEntry {
272    /// Request count in current window
273    count: u64,
274    /// Window start time
275    window_start: Instant,
276    /// Request timestamps for sliding window
277    timestamps: Vec<Instant>,
278    /// Available tokens for token bucket
279    tokens: f64,
280    /// Last token refill time
281    last_refill: Instant,
282}
283
284impl RateLimitEntry {
285    fn new(config: &RateLimitConfig) -> Self {
286        Self {
287            count: 0,
288            window_start: Instant::now(),
289            timestamps: Vec::new(),
290            tokens: config.burst_size as f64,
291            last_refill: Instant::now(),
292        }
293    }
294}
295
296/// Rate limiter state.
297#[derive(Clone)]
298pub struct RateLimiterState {
299    entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
300    config: RateLimitConfig,
301}
302
303impl RateLimiterState {
304    /// Create new rate limiter state.
305    pub fn new(config: RateLimitConfig) -> Self {
306        Self {
307            entries: Arc::new(RwLock::new(HashMap::new())),
308            config,
309        }
310    }
311
312    /// Check if a request should be allowed.
313    pub async fn check(&self, key: &str) -> Result<RateLimitInfo, RateLimitExceeded> {
314        let mut entries = self.entries.write().await;
315        let now = Instant::now();
316
317        let entry = entries
318            .entry(key.to_string())
319            .or_insert_with(|| RateLimitEntry::new(&self.config));
320
321        match self.config.algorithm {
322            RateLimitAlgorithm::FixedWindow => self.check_fixed_window(entry, now),
323            RateLimitAlgorithm::SlidingWindow => self.check_sliding_window(entry, now),
324            RateLimitAlgorithm::TokenBucket => self.check_token_bucket(entry, now),
325        }
326    }
327
328    fn check_fixed_window(
329        &self,
330        entry: &mut RateLimitEntry,
331        now: Instant,
332    ) -> Result<RateLimitInfo, RateLimitExceeded> {
333        // Reset window if expired
334        if now.duration_since(entry.window_start) >= self.config.window {
335            entry.count = 0;
336            entry.window_start = now;
337        }
338
339        if entry.count >= self.config.max_requests {
340            let reset_time = entry.window_start + self.config.window;
341            let retry_after = reset_time.saturating_duration_since(now).as_secs();
342            return Err(RateLimitExceeded {
343                retry_after,
344                message: "Rate limit exceeded".to_string(),
345            });
346        }
347
348        entry.count += 1;
349
350        let reset_time = entry.window_start + self.config.window;
351        Ok(RateLimitInfo {
352            limit: self.config.max_requests,
353            remaining: self.config.max_requests.saturating_sub(entry.count),
354            reset: reset_time.saturating_duration_since(now).as_secs(),
355        })
356    }
357
358    fn check_sliding_window(
359        &self,
360        entry: &mut RateLimitEntry,
361        now: Instant,
362    ) -> Result<RateLimitInfo, RateLimitExceeded> {
363        // Remove expired timestamps
364        let window_start = now - self.config.window;
365        entry.timestamps.retain(|&t| t > window_start);
366
367        if entry.timestamps.len() as u64 >= self.config.max_requests {
368            let oldest = entry.timestamps.first().copied().unwrap_or(now);
369            let retry_after = (oldest + self.config.window)
370                .saturating_duration_since(now)
371                .as_secs();
372            return Err(RateLimitExceeded {
373                retry_after,
374                message: "Rate limit exceeded".to_string(),
375            });
376        }
377
378        entry.timestamps.push(now);
379
380        Ok(RateLimitInfo {
381            limit: self.config.max_requests,
382            remaining: self
383                .config
384                .max_requests
385                .saturating_sub(entry.timestamps.len() as u64),
386            reset: self.config.window.as_secs(),
387        })
388    }
389
390    fn check_token_bucket(
391        &self,
392        entry: &mut RateLimitEntry,
393        now: Instant,
394    ) -> Result<RateLimitInfo, RateLimitExceeded> {
395        // Refill tokens based on time elapsed
396        let elapsed = now.duration_since(entry.last_refill).as_secs_f64();
397        let refill_rate = self.config.max_requests as f64 / self.config.window.as_secs_f64();
398        let new_tokens = elapsed * refill_rate;
399
400        entry.tokens = (entry.tokens + new_tokens).min(self.config.burst_size as f64);
401        entry.last_refill = now;
402
403        if entry.tokens < 1.0 {
404            let tokens_needed = 1.0 - entry.tokens;
405            let retry_after = (tokens_needed / refill_rate).ceil() as u64;
406            return Err(RateLimitExceeded {
407                retry_after,
408                message: "Rate limit exceeded".to_string(),
409            });
410        }
411
412        entry.tokens -= 1.0;
413
414        Ok(RateLimitInfo {
415            limit: self.config.max_requests,
416            remaining: entry.tokens as u64,
417            reset: self.config.window.as_secs(),
418        })
419    }
420
421    /// Clean up expired entries (call periodically).
422    pub async fn cleanup(&self) {
423        let mut entries = self.entries.write().await;
424        let now = Instant::now();
425        let window = self.config.window * 2; // Keep entries for 2x window
426
427        entries.retain(|_, entry| now.duration_since(entry.window_start) < window);
428    }
429}
430
431/// Rate limit information for headers.
432#[derive(Debug, Clone)]
433pub struct RateLimitInfo {
434    /// Maximum requests allowed
435    pub limit: u64,
436    /// Remaining requests in current window
437    pub remaining: u64,
438    /// Seconds until rate limit resets
439    pub reset: u64,
440}
441
442/// Rate limiter middleware.
443#[derive(Clone)]
444pub struct RateLimiter {
445    state: RateLimiterState,
446}
447
448impl RateLimiter {
449    /// Create a new rate limiter with the given configuration.
450    pub fn new(config: RateLimitConfig) -> Self {
451        Self {
452            state: RateLimiterState::new(config),
453        }
454    }
455
456    /// Create a rate limiter for login endpoints (strict).
457    pub fn for_login() -> Self {
458        Self::new(RateLimitConfig::strict_login())
459    }
460
461    /// Create a rate limiter for API endpoints (lenient).
462    pub fn for_api() -> Self {
463        Self::new(RateLimitConfig::lenient_api())
464    }
465
466    /// Get the underlying state for manual operations.
467    pub fn state(&self) -> &RateLimiterState {
468        &self.state
469    }
470}
471
472impl<S, B> Transform<S, ServiceRequest> for RateLimiter
473where
474    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
475    B: 'static,
476{
477    type Response = ServiceResponse<B>;
478    type Error = Error;
479    type Transform = RateLimiterMiddleware<S>;
480    type InitError = ();
481    type Future = Ready<Result<Self::Transform, Self::InitError>>;
482
483    fn new_transform(&self, service: S) -> Self::Future {
484        ok(RateLimiterMiddleware {
485            service,
486            state: self.state.clone(),
487        })
488    }
489}
490
491/// Rate limiter middleware service.
492pub struct RateLimiterMiddleware<S> {
493    service: S,
494    state: RateLimiterState,
495}
496
497impl<S, B> Service<ServiceRequest> for RateLimiterMiddleware<S>
498where
499    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
500    B: 'static,
501{
502    type Response = ServiceResponse<B>;
503    type Error = Error;
504    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
505
506    forward_ready!(service);
507
508    fn call(&self, req: ServiceRequest) -> Self::Future {
509        let state = self.state.clone();
510        let config = state.config.clone();
511
512        // Check if path is excluded
513        let path = req.path().to_string();
514        if config.excluded_paths.iter().any(|p| path.starts_with(p)) {
515            let fut = self.service.call(req);
516            return Box::pin(fut);
517        }
518
519        // Extract key
520        let key = match config.key_extractor.extract(&req) {
521            Some(k) => k,
522            None => {
523                // Can't identify client, allow request
524                let fut = self.service.call(req);
525                return Box::pin(fut);
526            }
527        };
528
529        let fut = self.service.call(req);
530        let add_headers = config.add_headers;
531        let error_handler = config.error_response.clone();
532
533        Box::pin(async move {
534            // Check rate limit
535            match state.check(&key).await {
536                Ok(info) => {
537                    let mut resp = fut.await?;
538
539                    // Add rate limit headers
540                    if add_headers {
541                        let headers = resp.headers_mut();
542                        if let Ok(v) = HeaderValue::from_str(&info.limit.to_string()) {
543                            headers.insert(HeaderName::from_static("x-ratelimit-limit"), v);
544                        }
545                        if let Ok(v) = HeaderValue::from_str(&info.remaining.to_string()) {
546                            headers.insert(HeaderName::from_static("x-ratelimit-remaining"), v);
547                        }
548                        if let Ok(v) = HeaderValue::from_str(&info.reset.to_string()) {
549                            headers.insert(HeaderName::from_static("x-ratelimit-reset"), v);
550                        }
551                    }
552
553                    Ok(resp)
554                }
555                Err(exceeded) => {
556                    // Rate limit exceeded
557                    let response = if let Some(handler) = error_handler {
558                        handler(exceeded.clone())
559                    } else {
560                        HttpResponse::build(StatusCode::TOO_MANY_REQUESTS)
561                            .insert_header(("Retry-After", exceeded.retry_after.to_string()))
562                            .insert_header(("X-RateLimit-Limit", config.max_requests.to_string()))
563                            .insert_header(("X-RateLimit-Remaining", "0"))
564                            .body(exceeded.message)
565                    };
566
567                    // We need to return the same response type
568                    // This is a workaround - the response body type doesn't match
569                    Err(actix_web::error::InternalError::from_response(
570                        std::io::Error::other("Rate limit exceeded"),
571                        response,
572                    )
573                    .into())
574                }
575            }
576        })
577    }
578}
579
580/// Builder for endpoint-specific rate limits.
581#[derive(Clone, Default)]
582pub struct RateLimitBuilder {
583    rules: Vec<(String, RateLimitConfig)>,
584    default: Option<RateLimitConfig>,
585}
586
587impl RateLimitBuilder {
588    /// Create a new rate limit builder.
589    pub fn new() -> Self {
590        Self::default()
591    }
592
593    /// Add a rate limit rule for a path pattern.
594    pub fn add_rule(mut self, pattern: &str, config: RateLimitConfig) -> Self {
595        self.rules.push((pattern.to_string(), config));
596        self
597    }
598
599    /// Set the default rate limit for unmatched paths.
600    pub fn default_limit(mut self, config: RateLimitConfig) -> Self {
601        self.default = Some(config);
602        self
603    }
604
605    /// Add strict rate limiting for login endpoints.
606    pub fn protect_login(self, path: &str) -> Self {
607        self.add_rule(path, RateLimitConfig::strict_login())
608    }
609
610    /// Add lenient rate limiting for API endpoints.
611    pub fn protect_api(self, path: &str) -> Self {
612        self.add_rule(path, RateLimitConfig::lenient_api())
613    }
614}
615
616#[cfg(test)]
617mod tests {
618    use super::*;
619
620    #[tokio::test]
621    async fn test_fixed_window_rate_limit() {
622        let config = RateLimitConfig::new()
623            .max_requests(3)
624            .window(Duration::from_secs(60));
625
626        let state = RateLimiterState::new(config);
627
628        // First 3 requests should succeed
629        assert!(state.check("test-key").await.is_ok());
630        assert!(state.check("test-key").await.is_ok());
631        assert!(state.check("test-key").await.is_ok());
632
633        // 4th request should fail
634        assert!(state.check("test-key").await.is_err());
635    }
636
637    #[tokio::test]
638    async fn test_sliding_window_rate_limit() {
639        let config = RateLimitConfig::new()
640            .max_requests(3)
641            .window(Duration::from_secs(60))
642            .algorithm(RateLimitAlgorithm::SlidingWindow);
643
644        let state = RateLimiterState::new(config);
645
646        // First 3 requests should succeed
647        assert!(state.check("test-key").await.is_ok());
648        assert!(state.check("test-key").await.is_ok());
649        assert!(state.check("test-key").await.is_ok());
650
651        // 4th request should fail
652        assert!(state.check("test-key").await.is_err());
653    }
654
655    #[tokio::test]
656    async fn test_token_bucket_rate_limit() {
657        let config = RateLimitConfig::new()
658            .max_requests(10)
659            .window(Duration::from_secs(1))
660            .burst_size(3)
661            .algorithm(RateLimitAlgorithm::TokenBucket);
662
663        let state = RateLimiterState::new(config);
664
665        // Burst of 3 should succeed
666        assert!(state.check("test-key").await.is_ok());
667        assert!(state.check("test-key").await.is_ok());
668        assert!(state.check("test-key").await.is_ok());
669
670        // 4th request should fail (burst exhausted)
671        assert!(state.check("test-key").await.is_err());
672    }
673
674    #[tokio::test]
675    async fn test_different_keys_independent() {
676        let config = RateLimitConfig::new()
677            .max_requests(2)
678            .window(Duration::from_secs(60));
679
680        let state = RateLimiterState::new(config);
681
682        // Key A
683        assert!(state.check("key-a").await.is_ok());
684        assert!(state.check("key-a").await.is_ok());
685        assert!(state.check("key-a").await.is_err());
686
687        // Key B should still have quota
688        assert!(state.check("key-b").await.is_ok());
689        assert!(state.check("key-b").await.is_ok());
690        assert!(state.check("key-b").await.is_err());
691    }
692
693    #[test]
694    fn test_rate_limit_info() {
695        let info = RateLimitInfo {
696            limit: 100,
697            remaining: 50,
698            reset: 30,
699        };
700
701        assert_eq!(info.limit, 100);
702        assert_eq!(info.remaining, 50);
703        assert_eq!(info.reset, 30);
704    }
705
706    #[test]
707    fn test_config_builder() {
708        let config = RateLimitConfig::new()
709            .requests_per_minute(60)
710            .burst_size(10)
711            .add_headers(true)
712            .exclude_paths(vec!["/health", "/metrics"]);
713
714        assert_eq!(config.max_requests, 60);
715        assert_eq!(config.burst_size, 10);
716        assert!(config.add_headers);
717        assert_eq!(config.excluded_paths.len(), 2);
718    }
719
720    #[test]
721    fn test_strict_login_config() {
722        let config = RateLimitConfig::strict_login();
723        assert_eq!(config.max_requests, 5);
724        assert_eq!(config.window, Duration::from_secs(60));
725    }
726
727    #[test]
728    fn test_lenient_api_config() {
729        let config = RateLimitConfig::lenient_api();
730        assert_eq!(config.max_requests, 1000);
731    }
732}