avl_console/middleware/
rate_limit.rs

1//! Rate limiting middleware
2
3use axum::{
4    extract::Request,
5    http::StatusCode,
6    response::{IntoResponse, Response},
7};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::Mutex;
12use tower::{Layer, Service};
13
14/// Rate limit configuration
15#[derive(Clone)]
16pub struct RateLimitConfig {
17    /// Maximum requests per window
18    pub max_requests: usize,
19    /// Time window
20    pub window: Duration,
21}
22
23impl Default for RateLimitConfig {
24    fn default() -> Self {
25        Self {
26            max_requests: 100,
27            window: Duration::from_secs(60),
28        }
29    }
30}
31
32/// Rate limiter state
33struct RateLimiter {
34    requests: HashMap<String, Vec<Instant>>,
35    config: RateLimitConfig,
36}
37
38impl RateLimiter {
39    fn new(config: RateLimitConfig) -> Self {
40        Self {
41            requests: HashMap::new(),
42            config,
43        }
44    }
45
46    fn check_rate_limit(&mut self, key: &str) -> bool {
47        let now = Instant::now();
48        let window_start = now - self.config.window;
49
50        // Clean old entries
51        let requests = self.requests.entry(key.to_string()).or_insert_with(Vec::new);
52        requests.retain(|&time| time > window_start);
53
54        // Check limit
55        if requests.len() >= self.config.max_requests {
56            return false;
57        }
58
59        // Add new request
60        requests.push(now);
61        true
62    }
63}
64
65/// Rate limiting layer
66#[derive(Clone)]
67pub struct RateLimitLayer {
68    limiter: Arc<Mutex<RateLimiter>>,
69}
70
71impl RateLimitLayer {
72    pub fn new() -> Self {
73        Self::with_config(RateLimitConfig::default())
74    }
75
76    pub fn with_config(config: RateLimitConfig) -> Self {
77        Self {
78            limiter: Arc::new(Mutex::new(RateLimiter::new(config))),
79        }
80    }
81}
82
83impl Default for RateLimitLayer {
84    fn default() -> Self {
85        Self::new()
86    }
87}
88
89impl<S> Layer<S> for RateLimitLayer {
90    type Service = RateLimitMiddleware<S>;
91
92    fn layer(&self, inner: S) -> Self::Service {
93        RateLimitMiddleware {
94            inner,
95            limiter: self.limiter.clone(),
96        }
97    }
98}
99
100#[derive(Clone)]
101pub struct RateLimitMiddleware<S> {
102    inner: S,
103    limiter: Arc<Mutex<RateLimiter>>,
104}
105
106impl<S> Service<Request> for RateLimitMiddleware<S>
107where
108    S: Service<Request, Response = Response> + Send + 'static,
109    S::Future: Send + 'static,
110{
111    type Response = S::Response;
112    type Error = S::Error;
113    type Future = std::pin::Pin<
114        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
115    >;
116
117    fn poll_ready(
118        &mut self,
119        cx: &mut std::task::Context<'_>,
120    ) -> std::task::Poll<Result<(), Self::Error>> {
121        self.inner.poll_ready(cx)
122    }
123
124    fn call(&mut self, req: Request) -> Self::Future {
125        let limiter = self.limiter.clone();
126
127        // Extract client ID before moving req
128        let client_id = req
129            .headers()
130            .get("x-forwarded-for")
131            .and_then(|v| v.to_str().ok())
132            .unwrap_or("unknown")
133            .to_string();
134
135        let future = self.inner.call(req);
136
137        Box::pin(async move {
138
139            // Check rate limit
140            let mut limiter = limiter.lock().await;
141            if !limiter.check_rate_limit(&client_id) {
142                return Ok((
143                    StatusCode::TOO_MANY_REQUESTS,
144                    "Rate limit exceeded. Please try again later.",
145                )
146                    .into_response());
147            }
148            drop(limiter);
149
150            future.await
151        })
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[tokio::test]
160    async fn test_rate_limiter() {
161        let config = RateLimitConfig {
162            max_requests: 5,
163            window: Duration::from_secs(60),
164        };
165        let mut limiter = RateLimiter::new(config);
166
167        // Should allow first 5 requests
168        for _ in 0..5 {
169            assert!(limiter.check_rate_limit("test_user"));
170        }
171
172        // 6th request should be denied
173        assert!(!limiter.check_rate_limit("test_user"));
174    }
175}