avl_console/middleware/
rate_limit.rs1use axum::{
4 extract::Request,
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::Mutex;
12use tower::{Layer, Service};
13
14#[derive(Clone)]
16pub struct RateLimitConfig {
17 pub max_requests: usize,
19 pub window: Duration,
21}
22
23impl Default for RateLimitConfig {
24 fn default() -> Self {
25 Self {
26 max_requests: 100,
27 window: Duration::from_secs(60),
28 }
29 }
30}
31
32struct RateLimiter {
34 requests: HashMap<String, Vec<Instant>>,
35 config: RateLimitConfig,
36}
37
38impl RateLimiter {
39 fn new(config: RateLimitConfig) -> Self {
40 Self {
41 requests: HashMap::new(),
42 config,
43 }
44 }
45
46 fn check_rate_limit(&mut self, key: &str) -> bool {
47 let now = Instant::now();
48 let window_start = now - self.config.window;
49
50 let requests = self.requests.entry(key.to_string()).or_insert_with(Vec::new);
52 requests.retain(|&time| time > window_start);
53
54 if requests.len() >= self.config.max_requests {
56 return false;
57 }
58
59 requests.push(now);
61 true
62 }
63}
64
65#[derive(Clone)]
67pub struct RateLimitLayer {
68 limiter: Arc<Mutex<RateLimiter>>,
69}
70
71impl RateLimitLayer {
72 pub fn new() -> Self {
73 Self::with_config(RateLimitConfig::default())
74 }
75
76 pub fn with_config(config: RateLimitConfig) -> Self {
77 Self {
78 limiter: Arc::new(Mutex::new(RateLimiter::new(config))),
79 }
80 }
81}
82
83impl Default for RateLimitLayer {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89impl<S> Layer<S> for RateLimitLayer {
90 type Service = RateLimitMiddleware<S>;
91
92 fn layer(&self, inner: S) -> Self::Service {
93 RateLimitMiddleware {
94 inner,
95 limiter: self.limiter.clone(),
96 }
97 }
98}
99
100#[derive(Clone)]
101pub struct RateLimitMiddleware<S> {
102 inner: S,
103 limiter: Arc<Mutex<RateLimiter>>,
104}
105
106impl<S> Service<Request> for RateLimitMiddleware<S>
107where
108 S: Service<Request, Response = Response> + Send + 'static,
109 S::Future: Send + 'static,
110{
111 type Response = S::Response;
112 type Error = S::Error;
113 type Future = std::pin::Pin<
114 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
115 >;
116
117 fn poll_ready(
118 &mut self,
119 cx: &mut std::task::Context<'_>,
120 ) -> std::task::Poll<Result<(), Self::Error>> {
121 self.inner.poll_ready(cx)
122 }
123
124 fn call(&mut self, req: Request) -> Self::Future {
125 let limiter = self.limiter.clone();
126
127 let client_id = req
129 .headers()
130 .get("x-forwarded-for")
131 .and_then(|v| v.to_str().ok())
132 .unwrap_or("unknown")
133 .to_string();
134
135 let future = self.inner.call(req);
136
137 Box::pin(async move {
138
139 let mut limiter = limiter.lock().await;
141 if !limiter.check_rate_limit(&client_id) {
142 return Ok((
143 StatusCode::TOO_MANY_REQUESTS,
144 "Rate limit exceeded. Please try again later.",
145 )
146 .into_response());
147 }
148 drop(limiter);
149
150 future.await
151 })
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[tokio::test]
160 async fn test_rate_limiter() {
161 let config = RateLimitConfig {
162 max_requests: 5,
163 window: Duration::from_secs(60),
164 };
165 let mut limiter = RateLimiter::new(config);
166
167 for _ in 0..5 {
169 assert!(limiter.check_rate_limit("test_user"));
170 }
171
172 assert!(!limiter.check_rate_limit("test_user"));
174 }
175}