Skip to main content

karbon_framework/http/middleware/
rate_limit.rs

1use axum::{
2    extract::Request,
3    http::StatusCode,
4    response::{IntoResponse, Response},
5};
6use std::collections::HashMap;
7use std::net::IpAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::Mutex;
11use tower::{Layer, Service};
12
13/// Simple in-memory rate limiter (per IP address).
14///
15/// # Security note
16/// Uses `X-Forwarded-For` / `X-Real-IP` headers to identify clients.
17/// **These headers can be spoofed** unless your reverse proxy (nginx, Cloudflare, etc.)
18/// strips and re-sets them. In production, always place this behind a trusted reverse proxy
19/// that sets `X-Forwarded-For` to the real client IP.
20///
21/// For distributed systems, consider a Redis-backed rate limiter instead.
22///
23/// # Example
24/// ```ignore
25/// let app = Router::new()
26///     .route("/api/login", post(login))
27///     .layer(RateLimitLayer::new(60, Duration::from_secs(60))); // 60 req/min
28/// ```
29#[derive(Clone)]
30pub struct RateLimitLayer {
31    max_requests: u32,
32    window: Duration,
33    store: Arc<Mutex<HashMap<IpAddr, (u32, Instant)>>>,
34}
35
36impl RateLimitLayer {
37    pub fn new(max_requests: u32, window: Duration) -> Self {
38        Self {
39            max_requests,
40            window,
41            store: Arc::new(Mutex::new(HashMap::new())),
42        }
43    }
44
45    /// Per-minute shorthand
46    pub fn per_minute(max: u32) -> Self {
47        Self::new(max, Duration::from_secs(60))
48    }
49}
50
51impl<S> Layer<S> for RateLimitLayer {
52    type Service = RateLimitService<S>;
53
54    fn layer(&self, inner: S) -> Self::Service {
55        RateLimitService {
56            inner,
57            max_requests: self.max_requests,
58            window: self.window,
59            store: self.store.clone(),
60        }
61    }
62}
63
64#[derive(Clone)]
65pub struct RateLimitService<S> {
66    inner: S,
67    max_requests: u32,
68    window: Duration,
69    store: Arc<Mutex<HashMap<IpAddr, (u32, Instant)>>>,
70}
71
72impl<S> Service<Request> for RateLimitService<S>
73where
74    S: Service<Request, Response = Response> + Clone + Send + 'static,
75    S::Future: Send,
76{
77    type Response = Response;
78    type Error = S::Error;
79    type Future = std::pin::Pin<
80        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
81    >;
82
83    fn poll_ready(
84        &mut self,
85        cx: &mut std::task::Context<'_>,
86    ) -> std::task::Poll<Result<(), Self::Error>> {
87        self.inner.poll_ready(cx)
88    }
89
90    fn call(&mut self, request: Request) -> Self::Future {
91        let max = self.max_requests;
92        let window = self.window;
93        let store = self.store.clone();
94        let mut inner = self.inner.clone();
95
96        Box::pin(async move {
97            // Extract IP from proxy headers or fallback to loopback
98            let ip: IpAddr = crate::util::HttpHelper::client_ip(
99                request.headers(),
100                "127.0.0.1".parse().unwrap(),
101            );
102
103            let mut map = store.lock().await;
104            let now = Instant::now();
105
106            let (count, started) = map.entry(ip).or_insert((0, now));
107
108            // Reset window if expired
109            if now.duration_since(*started) > window {
110                *count = 0;
111                *started = now;
112            }
113
114            *count += 1;
115
116            if *count > max {
117                drop(map);
118                return Ok((
119                    StatusCode::TOO_MANY_REQUESTS,
120                    "Rate limit exceeded",
121                )
122                    .into_response());
123            }
124            drop(map);
125
126            inner.call(request).await
127        })
128    }
129}