Skip to main content

ironflow_api/
rate_limit.rs

1//! Per-IP rate limiting middleware backed by [`governor`].
2//!
3//! Use [`per_minute`] to create a limiter, then install the [`rate_limit`]
4//! middleware on the desired router group. Rate-limited responses include
5//! `X-RateLimit-Limit` / `Retry-After` headers and return HTTP 429 with a
6//! JSON error body.
7
8use std::net::{IpAddr, Ipv4Addr, SocketAddr};
9use std::num::NonZeroU32;
10use std::sync::Arc;
11
12use axum::extract::{ConnectInfo, Request};
13use axum::http::{HeaderValue, StatusCode};
14use axum::middleware::Next;
15use axum::response::{IntoResponse, Response};
16use governor::clock::{Clock, DefaultClock};
17use governor::state::keyed::DashMapStateStore;
18use governor::{Quota, RateLimiter};
19use serde_json::json;
20
21/// Keyed rate limiter: one bucket per IP address.
22type KeyedLimiter = RateLimiter<IpAddr, DashMapStateStore<IpAddr>, DefaultClock>;
23
24/// Shared rate limiter state, injected as an Axum extension.
25#[derive(Clone)]
26pub struct RateLimitState {
27    limiter: Arc<KeyedLimiter>,
28    burst: u32,
29}
30
31/// Build a per-IP rate limiter allowing `requests_per_minute` requests
32/// per minute.
33///
34/// # Panics
35///
36/// Panics if `requests_per_minute` is 0.
37///
38/// # Examples
39///
40/// ```
41/// use ironflow_api::rate_limit::per_minute;
42///
43/// let auth_limiter = per_minute(10);
44/// let general_limiter = per_minute(60);
45/// ```
46pub fn per_minute(requests_per_minute: u32) -> RateLimitState {
47    let quota = Quota::per_minute(NonZeroU32::new(requests_per_minute).expect("burst must be > 0"));
48    RateLimitState {
49        limiter: Arc::new(RateLimiter::keyed(quota)),
50        burst: requests_per_minute,
51    }
52}
53
54/// Axum middleware that enforces per-IP rate limiting.
55///
56/// Must be installed with a [`RateLimitState`] extension on the router.
57/// Returns 429 Too Many Requests when the limit is exceeded, with a JSON body
58/// and `X-RateLimit-Limit` / `Retry-After` headers.
59pub async fn rate_limit(req: Request, next: Next) -> Response {
60    let state = match req.extensions().get::<RateLimitState>() {
61        Some(s) => s.clone(),
62        None => return next.run(req).await,
63    };
64
65    let ip = extract_client_ip(&req);
66
67    match state.limiter.check_key(&ip) {
68        Ok(_) => {
69            let mut resp = next.run(req).await;
70            resp.headers_mut()
71                .insert("x-ratelimit-limit", HeaderValue::from(state.burst));
72            resp
73        }
74        Err(not_until) => {
75            let retry_after = not_until.wait_time_from(DefaultClock::default().now());
76            let retry_secs = retry_after.as_secs().max(1);
77
78            let body = json!({
79                "error": {
80                    "code": "RATE_LIMIT_EXCEEDED",
81                    "message": "Too many requests, please try again later",
82                    "retry_after_secs": retry_secs,
83                }
84            });
85
86            let mut resp = (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
87            resp.headers_mut()
88                .insert("retry-after", HeaderValue::from(retry_secs));
89            resp.headers_mut()
90                .insert("x-ratelimit-limit", HeaderValue::from(state.burst));
91            resp
92        }
93    }
94}
95
96/// Extract the client IP from the request.
97///
98/// Checks `X-Forwarded-For` first (first IP in the chain), then
99/// `X-Real-Ip`, then falls back to the connected peer address.
100fn extract_client_ip(req: &Request) -> IpAddr {
101    // X-Forwarded-For: client, proxy1, proxy2
102    if let Some(forwarded) = req
103        .headers()
104        .get("x-forwarded-for")
105        .and_then(|v| v.to_str().ok())
106        && let Some(first) = forwarded.split(',').next()
107        && let Ok(ip) = first.trim().parse::<IpAddr>()
108    {
109        return ip;
110    }
111
112    // X-Real-Ip
113    if let Some(real_ip) = req.headers().get("x-real-ip").and_then(|v| v.to_str().ok())
114        && let Ok(ip) = real_ip.trim().parse::<IpAddr>()
115    {
116        return ip;
117    }
118
119    // Fallback to connected peer (axum injects ConnectInfo)
120    req.extensions()
121        .get::<ConnectInfo<SocketAddr>>()
122        .map(|ci| ci.0.ip())
123        .unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED))
124}
125
126#[cfg(test)]
127mod tests {
128    use axum::body::Body;
129    use axum::http::{Request, StatusCode};
130    use axum::middleware as axum_mw;
131    use axum::routing::get;
132    use axum::{Extension, Router};
133    use http_body_util::BodyExt;
134    use serde_json::Value as JsonValue;
135    use tower::ServiceExt;
136
137    use super::*;
138
139    async fn ok_handler() -> &'static str {
140        "ok"
141    }
142
143    fn test_app(limiter: RateLimitState) -> Router {
144        Router::new()
145            .route("/test", get(ok_handler))
146            .layer(axum_mw::from_fn(rate_limit))
147            .layer(Extension(limiter))
148    }
149
150    fn test_request() -> Request<Body> {
151        Request::builder()
152            .uri("/test")
153            .header("x-forwarded-for", "1.2.3.4")
154            .body(Body::empty())
155            .unwrap()
156    }
157
158    #[tokio::test]
159    async fn allows_requests_within_limit() {
160        let limiter = per_minute(5);
161        let app = test_app(limiter);
162
163        let resp = app.oneshot(test_request()).await.unwrap();
164        assert_eq!(resp.status(), StatusCode::OK);
165        assert!(resp.headers().contains_key("x-ratelimit-limit"));
166    }
167
168    #[tokio::test]
169    async fn rejects_when_limit_exceeded() {
170        let limiter = per_minute(2);
171        let app = test_app(limiter.clone());
172
173        // Exhaust the limiter from the same IP
174        let _ = app.oneshot(test_request()).await;
175        let app = test_app(limiter.clone());
176        let _ = app.oneshot(test_request()).await;
177
178        let app = test_app(limiter);
179        let resp = app.oneshot(test_request()).await.unwrap();
180        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
181
182        let body = resp.into_body().collect().await.unwrap().to_bytes();
183        let json_val: JsonValue = serde_json::from_slice(&body).unwrap();
184        assert_eq!(json_val["error"]["code"], "RATE_LIMIT_EXCEEDED");
185    }
186
187    #[tokio::test]
188    async fn includes_retry_after_header() {
189        let limiter = per_minute(1);
190        let app = test_app(limiter.clone());
191
192        // Use up the single allowed request
193        let _ = app.oneshot(test_request()).await;
194
195        let app = test_app(limiter);
196        let resp = app.oneshot(test_request()).await.unwrap();
197        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
198        assert!(resp.headers().contains_key("retry-after"));
199    }
200
201    #[tokio::test]
202    async fn different_ips_have_separate_limits() {
203        let limiter = per_minute(1);
204
205        let app = test_app(limiter.clone());
206        let req_ip1 = Request::builder()
207            .uri("/test")
208            .header("x-forwarded-for", "10.0.0.1")
209            .body(Body::empty())
210            .unwrap();
211        let resp = app.oneshot(req_ip1).await.unwrap();
212        assert_eq!(resp.status(), StatusCode::OK);
213
214        let app = test_app(limiter);
215        let req_ip2 = Request::builder()
216            .uri("/test")
217            .header("x-forwarded-for", "10.0.0.2")
218            .body(Body::empty())
219            .unwrap();
220        let resp = app.oneshot(req_ip2).await.unwrap();
221        assert_eq!(resp.status(), StatusCode::OK);
222    }
223
224    #[tokio::test]
225    async fn extracts_ip_from_x_real_ip() {
226        let limiter = per_minute(1);
227
228        let app = test_app(limiter.clone());
229        let req = Request::builder()
230            .uri("/test")
231            .header("x-real-ip", "192.168.1.1")
232            .body(Body::empty())
233            .unwrap();
234        let resp = app.oneshot(req).await.unwrap();
235        assert_eq!(resp.status(), StatusCode::OK);
236
237        let app = test_app(limiter);
238        let req = Request::builder()
239            .uri("/test")
240            .header("x-real-ip", "192.168.1.1")
241            .body(Body::empty())
242            .unwrap();
243        let resp = app.oneshot(req).await.unwrap();
244        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
245    }
246
247    #[tokio::test]
248    async fn general_limiter_allows_more_requests() {
249        let limiter = per_minute(60);
250
251        for _ in 0..10 {
252            let app = test_app(limiter.clone());
253            let resp = app.oneshot(test_request()).await.unwrap();
254            assert_eq!(resp.status(), StatusCode::OK);
255        }
256    }
257
258    #[test]
259    fn extract_ip_x_forwarded_for_first_ip() {
260        let req = Request::builder()
261            .uri("/test")
262            .header("x-forwarded-for", "1.2.3.4, 5.6.7.8")
263            .body(Body::empty())
264            .unwrap();
265        let ip = extract_client_ip(&req);
266        assert_eq!(ip, "1.2.3.4".parse::<IpAddr>().unwrap());
267    }
268
269    #[test]
270    fn extract_ip_fallback_to_unspecified() {
271        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
272        let ip = extract_client_ip(&req);
273        assert_eq!(ip, IpAddr::V4(Ipv4Addr::UNSPECIFIED));
274    }
275}