ironflow_api/
rate_limit.rs1use std::net::{IpAddr, Ipv4Addr, SocketAddr};
9use std::num::NonZeroU32;
10use std::sync::Arc;
11
12use axum::extract::{ConnectInfo, Request};
13use axum::http::{HeaderValue, StatusCode};
14use axum::middleware::Next;
15use axum::response::{IntoResponse, Response};
16use governor::clock::{Clock, DefaultClock};
17use governor::state::keyed::DashMapStateStore;
18use governor::{Quota, RateLimiter};
19use serde_json::json;
20
21type KeyedLimiter = RateLimiter<IpAddr, DashMapStateStore<IpAddr>, DefaultClock>;
23
24#[derive(Clone)]
26pub struct RateLimitState {
27 limiter: Arc<KeyedLimiter>,
28 burst: u32,
29}
30
31pub fn per_minute(requests_per_minute: u32) -> RateLimitState {
47 let quota = Quota::per_minute(NonZeroU32::new(requests_per_minute).expect("burst must be > 0"));
48 RateLimitState {
49 limiter: Arc::new(RateLimiter::keyed(quota)),
50 burst: requests_per_minute,
51 }
52}
53
54pub async fn rate_limit(req: Request, next: Next) -> Response {
60 let state = match req.extensions().get::<RateLimitState>() {
61 Some(s) => s.clone(),
62 None => return next.run(req).await,
63 };
64
65 let ip = extract_client_ip(&req);
66
67 match state.limiter.check_key(&ip) {
68 Ok(_) => {
69 let mut resp = next.run(req).await;
70 resp.headers_mut()
71 .insert("x-ratelimit-limit", HeaderValue::from(state.burst));
72 resp
73 }
74 Err(not_until) => {
75 let retry_after = not_until.wait_time_from(DefaultClock::default().now());
76 let retry_secs = retry_after.as_secs().max(1);
77
78 let body = json!({
79 "error": {
80 "code": "RATE_LIMIT_EXCEEDED",
81 "message": "Too many requests, please try again later",
82 "retry_after_secs": retry_secs,
83 }
84 });
85
86 let mut resp = (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
87 resp.headers_mut()
88 .insert("retry-after", HeaderValue::from(retry_secs));
89 resp.headers_mut()
90 .insert("x-ratelimit-limit", HeaderValue::from(state.burst));
91 resp
92 }
93 }
94}
95
96fn extract_client_ip(req: &Request) -> IpAddr {
101 if let Some(forwarded) = req
103 .headers()
104 .get("x-forwarded-for")
105 .and_then(|v| v.to_str().ok())
106 && let Some(first) = forwarded.split(',').next()
107 && let Ok(ip) = first.trim().parse::<IpAddr>()
108 {
109 return ip;
110 }
111
112 if let Some(real_ip) = req.headers().get("x-real-ip").and_then(|v| v.to_str().ok())
114 && let Ok(ip) = real_ip.trim().parse::<IpAddr>()
115 {
116 return ip;
117 }
118
119 req.extensions()
121 .get::<ConnectInfo<SocketAddr>>()
122 .map(|ci| ci.0.ip())
123 .unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED))
124}
125
126#[cfg(test)]
127mod tests {
128 use axum::body::Body;
129 use axum::http::{Request, StatusCode};
130 use axum::middleware as axum_mw;
131 use axum::routing::get;
132 use axum::{Extension, Router};
133 use http_body_util::BodyExt;
134 use serde_json::Value as JsonValue;
135 use tower::ServiceExt;
136
137 use super::*;
138
139 async fn ok_handler() -> &'static str {
140 "ok"
141 }
142
143 fn test_app(limiter: RateLimitState) -> Router {
144 Router::new()
145 .route("/test", get(ok_handler))
146 .layer(axum_mw::from_fn(rate_limit))
147 .layer(Extension(limiter))
148 }
149
150 fn test_request() -> Request<Body> {
151 Request::builder()
152 .uri("/test")
153 .header("x-forwarded-for", "1.2.3.4")
154 .body(Body::empty())
155 .unwrap()
156 }
157
158 #[tokio::test]
159 async fn allows_requests_within_limit() {
160 let limiter = per_minute(5);
161 let app = test_app(limiter);
162
163 let resp = app.oneshot(test_request()).await.unwrap();
164 assert_eq!(resp.status(), StatusCode::OK);
165 assert!(resp.headers().contains_key("x-ratelimit-limit"));
166 }
167
168 #[tokio::test]
169 async fn rejects_when_limit_exceeded() {
170 let limiter = per_minute(2);
171 let app = test_app(limiter.clone());
172
173 let _ = app.oneshot(test_request()).await;
175 let app = test_app(limiter.clone());
176 let _ = app.oneshot(test_request()).await;
177
178 let app = test_app(limiter);
179 let resp = app.oneshot(test_request()).await.unwrap();
180 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
181
182 let body = resp.into_body().collect().await.unwrap().to_bytes();
183 let json_val: JsonValue = serde_json::from_slice(&body).unwrap();
184 assert_eq!(json_val["error"]["code"], "RATE_LIMIT_EXCEEDED");
185 }
186
187 #[tokio::test]
188 async fn includes_retry_after_header() {
189 let limiter = per_minute(1);
190 let app = test_app(limiter.clone());
191
192 let _ = app.oneshot(test_request()).await;
194
195 let app = test_app(limiter);
196 let resp = app.oneshot(test_request()).await.unwrap();
197 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
198 assert!(resp.headers().contains_key("retry-after"));
199 }
200
201 #[tokio::test]
202 async fn different_ips_have_separate_limits() {
203 let limiter = per_minute(1);
204
205 let app = test_app(limiter.clone());
206 let req_ip1 = Request::builder()
207 .uri("/test")
208 .header("x-forwarded-for", "10.0.0.1")
209 .body(Body::empty())
210 .unwrap();
211 let resp = app.oneshot(req_ip1).await.unwrap();
212 assert_eq!(resp.status(), StatusCode::OK);
213
214 let app = test_app(limiter);
215 let req_ip2 = Request::builder()
216 .uri("/test")
217 .header("x-forwarded-for", "10.0.0.2")
218 .body(Body::empty())
219 .unwrap();
220 let resp = app.oneshot(req_ip2).await.unwrap();
221 assert_eq!(resp.status(), StatusCode::OK);
222 }
223
224 #[tokio::test]
225 async fn extracts_ip_from_x_real_ip() {
226 let limiter = per_minute(1);
227
228 let app = test_app(limiter.clone());
229 let req = Request::builder()
230 .uri("/test")
231 .header("x-real-ip", "192.168.1.1")
232 .body(Body::empty())
233 .unwrap();
234 let resp = app.oneshot(req).await.unwrap();
235 assert_eq!(resp.status(), StatusCode::OK);
236
237 let app = test_app(limiter);
238 let req = Request::builder()
239 .uri("/test")
240 .header("x-real-ip", "192.168.1.1")
241 .body(Body::empty())
242 .unwrap();
243 let resp = app.oneshot(req).await.unwrap();
244 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
245 }
246
247 #[tokio::test]
248 async fn general_limiter_allows_more_requests() {
249 let limiter = per_minute(60);
250
251 for _ in 0..10 {
252 let app = test_app(limiter.clone());
253 let resp = app.oneshot(test_request()).await.unwrap();
254 assert_eq!(resp.status(), StatusCode::OK);
255 }
256 }
257
258 #[test]
259 fn extract_ip_x_forwarded_for_first_ip() {
260 let req = Request::builder()
261 .uri("/test")
262 .header("x-forwarded-for", "1.2.3.4, 5.6.7.8")
263 .body(Body::empty())
264 .unwrap();
265 let ip = extract_client_ip(&req);
266 assert_eq!(ip, "1.2.3.4".parse::<IpAddr>().unwrap());
267 }
268
269 #[test]
270 fn extract_ip_fallback_to_unspecified() {
271 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
272 let ip = extract_client_ip(&req);
273 assert_eq!(ip, IpAddr::V4(Ipv4Addr::UNSPECIFIED));
274 }
275}