Skip to main content

actix_security_core/http/security/
rate_limit.rs

1//! Rate Limiting middleware for brute-force protection.
2//!
3//! Provides configurable rate limiting to protect against brute-force attacks,
4//! DDoS attempts, and API abuse.
5//!
6//! # Spring Security Equivalent
7//! Similar to Spring Security's `RateLimiter` and integration with Bucket4j.
8//!
9//! # Example
10//!
11//! ```ignore
12//! use actix_security::http::security::rate_limit::{RateLimiter, RateLimitConfig};
13//! use actix_web::{App, HttpServer};
14//!
15//! let rate_limiter = RateLimiter::new(
16//!     RateLimitConfig::new()
17//!         .requests_per_second(10)
18//!         .burst_size(20)
19//! );
20//!
21//! HttpServer::new(move || {
22//!     App::new()
23//!         .wrap(rate_limiter.clone())
24//!         .route("/api/login", web::post().to(login))
25//! })
26//! ```
27
28use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
29use actix_web::http::header::{HeaderName, HeaderValue};
30use actix_web::http::StatusCode;
31use actix_web::{Error, HttpResponse};
32use futures_util::future::{ok, LocalBoxFuture, Ready};
33use std::collections::HashMap;
34use std::sync::Arc;
35use std::time::{Duration, Instant};
36use tokio::sync::RwLock;
37
38/// Rate limit exceeded error response.
39#[derive(Debug, Clone)]
40pub struct RateLimitExceeded {
41    /// Retry after seconds
42    pub retry_after: u64,
43    /// Custom message
44    pub message: String,
45}
46
47impl Default for RateLimitExceeded {
48    fn default() -> Self {
49        Self {
50            retry_after: 60,
51            message: "Too many requests. Please try again later.".to_string(),
52        }
53    }
54}
55
56/// Type alias for custom key extractor function.
57pub type KeyExtractorFn = Arc<dyn Fn(&ServiceRequest) -> Option<String> + Send + Sync>;
58
59/// Strategy for identifying clients for rate limiting.
60#[derive(Clone, Default)]
61pub enum KeyExtractor {
62    /// Rate limit by IP address (default)
63    #[default]
64    IpAddress,
65    /// Rate limit by authenticated user
66    User,
67    /// Rate limit by custom header
68    Header(String),
69    /// Rate limit by IP + endpoint combination
70    IpAndEndpoint,
71    /// Custom key extractor function
72    Custom(KeyExtractorFn),
73}
74
75impl std::fmt::Debug for KeyExtractor {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        match self {
78            KeyExtractor::IpAddress => write!(f, "IpAddress"),
79            KeyExtractor::User => write!(f, "User"),
80            KeyExtractor::Header(h) => write!(f, "Header({})", h),
81            KeyExtractor::IpAndEndpoint => write!(f, "IpAndEndpoint"),
82            KeyExtractor::Custom(_) => write!(f, "Custom(<fn>)"),
83        }
84    }
85}
86
87impl KeyExtractor {
88    /// Extract the rate limit key from a request.
89    pub fn extract(&self, req: &ServiceRequest) -> Option<String> {
90        match self {
91            KeyExtractor::IpAddress => req
92                .connection_info()
93                .realip_remote_addr()
94                .map(|s| s.to_string()),
95            KeyExtractor::User => req
96                .headers()
97                .get("Authorization")
98                .and_then(|h| h.to_str().ok())
99                .map(|s| s.to_string()),
100            KeyExtractor::Header(name) => req
101                .headers()
102                .get(name.as_str())
103                .and_then(|h| h.to_str().ok())
104                .map(|s| s.to_string()),
105            KeyExtractor::IpAndEndpoint => {
106                let ip = req.connection_info().realip_remote_addr()?.to_string();
107                let path = req.path().to_string();
108                Some(format!("{}:{}", ip, path))
109            }
110            KeyExtractor::Custom(f) => f(req),
111        }
112    }
113}
114
115/// Rate limiting algorithm.
116#[derive(Debug, Clone, Default)]
117pub enum RateLimitAlgorithm {
118    /// Fixed window counter (simpler, less memory)
119    #[default]
120    FixedWindow,
121    /// Sliding window log (more accurate, more memory)
122    SlidingWindow,
123    /// Token bucket (smooth rate limiting)
124    TokenBucket,
125}
126
127/// Rate limit configuration.
128#[derive(Clone)]
129pub struct RateLimitConfig {
130    /// Maximum requests per window
131    pub max_requests: u64,
132    /// Time window duration
133    pub window: Duration,
134    /// Burst size (for token bucket)
135    pub burst_size: u64,
136    /// Algorithm to use
137    pub algorithm: RateLimitAlgorithm,
138    /// Key extractor
139    pub key_extractor: KeyExtractor,
140    /// Paths to exclude from rate limiting
141    pub excluded_paths: Vec<String>,
142    /// Whether to add rate limit headers to response
143    pub add_headers: bool,
144    /// Custom error response
145    pub error_response: Option<Arc<dyn Fn(RateLimitExceeded) -> HttpResponse + Send + Sync>>,
146}
147
148impl std::fmt::Debug for RateLimitConfig {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        f.debug_struct("RateLimitConfig")
151            .field("max_requests", &self.max_requests)
152            .field("window", &self.window)
153            .field("burst_size", &self.burst_size)
154            .field("algorithm", &self.algorithm)
155            .field("key_extractor", &self.key_extractor)
156            .field("excluded_paths", &self.excluded_paths)
157            .field("add_headers", &self.add_headers)
158            .field("error_response", &self.error_response.as_ref().map(|_| "<fn>"))
159            .finish()
160    }
161}
162
163impl Default for RateLimitConfig {
164    fn default() -> Self {
165        Self {
166            max_requests: 100,
167            window: Duration::from_secs(60),
168            burst_size: 10,
169            algorithm: RateLimitAlgorithm::default(),
170            key_extractor: KeyExtractor::default(),
171            excluded_paths: vec![],
172            add_headers: true,
173            error_response: None,
174        }
175    }
176}
177
178impl RateLimitConfig {
179    /// Create a new rate limit configuration.
180    pub fn new() -> Self {
181        Self::default()
182    }
183
184    /// Set maximum requests per window.
185    pub fn max_requests(mut self, max: u64) -> Self {
186        self.max_requests = max;
187        self
188    }
189
190    /// Set requests per second (convenience method).
191    pub fn requests_per_second(mut self, rps: u64) -> Self {
192        self.max_requests = rps;
193        self.window = Duration::from_secs(1);
194        self
195    }
196
197    /// Set requests per minute.
198    pub fn requests_per_minute(mut self, rpm: u64) -> Self {
199        self.max_requests = rpm;
200        self.window = Duration::from_secs(60);
201        self
202    }
203
204    /// Set time window.
205    pub fn window(mut self, window: Duration) -> Self {
206        self.window = window;
207        self
208    }
209
210    /// Set burst size for token bucket algorithm.
211    pub fn burst_size(mut self, size: u64) -> Self {
212        self.burst_size = size;
213        self
214    }
215
216    /// Set the rate limiting algorithm.
217    pub fn algorithm(mut self, algo: RateLimitAlgorithm) -> Self {
218        self.algorithm = algo;
219        self
220    }
221
222    /// Set the key extractor.
223    pub fn key_extractor(mut self, extractor: KeyExtractor) -> Self {
224        self.key_extractor = extractor;
225        self
226    }
227
228    /// Exclude paths from rate limiting.
229    pub fn exclude_paths(mut self, paths: Vec<&str>) -> Self {
230        self.excluded_paths = paths.into_iter().map(String::from).collect();
231        self
232    }
233
234    /// Whether to add rate limit headers.
235    pub fn add_headers(mut self, add: bool) -> Self {
236        self.add_headers = add;
237        self
238    }
239
240    /// Set custom error response handler.
241    pub fn error_response<F>(mut self, handler: F) -> Self
242    where
243        F: Fn(RateLimitExceeded) -> HttpResponse + Send + Sync + 'static,
244    {
245        self.error_response = Some(Arc::new(handler));
246        self
247    }
248
249    /// Create a strict configuration for login endpoints.
250    pub fn strict_login() -> Self {
251        Self::new()
252            .requests_per_minute(5)
253            .burst_size(3)
254            .algorithm(RateLimitAlgorithm::SlidingWindow)
255    }
256
257    /// Create a lenient configuration for API endpoints.
258    pub fn lenient_api() -> Self {
259        Self::new()
260            .requests_per_minute(1000)
261            .burst_size(100)
262            .algorithm(RateLimitAlgorithm::TokenBucket)
263    }
264}
265
266/// Rate limit entry for tracking requests.
267#[derive(Debug, Clone)]
268struct RateLimitEntry {
269    /// Request count in current window
270    count: u64,
271    /// Window start time
272    window_start: Instant,
273    /// Request timestamps for sliding window
274    timestamps: Vec<Instant>,
275    /// Available tokens for token bucket
276    tokens: f64,
277    /// Last token refill time
278    last_refill: Instant,
279}
280
281impl RateLimitEntry {
282    fn new(config: &RateLimitConfig) -> Self {
283        Self {
284            count: 0,
285            window_start: Instant::now(),
286            timestamps: Vec::new(),
287            tokens: config.burst_size as f64,
288            last_refill: Instant::now(),
289        }
290    }
291}
292
293/// Rate limiter state.
294#[derive(Clone)]
295pub struct RateLimiterState {
296    entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
297    config: RateLimitConfig,
298}
299
300impl RateLimiterState {
301    /// Create new rate limiter state.
302    pub fn new(config: RateLimitConfig) -> Self {
303        Self {
304            entries: Arc::new(RwLock::new(HashMap::new())),
305            config,
306        }
307    }
308
309    /// Check if a request should be allowed.
310    pub async fn check(&self, key: &str) -> Result<RateLimitInfo, RateLimitExceeded> {
311        let mut entries = self.entries.write().await;
312        let now = Instant::now();
313
314        let entry = entries
315            .entry(key.to_string())
316            .or_insert_with(|| RateLimitEntry::new(&self.config));
317
318        match self.config.algorithm {
319            RateLimitAlgorithm::FixedWindow => self.check_fixed_window(entry, now),
320            RateLimitAlgorithm::SlidingWindow => self.check_sliding_window(entry, now),
321            RateLimitAlgorithm::TokenBucket => self.check_token_bucket(entry, now),
322        }
323    }
324
325    fn check_fixed_window(
326        &self,
327        entry: &mut RateLimitEntry,
328        now: Instant,
329    ) -> Result<RateLimitInfo, RateLimitExceeded> {
330        // Reset window if expired
331        if now.duration_since(entry.window_start) >= self.config.window {
332            entry.count = 0;
333            entry.window_start = now;
334        }
335
336        if entry.count >= self.config.max_requests {
337            let reset_time = entry.window_start + self.config.window;
338            let retry_after = reset_time.saturating_duration_since(now).as_secs();
339            return Err(RateLimitExceeded {
340                retry_after,
341                message: "Rate limit exceeded".to_string(),
342            });
343        }
344
345        entry.count += 1;
346
347        let reset_time = entry.window_start + self.config.window;
348        Ok(RateLimitInfo {
349            limit: self.config.max_requests,
350            remaining: self.config.max_requests.saturating_sub(entry.count),
351            reset: reset_time.saturating_duration_since(now).as_secs(),
352        })
353    }
354
355    fn check_sliding_window(
356        &self,
357        entry: &mut RateLimitEntry,
358        now: Instant,
359    ) -> Result<RateLimitInfo, RateLimitExceeded> {
360        // Remove expired timestamps
361        let window_start = now - self.config.window;
362        entry.timestamps.retain(|&t| t > window_start);
363
364        if entry.timestamps.len() as u64 >= self.config.max_requests {
365            let oldest = entry.timestamps.first().copied().unwrap_or(now);
366            let retry_after = (oldest + self.config.window)
367                .saturating_duration_since(now)
368                .as_secs();
369            return Err(RateLimitExceeded {
370                retry_after,
371                message: "Rate limit exceeded".to_string(),
372            });
373        }
374
375        entry.timestamps.push(now);
376
377        Ok(RateLimitInfo {
378            limit: self.config.max_requests,
379            remaining: self.config.max_requests.saturating_sub(entry.timestamps.len() as u64),
380            reset: self.config.window.as_secs(),
381        })
382    }
383
384    fn check_token_bucket(
385        &self,
386        entry: &mut RateLimitEntry,
387        now: Instant,
388    ) -> Result<RateLimitInfo, RateLimitExceeded> {
389        // Refill tokens based on time elapsed
390        let elapsed = now.duration_since(entry.last_refill).as_secs_f64();
391        let refill_rate = self.config.max_requests as f64 / self.config.window.as_secs_f64();
392        let new_tokens = elapsed * refill_rate;
393
394        entry.tokens = (entry.tokens + new_tokens).min(self.config.burst_size as f64);
395        entry.last_refill = now;
396
397        if entry.tokens < 1.0 {
398            let tokens_needed = 1.0 - entry.tokens;
399            let retry_after = (tokens_needed / refill_rate).ceil() as u64;
400            return Err(RateLimitExceeded {
401                retry_after,
402                message: "Rate limit exceeded".to_string(),
403            });
404        }
405
406        entry.tokens -= 1.0;
407
408        Ok(RateLimitInfo {
409            limit: self.config.max_requests,
410            remaining: entry.tokens as u64,
411            reset: self.config.window.as_secs(),
412        })
413    }
414
415    /// Clean up expired entries (call periodically).
416    pub async fn cleanup(&self) {
417        let mut entries = self.entries.write().await;
418        let now = Instant::now();
419        let window = self.config.window * 2; // Keep entries for 2x window
420
421        entries.retain(|_, entry| now.duration_since(entry.window_start) < window);
422    }
423}
424
425/// Rate limit information for headers.
426#[derive(Debug, Clone)]
427pub struct RateLimitInfo {
428    /// Maximum requests allowed
429    pub limit: u64,
430    /// Remaining requests in current window
431    pub remaining: u64,
432    /// Seconds until rate limit resets
433    pub reset: u64,
434}
435
436/// Rate limiter middleware.
437#[derive(Clone)]
438pub struct RateLimiter {
439    state: RateLimiterState,
440}
441
442impl RateLimiter {
443    /// Create a new rate limiter with the given configuration.
444    pub fn new(config: RateLimitConfig) -> Self {
445        Self {
446            state: RateLimiterState::new(config),
447        }
448    }
449
450    /// Create a rate limiter for login endpoints (strict).
451    pub fn for_login() -> Self {
452        Self::new(RateLimitConfig::strict_login())
453    }
454
455    /// Create a rate limiter for API endpoints (lenient).
456    pub fn for_api() -> Self {
457        Self::new(RateLimitConfig::lenient_api())
458    }
459
460    /// Get the underlying state for manual operations.
461    pub fn state(&self) -> &RateLimiterState {
462        &self.state
463    }
464}
465
466impl<S, B> Transform<S, ServiceRequest> for RateLimiter
467where
468    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
469    B: 'static,
470{
471    type Response = ServiceResponse<B>;
472    type Error = Error;
473    type Transform = RateLimiterMiddleware<S>;
474    type InitError = ();
475    type Future = Ready<Result<Self::Transform, Self::InitError>>;
476
477    fn new_transform(&self, service: S) -> Self::Future {
478        ok(RateLimiterMiddleware {
479            service,
480            state: self.state.clone(),
481        })
482    }
483}
484
485/// Rate limiter middleware service.
486pub struct RateLimiterMiddleware<S> {
487    service: S,
488    state: RateLimiterState,
489}
490
491impl<S, B> Service<ServiceRequest> for RateLimiterMiddleware<S>
492where
493    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
494    B: 'static,
495{
496    type Response = ServiceResponse<B>;
497    type Error = Error;
498    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
499
500    forward_ready!(service);
501
502    fn call(&self, req: ServiceRequest) -> Self::Future {
503        let state = self.state.clone();
504        let config = state.config.clone();
505
506        // Check if path is excluded
507        let path = req.path().to_string();
508        if config.excluded_paths.iter().any(|p| path.starts_with(p)) {
509            let fut = self.service.call(req);
510            return Box::pin(fut);
511        }
512
513        // Extract key
514        let key = match config.key_extractor.extract(&req) {
515            Some(k) => k,
516            None => {
517                // Can't identify client, allow request
518                let fut = self.service.call(req);
519                return Box::pin(fut);
520            }
521        };
522
523        let fut = self.service.call(req);
524        let add_headers = config.add_headers;
525        let error_handler = config.error_response.clone();
526
527        Box::pin(async move {
528            // Check rate limit
529            match state.check(&key).await {
530                Ok(info) => {
531                    let mut resp = fut.await?;
532
533                    // Add rate limit headers
534                    if add_headers {
535                        let headers = resp.headers_mut();
536                        if let Ok(v) = HeaderValue::from_str(&info.limit.to_string()) {
537                            headers.insert(
538                                HeaderName::from_static("x-ratelimit-limit"),
539                                v,
540                            );
541                        }
542                        if let Ok(v) = HeaderValue::from_str(&info.remaining.to_string()) {
543                            headers.insert(
544                                HeaderName::from_static("x-ratelimit-remaining"),
545                                v,
546                            );
547                        }
548                        if let Ok(v) = HeaderValue::from_str(&info.reset.to_string()) {
549                            headers.insert(
550                                HeaderName::from_static("x-ratelimit-reset"),
551                                v,
552                            );
553                        }
554                    }
555
556                    Ok(resp)
557                }
558                Err(exceeded) => {
559                    // Rate limit exceeded
560                    let response = if let Some(handler) = error_handler {
561                        handler(exceeded.clone())
562                    } else {
563                        HttpResponse::build(StatusCode::TOO_MANY_REQUESTS)
564                            .insert_header(("Retry-After", exceeded.retry_after.to_string()))
565                            .insert_header(("X-RateLimit-Limit", config.max_requests.to_string()))
566                            .insert_header(("X-RateLimit-Remaining", "0"))
567                            .body(exceeded.message)
568                    };
569
570                    // We need to return the same response type
571                    // This is a workaround - the response body type doesn't match
572                    Err(actix_web::error::InternalError::from_response(
573                        std::io::Error::new(std::io::ErrorKind::Other, "Rate limit exceeded"),
574                        response,
575                    )
576                    .into())
577                }
578            }
579        })
580    }
581}
582
583/// Builder for endpoint-specific rate limits.
584#[derive(Clone, Default)]
585pub struct RateLimitBuilder {
586    rules: Vec<(String, RateLimitConfig)>,
587    default: Option<RateLimitConfig>,
588}
589
590impl RateLimitBuilder {
591    /// Create a new rate limit builder.
592    pub fn new() -> Self {
593        Self::default()
594    }
595
596    /// Add a rate limit rule for a path pattern.
597    pub fn add_rule(mut self, pattern: &str, config: RateLimitConfig) -> Self {
598        self.rules.push((pattern.to_string(), config));
599        self
600    }
601
602    /// Set the default rate limit for unmatched paths.
603    pub fn default_limit(mut self, config: RateLimitConfig) -> Self {
604        self.default = Some(config);
605        self
606    }
607
608    /// Add strict rate limiting for login endpoints.
609    pub fn protect_login(self, path: &str) -> Self {
610        self.add_rule(path, RateLimitConfig::strict_login())
611    }
612
613    /// Add lenient rate limiting for API endpoints.
614    pub fn protect_api(self, path: &str) -> Self {
615        self.add_rule(path, RateLimitConfig::lenient_api())
616    }
617}
618
619#[cfg(test)]
620mod tests {
621    use super::*;
622
623    #[tokio::test]
624    async fn test_fixed_window_rate_limit() {
625        let config = RateLimitConfig::new().max_requests(3).window(Duration::from_secs(60));
626
627        let state = RateLimiterState::new(config);
628
629        // First 3 requests should succeed
630        assert!(state.check("test-key").await.is_ok());
631        assert!(state.check("test-key").await.is_ok());
632        assert!(state.check("test-key").await.is_ok());
633
634        // 4th request should fail
635        assert!(state.check("test-key").await.is_err());
636    }
637
638    #[tokio::test]
639    async fn test_sliding_window_rate_limit() {
640        let config = RateLimitConfig::new()
641            .max_requests(3)
642            .window(Duration::from_secs(60))
643            .algorithm(RateLimitAlgorithm::SlidingWindow);
644
645        let state = RateLimiterState::new(config);
646
647        // First 3 requests should succeed
648        assert!(state.check("test-key").await.is_ok());
649        assert!(state.check("test-key").await.is_ok());
650        assert!(state.check("test-key").await.is_ok());
651
652        // 4th request should fail
653        assert!(state.check("test-key").await.is_err());
654    }
655
656    #[tokio::test]
657    async fn test_token_bucket_rate_limit() {
658        let config = RateLimitConfig::new()
659            .max_requests(10)
660            .window(Duration::from_secs(1))
661            .burst_size(3)
662            .algorithm(RateLimitAlgorithm::TokenBucket);
663
664        let state = RateLimiterState::new(config);
665
666        // Burst of 3 should succeed
667        assert!(state.check("test-key").await.is_ok());
668        assert!(state.check("test-key").await.is_ok());
669        assert!(state.check("test-key").await.is_ok());
670
671        // 4th request should fail (burst exhausted)
672        assert!(state.check("test-key").await.is_err());
673    }
674
675    #[tokio::test]
676    async fn test_different_keys_independent() {
677        let config = RateLimitConfig::new().max_requests(2).window(Duration::from_secs(60));
678
679        let state = RateLimiterState::new(config);
680
681        // Key A
682        assert!(state.check("key-a").await.is_ok());
683        assert!(state.check("key-a").await.is_ok());
684        assert!(state.check("key-a").await.is_err());
685
686        // Key B should still have quota
687        assert!(state.check("key-b").await.is_ok());
688        assert!(state.check("key-b").await.is_ok());
689        assert!(state.check("key-b").await.is_err());
690    }
691
692    #[test]
693    fn test_rate_limit_info() {
694        let info = RateLimitInfo {
695            limit: 100,
696            remaining: 50,
697            reset: 30,
698        };
699
700        assert_eq!(info.limit, 100);
701        assert_eq!(info.remaining, 50);
702        assert_eq!(info.reset, 30);
703    }
704
705    #[test]
706    fn test_config_builder() {
707        let config = RateLimitConfig::new()
708            .requests_per_minute(60)
709            .burst_size(10)
710            .add_headers(true)
711            .exclude_paths(vec!["/health", "/metrics"]);
712
713        assert_eq!(config.max_requests, 60);
714        assert_eq!(config.burst_size, 10);
715        assert!(config.add_headers);
716        assert_eq!(config.excluded_paths.len(), 2);
717    }
718
719    #[test]
720    fn test_strict_login_config() {
721        let config = RateLimitConfig::strict_login();
722        assert_eq!(config.max_requests, 5);
723        assert_eq!(config.window, Duration::from_secs(60));
724    }
725
726    #[test]
727    fn test_lenient_api_config() {
728        let config = RateLimitConfig::lenient_api();
729        assert_eq!(config.max_requests, 1000);
730    }
731}