1use http::{HeaderName, Request, Response, StatusCode};
2use std::collections::HashMap;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::{Arc, Mutex};
6use std::task::{Context, Poll};
7use std::time::{Duration, Instant};
8use tower::{Layer, Service};
9
10#[derive(Clone, Debug)]
28pub enum KeyExtractor {
29 XForwardedFor,
35 Header(HeaderName),
40}
41
42#[derive(Debug)]
43struct WindowState {
44 count: u64,
45 window_start: Instant,
46}
47
48#[derive(Clone, Debug)]
67pub struct RateLimitConfig {
68 requests: u64,
69 window: Duration,
70 status: StatusCode,
71 key: KeyExtractor,
72}
73
74impl Default for RateLimitConfig {
75 fn default() -> Self {
76 Self {
77 requests: 100,
78 window: Duration::from_secs(60),
79 status: StatusCode::TOO_MANY_REQUESTS,
80 key: KeyExtractor::XForwardedFor,
81 }
82 }
83}
84
85impl RateLimitConfig {
86 pub fn new() -> Self {
89 Self::default()
90 }
91
92 pub fn requests(mut self, n: u64) -> Self {
96 self.requests = n;
97 self
98 }
99
100 pub fn window(mut self, d: Duration) -> Self {
104 self.window = d;
105 self
106 }
107
108 pub fn status(mut self, s: StatusCode) -> Self {
112 self.status = s;
113 self
114 }
115
116 pub fn key(mut self, k: KeyExtractor) -> Self {
120 self.key = k;
121 self
122 }
123
124 fn extract_key<B>(&self, req: &Request<B>) -> Option<String> {
125 match &self.key {
126 KeyExtractor::XForwardedFor => req
127 .headers()
128 .get("x-forwarded-for")
129 .and_then(|v| v.to_str().ok())
130 .map(|s| s.split(',').next().unwrap_or(s).trim().to_owned()),
131 KeyExtractor::Header(name) => req
132 .headers()
133 .get(name)
134 .and_then(|v| v.to_str().ok())
135 .map(|s| s.to_owned()),
136 }
137 }
138}
139
140#[derive(Clone, Debug)]
166pub struct RateLimitLayer {
167 config: RateLimitConfig,
168 store: Arc<Mutex<HashMap<String, WindowState>>>,
169}
170
171impl Default for RateLimitLayer {
172 fn default() -> Self {
173 Self::new(RateLimitConfig::default())
174 }
175}
176
177impl RateLimitLayer {
178 pub fn new(config: RateLimitConfig) -> Self {
180 Self {
181 config,
182 store: Arc::new(Mutex::new(HashMap::new())),
183 }
184 }
185}
186
187impl<S> Layer<S> for RateLimitLayer {
188 type Service = RateLimitService<S>;
189
190 fn layer(&self, inner: S) -> Self::Service {
191 RateLimitService {
192 inner,
193 config: self.config.clone(),
194 store: Arc::clone(&self.store),
195 }
196 }
197}
198
199#[derive(Clone, Debug)]
201pub struct RateLimitService<S> {
202 inner: S,
203 config: RateLimitConfig,
204 store: Arc<Mutex<HashMap<String, WindowState>>>,
205}
206
207impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for RateLimitService<S>
208where
209 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
210 S::Future: Send + 'static,
211 S::Error: Send + 'static,
212 ResBody: Default + Send + 'static,
213{
214 type Response = Response<ResBody>;
215 type Error = S::Error;
216 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
217
218 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
219 self.inner.poll_ready(cx)
220 }
221
222 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
223 let retry_after_secs: Option<u64> = self.config.extract_key(&req).and_then(|key| {
225 let mut store = self.store.lock().expect("rate limit store lock poisoned");
226 let now = Instant::now();
227 let window = self.config.window;
228 let limit = self.config.requests;
229
230 let state = store.entry(key).or_insert_with(|| WindowState {
231 count: 0,
232 window_start: now,
233 });
234
235 if now.duration_since(state.window_start) >= window {
236 state.window_start = now;
237 state.count = 0;
238 }
239
240 if state.count >= limit {
241 let remaining = window
242 .saturating_sub(now.duration_since(state.window_start))
243 .as_secs()
244 .max(1);
245 return Some(remaining);
246 }
247
248 state.count += 1;
249 None
250 });
251
252 if let Some(retry_secs) = retry_after_secs {
253 let status = self.config.status;
254 return Box::pin(async move {
255 Ok(Response::builder()
256 .status(status)
257 .header("retry-after", retry_secs.to_string())
258 .body(ResBody::default())
259 .expect("error response is valid"))
260 });
261 }
262
263 Box::pin(self.inner.call(req))
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use axum::{body::Body, routing::get, Router};
271 use http::StatusCode;
272 use tower::ServiceExt;
273
274 fn build_app(config: RateLimitConfig) -> Router {
275 Router::new()
276 .route("/", get(|| async { "ok" }))
277 .layer(RateLimitLayer::new(config))
278 }
279
280 async fn send(app: Router, req: http::Request<Body>) -> http::Response<Body> {
281 app.oneshot(req).await.unwrap()
282 }
283
284 fn req_with_ip(ip: &str) -> http::Request<Body> {
285 http::Request::builder()
286 .method("GET")
287 .uri("/")
288 .header("x-forwarded-for", ip)
289 .body(Body::empty())
290 .unwrap()
291 }
292
293 fn req_with_key(header: &str, value: &str) -> http::Request<Body> {
294 http::Request::builder()
295 .method("GET")
296 .uri("/")
297 .header(header, value)
298 .body(Body::empty())
299 .unwrap()
300 }
301
302 fn req_bare() -> http::Request<Body> {
303 http::Request::builder()
304 .method("GET")
305 .uri("/")
306 .body(Body::empty())
307 .unwrap()
308 }
309
310 #[tokio::test]
311 async fn passes_within_limit() {
312 let app = build_app(RateLimitConfig::new().requests(3));
313 for _ in 0..3 {
314 let response = send(app.clone(), req_with_ip("1.2.3.4")).await;
315 assert_eq!(response.status(), StatusCode::OK);
316 }
317 }
318
319 #[tokio::test]
320 async fn rejects_when_limit_exceeded() {
321 let app = build_app(RateLimitConfig::new().requests(2));
322 send(app.clone(), req_with_ip("1.2.3.4")).await;
323 send(app.clone(), req_with_ip("1.2.3.4")).await;
324 let response = send(app.clone(), req_with_ip("1.2.3.4")).await;
325 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
326 }
327
328 #[tokio::test]
329 async fn retry_after_header_present_on_rejection() {
330 let app = build_app(RateLimitConfig::new().requests(1));
331 send(app.clone(), req_with_ip("1.2.3.4")).await;
332 let response = send(app.clone(), req_with_ip("1.2.3.4")).await;
333 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
334 let retry = response
335 .headers()
336 .get("retry-after")
337 .and_then(|v| v.to_str().ok())
338 .and_then(|s| s.parse::<u64>().ok())
339 .unwrap();
340 assert!(retry >= 1);
341 }
342
343 #[tokio::test]
344 async fn different_ips_have_independent_counters() {
345 let app = build_app(RateLimitConfig::new().requests(1));
346 send(app.clone(), req_with_ip("1.1.1.1")).await; let response = send(app.clone(), req_with_ip("2.2.2.2")).await; assert_eq!(response.status(), StatusCode::OK);
349 }
350
351 #[tokio::test]
352 async fn zero_limit_rejects_immediately() {
353 let app = build_app(RateLimitConfig::new().requests(0));
354 let response = send(app, req_with_ip("1.2.3.4")).await;
355 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
356 }
357
358 #[tokio::test]
359 async fn request_without_key_passes_uncounted() {
360 let app = build_app(RateLimitConfig::new().requests(0)); let response = send(app, req_bare()).await; assert_eq!(response.status(), StatusCode::OK);
363 }
364
365 #[tokio::test]
366 async fn header_key_extractor() {
367 let config = RateLimitConfig::new()
368 .requests(1)
369 .key(KeyExtractor::Header(HeaderName::from_static("x-api-key")));
370 let app = build_app(config);
371
372 let response = send(app.clone(), req_with_key("x-api-key", "token-a")).await;
373 assert_eq!(response.status(), StatusCode::OK);
374
375 let response = send(app.clone(), req_with_key("x-api-key", "token-b")).await;
377 assert_eq!(response.status(), StatusCode::OK);
378
379 let response = send(app.clone(), req_with_key("x-api-key", "token-a")).await;
381 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
382 }
383
384 #[tokio::test]
385 async fn custom_rejection_status() {
386 let config = RateLimitConfig::new()
387 .requests(0)
388 .status(StatusCode::SERVICE_UNAVAILABLE);
389 let response = send(build_app(config), req_with_ip("1.2.3.4")).await;
390 assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
391 }
392
393 #[tokio::test]
394 async fn window_resets_after_expiry() {
395 let config = RateLimitConfig::new()
396 .requests(1)
397 .window(Duration::from_millis(50));
398 let app = build_app(config);
399
400 let r = send(app.clone(), req_with_ip("10.0.0.1")).await;
402 assert_eq!(r.status(), StatusCode::OK);
403
404 let r = send(app.clone(), req_with_ip("10.0.0.1")).await;
406 assert_eq!(r.status(), StatusCode::TOO_MANY_REQUESTS);
407
408 tokio::time::sleep(Duration::from_millis(60)).await;
410
411 let r = send(app.clone(), req_with_ip("10.0.0.1")).await;
413 assert_eq!(r.status(), StatusCode::OK);
414 }
415
416 #[tokio::test]
417 async fn x_forwarded_for_uses_first_ip() {
418 let app = build_app(RateLimitConfig::new().requests(1));
419 send(
421 app.clone(),
422 http::Request::builder()
423 .method("GET")
424 .uri("/")
425 .header("x-forwarded-for", "5.5.5.5, 10.0.0.1, 172.16.0.1")
426 .body(Body::empty())
427 .unwrap(),
428 )
429 .await;
430 let response = send(
432 app.clone(),
433 http::Request::builder()
434 .method("GET")
435 .uri("/")
436 .header("x-forwarded-for", "5.5.5.5, 192.168.1.1")
437 .body(Body::empty())
438 .unwrap(),
439 )
440 .await;
441 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
442 }
443
444 #[tokio::test]
445 async fn default_layer_uses_100_req_per_60s() {
446 let app = Router::new()
447 .route("/", get(|| async { "ok" }))
448 .layer(RateLimitLayer::default());
449
450 for _ in 0..100 {
452 let r = send(app.clone(), req_with_ip("9.9.9.9")).await;
453 assert_eq!(r.status(), StatusCode::OK);
454 }
455 let r = send(app.clone(), req_with_ip("9.9.9.9")).await;
457 assert_eq!(r.status(), StatusCode::TOO_MANY_REQUESTS);
458 }
459}