acton_htmx/middleware/
rate_limit.rs

1//! Rate limiting middleware
2//!
3//! Provides request rate limiting with support for both Redis-backed (distributed)
4//! and in-memory (single-instance) backends. Rate limits can be configured per
5//! authenticated user, per IP address, and per specific route patterns.
6//!
7//! # Features
8//!
9//! - **Multiple Identifiers**: Rate limit by user ID (authenticated), IP address (anonymous), or both
10//! - **Route-Specific Limits**: Apply stricter limits to sensitive endpoints (e.g., `/login`, `/register`)
11//! - **Redis Backend**: Distributed rate limiting for multi-instance deployments (requires `cache` feature)
12//! - **In-Memory Fallback**: Automatic fallback to in-memory rate limiting if Redis is unavailable
13//! - **Failure Modes**: Configurable behavior on backend errors (fail-open or fail-closed)
14//! - **Sliding Window**: Uses sliding window algorithm for accurate rate limiting
15//!
16//! # Example
17//!
18//! ```rust,no_run
19//! use acton_htmx::middleware::rate_limit::RateLimit;
20//! use acton_htmx::config::RateLimitConfig;
21//! use axum::{Router, routing::get};
22//!
23//! # async fn example() -> anyhow::Result<()> {
24//! let config = RateLimitConfig::default();
25//!
26//! #[cfg(feature = "redis")]
27//! let redis_pool = acton_htmx::database::redis::create_pool("redis://localhost:6379").await?;
28//!
29//! #[cfg(feature = "redis")]
30//! let rate_limit = RateLimit::new(config, Some(redis_pool));
31//!
32//! #[cfg(not(feature = "redis"))]
33//! let rate_limit = RateLimit::new(config, None);
34//!
35//! let app = Router::new()
36//!     .route("/", get(|| async { "Hello" }))
37//!     .layer(axum::middleware::from_fn_with_state(
38//!         rate_limit,
39//!         RateLimit::middleware,
40//!     ));
41//! # Ok(())
42//! # }
43//! ```
44
45#[cfg(feature = "redis")]
46use deadpool_redis::Pool as RedisPool;
47
48use axum::{
49    extract::{ConnectInfo, Request, State},
50    http::StatusCode,
51    middleware::Next,
52    response::{IntoResponse, Response},
53};
54use std::collections::HashMap;
55use std::net::SocketAddr;
56use std::sync::Arc;
57use std::time::{Duration, Instant};
58use tokio::sync::RwLock;
59use tracing::{debug, warn};
60
61use crate::config::RateLimitConfig;
62
63/// In-memory rate limit entry
64#[derive(Debug, Clone)]
65struct RateLimitEntry {
66    /// Request count in current window
67    count: u32,
68    /// Window start time
69    window_start: Instant,
70}
71
72/// In-memory rate limit store
73type InMemoryStore = Arc<RwLock<HashMap<String, RateLimitEntry>>>;
74
75/// Rate limiting middleware
76///
77/// Enforces configurable rate limits per user, IP address, and route.
78/// Supports both Redis-backed (distributed) and in-memory (single-instance) storage.
79#[derive(Clone)]
80pub struct RateLimit {
81    config: RateLimitConfig,
82    #[cfg(feature = "redis")]
83    redis_pool: Option<RedisPool>,
84    in_memory_store: InMemoryStore,
85}
86
87impl RateLimit {
88    /// Create a new rate limiting middleware
89    ///
90    /// # Arguments
91    ///
92    /// * `config` - Rate limit configuration
93    /// * `redis_pool` - Optional Redis pool for distributed rate limiting (requires `redis` feature)
94    ///
95    /// # Example
96    ///
97    /// ```rust,no_run
98    /// use acton_htmx::middleware::rate_limit::RateLimit;
99    /// use acton_htmx::config::RateLimitConfig;
100    ///
101    /// # async fn example() -> anyhow::Result<()> {
102    /// let config = RateLimitConfig::default();
103    ///
104    /// #[cfg(feature = "redis")]
105    /// let redis_pool = acton_htmx::database::redis::create_pool("redis://localhost:6379").await?;
106    ///
107    /// #[cfg(feature = "redis")]
108    /// let rate_limit = RateLimit::new(config, Some(redis_pool));
109    ///
110    /// #[cfg(not(feature = "redis"))]
111    /// let rate_limit = RateLimit::new(config, None);
112    /// # Ok(())
113    /// # }
114    /// ```
115    #[must_use]
116    #[cfg(feature = "redis")]
117    pub fn new(config: RateLimitConfig, redis_pool: Option<RedisPool>) -> Self {
118        Self {
119            config,
120            redis_pool,
121            in_memory_store: Arc::new(RwLock::new(HashMap::new())),
122        }
123    }
124
125    /// Create a new rate limiting middleware without Redis
126    #[must_use]
127    #[cfg(not(feature = "redis"))]
128    pub fn new(config: RateLimitConfig, _redis_pool: Option<()>) -> Self {
129        Self {
130            config,
131            in_memory_store: Arc::new(RwLock::new(HashMap::new())),
132        }
133    }
134
135    /// Middleware function to enforce rate limits
136    ///
137    /// This middleware:
138    /// 1. Extracts user ID from session (if authenticated) or IP address
139    /// 2. Checks if request path matches strict route patterns
140    /// 3. Applies appropriate rate limit (per-user, per-IP, or per-route)
141    /// 4. Returns 429 Too Many Requests if limit exceeded
142    ///
143    /// # Errors
144    ///
145    /// Returns [`RateLimitError::TooManyRequests`] if the rate limit is exceeded.
146    /// Returns [`RateLimitError::Redis`] if Redis operations fail (when using Redis feature).
147    pub async fn middleware(
148        State(rate_limit): State<Self>,
149        request: Request,
150        next: Next,
151    ) -> Result<Response, RateLimitError> {
152        // Skip if rate limiting is disabled
153        if !rate_limit.config.enabled {
154            return Ok(next.run(request).await);
155        }
156
157        // Extract user ID from request extensions (set by session middleware)
158        let user_id: Option<i64> = request.extensions().get::<i64>().copied();
159
160        // Extract IP address from connection info
161        let ip_addr = request
162            .extensions()
163            .get::<ConnectInfo<SocketAddr>>()
164            .map(|ConnectInfo(addr)| addr.ip().to_string());
165
166        // Determine rate limit key and limit
167        let path = request.uri().path();
168        let (key, limit) = rate_limit.determine_key_and_limit(user_id, ip_addr.as_deref(), path);
169
170        debug!(
171            key = %key,
172            limit = limit,
173            path = %path,
174            user_id = ?user_id,
175            "Checking rate limit"
176        );
177
178        // Check rate limit
179        rate_limit.check_rate_limit(&key, limit).await?;
180
181        Ok(next.run(request).await)
182    }
183
184    /// Determine rate limit key and limit based on user, IP, and path
185    fn determine_key_and_limit(
186        &self,
187        user_id: Option<i64>,
188        ip_addr: Option<&str>,
189        path: &str,
190    ) -> (String, u32) {
191        // Check if path matches strict routes
192        let is_strict_route = self
193            .config
194            .strict_routes
195            .iter()
196            .any(|route| path.starts_with(route));
197
198        if is_strict_route {
199            // Use stricter per-route limit
200            let key = user_id.map_or_else(|| {
201                ip_addr.map_or_else(|| "ratelimit:route:unknown".to_string(), |ip| format!("ratelimit:route:ip:{ip}"))
202            }, |uid| format!("ratelimit:route:user:{uid}"));
203            (key, self.config.per_route_rpm)
204        } else if let Some(uid) = user_id {
205            // Authenticated user
206            (
207                format!("ratelimit:user:{uid}"),
208                self.config.per_user_rpm,
209            )
210        } else if let Some(ip) = ip_addr {
211            // Anonymous by IP
212            (format!("ratelimit:ip:{ip}"), self.config.per_ip_rpm)
213        } else {
214            // Fallback
215            ("ratelimit:unknown".to_string(), self.config.per_ip_rpm)
216        }
217    }
218
219    /// Check rate limit for a key
220    async fn check_rate_limit(&self, key: &str, limit: u32) -> Result<(), RateLimitError> {
221        // Try Redis first if enabled
222        #[cfg(feature = "redis")]
223        if self.config.redis_enabled {
224            if let Some(ref redis_pool) = self.redis_pool {
225                match self.check_rate_limit_redis(redis_pool, key, limit).await {
226                    Ok(()) => return Ok(()),
227                    Err(e) => {
228                        warn!(
229                            error = %e,
230                            key = %key,
231                            "Redis rate limit check failed, falling back to in-memory"
232                        );
233                        // Fall through to in-memory
234                    }
235                }
236            }
237        }
238
239        // Use in-memory rate limiting
240        self.check_rate_limit_memory(key, limit).await
241    }
242
243    /// Check rate limit using Redis backend
244    #[cfg(feature = "redis")]
245    async fn check_rate_limit_redis(
246        &self,
247        redis_pool: &RedisPool,
248        key: &str,
249        limit: u32,
250    ) -> Result<(), RateLimitError> {
251        let mut conn = redis_pool.get().await.map_err(|e| {
252            RateLimitError::Backend(format!("Failed to get Redis connection: {e}"))
253        })?;
254
255        // Use INCR and EXPIRE for sliding window rate limiting
256        let count: u32 = redis::cmd("INCR")
257            .arg(key)
258            .query_async(&mut *conn)
259            .await
260            .map_err(|e| RateLimitError::Backend(format!("Redis INCR failed: {e}")))?;
261
262        // Set expiration on first request
263        if count == 1 {
264            // Convert window_secs to i64, saturating at i64::MAX to avoid wrapping
265            let expire_secs = i64::try_from(self.config.window_secs).unwrap_or(i64::MAX);
266            let _: () = redis::cmd("EXPIRE")
267                .arg(key)
268                .arg(expire_secs)
269                .query_async(&mut *conn)
270                .await
271                .map_err(|e| RateLimitError::Backend(format!("Redis EXPIRE failed: {e}")))?;
272        }
273
274        // Check if limit exceeded
275        if count > limit {
276            warn!(
277                key = %key,
278                count = count,
279                limit = limit,
280                window_secs = self.config.window_secs,
281                "Rate limit exceeded"
282            );
283            return Err(RateLimitError::Exceeded {
284                limit,
285                window: Duration::from_secs(self.config.window_secs),
286            });
287        }
288
289        debug!(
290            key = %key,
291            count = count,
292            limit = limit,
293            "Rate limit check passed (Redis)"
294        );
295
296        Ok(())
297    }
298
299    /// Check rate limit using in-memory backend
300    async fn check_rate_limit_memory(&self, key: &str, limit: u32) -> Result<(), RateLimitError> {
301        let now = Instant::now();
302        let window_duration = Duration::from_secs(self.config.window_secs);
303
304        // Acquire lock, update entry, extract count, then immediately release lock
305        let mut store = self.in_memory_store.write().await;
306
307        // Get or create entry
308        let entry = store.entry(key.to_string()).or_insert_with(|| RateLimitEntry {
309            count: 0,
310            window_start: now,
311        });
312
313        // Check if window has expired
314        if now.duration_since(entry.window_start) >= window_duration {
315            // Reset window
316            entry.count = 1;
317            entry.window_start = now;
318        } else {
319            // Increment count
320            entry.count += 1;
321        }
322
323        let count = entry.count;
324        drop(store); // Explicitly drop the lock before any logging or error handling
325
326        // Check if limit exceeded (after releasing lock)
327        if count > limit {
328            warn!(
329                key = %key,
330                count = count,
331                limit = limit,
332                window_secs = self.config.window_secs,
333                "Rate limit exceeded"
334            );
335            return Err(RateLimitError::Exceeded {
336                limit,
337                window: window_duration,
338            });
339        }
340
341        debug!(
342            key = %key,
343            count = count,
344            limit = limit,
345            "Rate limit check passed (in-memory)"
346        );
347
348        Ok(())
349    }
350
351    /// Cleanup expired entries from in-memory store
352    ///
353    /// Should be called periodically to prevent memory leaks.
354    /// Returns the number of entries removed.
355    pub async fn cleanup_expired(&self) -> usize {
356        let now = Instant::now();
357        let window_duration = Duration::from_secs(self.config.window_secs);
358
359        let removed = {
360            let mut store = self.in_memory_store.write().await;
361            let before_count = store.len();
362
363            store.retain(|_, entry| now.duration_since(entry.window_start) < window_duration);
364
365            before_count - store.len()
366        }; // Drop the write lock here
367
368        if removed > 0 {
369            debug!(removed = removed, "Cleaned up expired rate limit entries");
370        }
371
372        removed
373    }
374}
375
376/// Rate limit errors
377#[derive(Debug, thiserror::Error)]
378pub enum RateLimitError {
379    /// Rate limit exceeded
380    #[error("Rate limit exceeded: {limit} requests per {window:?}")]
381    Exceeded {
382        /// Maximum requests allowed
383        limit: u32,
384        /// Time window
385        window: Duration,
386    },
387
388    /// Backend error (Redis, etc.)
389    #[error("Rate limit backend error: {0}")]
390    Backend(String),
391}
392
393impl IntoResponse for RateLimitError {
394    fn into_response(self) -> Response {
395        match self {
396            Self::Exceeded { limit, window } => {
397                let retry_after = window.as_secs();
398                (
399                    StatusCode::TOO_MANY_REQUESTS,
400                    [
401                        ("Retry-After", retry_after.to_string()),
402                        (
403                            "X-RateLimit-Limit",
404                            limit.to_string(),
405                        ),
406                    ],
407                    format!(
408                        "Rate limit exceeded. Maximum {} requests per {} seconds.",
409                        limit,
410                        window.as_secs()
411                    ),
412                )
413                    .into_response()
414            }
415            Self::Backend(msg) => {
416                warn!(error = %msg, "Rate limit backend error");
417                // Return 500 for backend errors in fail-closed mode
418                // (fail-open mode would skip rate limiting and never reach here)
419                (
420                    StatusCode::INTERNAL_SERVER_ERROR,
421                    "Rate limiting temporarily unavailable",
422                )
423                    .into_response()
424            }
425        }
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use crate::config::RateLimitFailureMode;
433
434    #[test]
435    fn test_rate_limit_creation() {
436        let config = RateLimitConfig::default();
437        let rate_limit = RateLimit::new(config, None);
438
439        assert!(rate_limit.config.enabled);
440        assert_eq!(rate_limit.config.per_user_rpm, 120);
441        assert_eq!(rate_limit.config.per_ip_rpm, 60);
442        assert_eq!(rate_limit.config.per_route_rpm, 30);
443    }
444
445    #[test]
446    fn test_determine_key_and_limit_authenticated() {
447        let config = RateLimitConfig::default();
448        let rate_limit = RateLimit::new(config, None);
449
450        let (key, limit) = rate_limit.determine_key_and_limit(Some(123), Some("192.168.1.1"), "/posts");
451        assert_eq!(key, "ratelimit:user:123");
452        assert_eq!(limit, 120);
453    }
454
455    #[test]
456    fn test_determine_key_and_limit_anonymous() {
457        let config = RateLimitConfig::default();
458        let rate_limit = RateLimit::new(config, None);
459
460        let (key, limit) = rate_limit.determine_key_and_limit(None, Some("192.168.1.1"), "/posts");
461        assert_eq!(key, "ratelimit:ip:192.168.1.1");
462        assert_eq!(limit, 60);
463    }
464
465    #[test]
466    fn test_determine_key_and_limit_strict_route_authenticated() {
467        let config = RateLimitConfig::default();
468        let rate_limit = RateLimit::new(config, None);
469
470        let (key, limit) = rate_limit.determine_key_and_limit(Some(123), Some("192.168.1.1"), "/login");
471        assert_eq!(key, "ratelimit:route:user:123");
472        assert_eq!(limit, 30);
473    }
474
475    #[test]
476    fn test_determine_key_and_limit_strict_route_anonymous() {
477        let config = RateLimitConfig::default();
478        let rate_limit = RateLimit::new(config, None);
479
480        let (key, limit) = rate_limit.determine_key_and_limit(None, Some("192.168.1.1"), "/register");
481        assert_eq!(key, "ratelimit:route:ip:192.168.1.1");
482        assert_eq!(limit, 30);
483    }
484
485    #[tokio::test]
486    async fn test_in_memory_rate_limit_within_limit() {
487        let config = RateLimitConfig {
488            enabled: true,
489            per_user_rpm: 5,
490            per_ip_rpm: 3,
491            per_route_rpm: 2,
492            window_secs: 60,
493            redis_enabled: false,
494            failure_mode: RateLimitFailureMode::Closed,
495            strict_routes: vec![],
496        };
497        let rate_limit = RateLimit::new(config, None);
498
499        // Should allow 3 requests
500        for _ in 0..3 {
501            let result = rate_limit.check_rate_limit_memory("test_key", 5).await;
502            assert!(result.is_ok());
503        }
504    }
505
506    #[tokio::test]
507    async fn test_in_memory_rate_limit_exceeded() {
508        let config = RateLimitConfig {
509            enabled: true,
510            per_user_rpm: 5,
511            per_ip_rpm: 3,
512            per_route_rpm: 2,
513            window_secs: 60,
514            redis_enabled: false,
515            failure_mode: RateLimitFailureMode::Closed,
516            strict_routes: vec![],
517        };
518        let rate_limit = RateLimit::new(config, None);
519
520        // Should allow 3 requests
521        for _ in 0..3 {
522            let result = rate_limit.check_rate_limit_memory("test_key", 3).await;
523            assert!(result.is_ok());
524        }
525
526        // 4th request should fail
527        let result = rate_limit.check_rate_limit_memory("test_key", 3).await;
528        assert!(result.is_err());
529        assert!(matches!(result.unwrap_err(), RateLimitError::Exceeded { .. }));
530    }
531
532    #[tokio::test]
533    async fn test_in_memory_rate_limit_window_reset() {
534        let config = RateLimitConfig {
535            enabled: true,
536            per_user_rpm: 5,
537            per_ip_rpm: 3,
538            per_route_rpm: 2,
539            window_secs: 1, // 1 second window for testing
540            redis_enabled: false,
541            failure_mode: RateLimitFailureMode::Closed,
542            strict_routes: vec![],
543        };
544        let rate_limit = RateLimit::new(config, None);
545
546        // Use up the limit
547        for _ in 0..3 {
548            let result = rate_limit.check_rate_limit_memory("test_key", 3).await;
549            assert!(result.is_ok());
550        }
551
552        // Should fail
553        let result = rate_limit.check_rate_limit_memory("test_key", 3).await;
554        assert!(result.is_err());
555
556        // Wait for window to expire
557        tokio::time::sleep(Duration::from_secs(2)).await;
558
559        // Should work again
560        let result = rate_limit.check_rate_limit_memory("test_key", 3).await;
561        assert!(result.is_ok());
562    }
563
564    #[tokio::test]
565    async fn test_cleanup_expired() {
566        let config = RateLimitConfig {
567            enabled: true,
568            per_user_rpm: 5,
569            per_ip_rpm: 3,
570            per_route_rpm: 2,
571            window_secs: 1, // 1 second window for testing
572            redis_enabled: false,
573            failure_mode: RateLimitFailureMode::Closed,
574            strict_routes: vec![],
575        };
576        let rate_limit = RateLimit::new(config, None);
577
578        // Create some entries
579        for i in 0..5 {
580            let key = format!("test_key_{i}");
581            let _ = rate_limit.check_rate_limit_memory(&key, 10).await;
582        }
583
584        // Verify entries exist
585        let len = {
586            let store = rate_limit.in_memory_store.read().await;
587            store.len()
588        };
589        assert_eq!(len, 5);
590
591        // Wait for window to expire
592        tokio::time::sleep(Duration::from_secs(2)).await;
593
594        // Cleanup
595        let removed = rate_limit.cleanup_expired().await;
596        assert_eq!(removed, 5);
597
598        // Verify entries removed
599        let len = {
600            let store = rate_limit.in_memory_store.read().await;
601            store.len()
602        };
603        assert_eq!(len, 0);
604    }
605
606    #[test]
607    fn test_rate_limit_error_display() {
608        let error = RateLimitError::Exceeded {
609            limit: 100,
610            window: Duration::from_secs(60),
611        };
612        assert!(error.to_string().contains("100"));
613        assert!(error.to_string().contains("60"));
614
615        let error = RateLimitError::Backend("Redis connection failed".to_string());
616        assert!(error.to_string().contains("Redis connection failed"));
617    }
618}