1use axum::{
8 extract::{Request, State},
9 http::StatusCode,
10 middleware::Next,
11 response::{IntoResponse, Response},
12};
13use std::collections::HashMap;
14use std::sync::{Arc, Mutex, RwLock};
15use std::time::Instant;
16use tokio::sync::Mutex as AsyncMutex;
17
18#[derive(Debug)]
20pub struct TokenBucket {
21 tokens: f64,
23 capacity: f64,
25 rate: f64,
27 last_refill: Instant,
29}
30
31impl TokenBucket {
32 pub fn new(capacity: f64, rate: f64) -> Self {
37 Self {
38 tokens: capacity,
39 capacity,
40 rate,
41 last_refill: Instant::now(),
42 }
43 }
44
45 pub fn try_acquire(&mut self) -> Result<(), f64> {
48 self.refill();
49 if self.tokens >= 1.0 {
50 self.tokens -= 1.0;
51 Ok(())
52 } else {
53 let deficit = 1.0 - self.tokens;
55 let retry_after = deficit / self.rate;
56 Err(retry_after)
57 }
58 }
59
60 fn refill(&mut self) {
61 let now = Instant::now();
62 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
63 self.tokens = (self.tokens + elapsed * self.rate).min(self.capacity);
64 self.last_refill = now;
65 }
66}
67
68#[derive(Clone)]
70pub struct RateLimiter(pub Arc<AsyncMutex<TokenBucket>>);
71
72impl RateLimiter {
73 pub fn new(capacity: f64, rate_per_second: f64) -> Self {
75 Self(Arc::new(AsyncMutex::new(TokenBucket::new(
76 capacity,
77 rate_per_second,
78 ))))
79 }
80}
81
82pub async fn rate_limit_middleware(
84 limiter: Option<axum::extract::Extension<RateLimiter>>,
85 request: Request,
86 next: Next,
87) -> Response {
88 let Some(axum::extract::Extension(limiter)) = limiter else {
89 return next.run(request).await;
90 };
91
92 let mut bucket = limiter.0.lock().await;
93 match bucket.try_acquire() {
94 Ok(()) => {
95 drop(bucket);
96 next.run(request).await
97 }
98 Err(retry_after) => {
99 drop(bucket);
100 let retry_secs = retry_after.ceil() as u64;
101 let body = serde_json::json!({
102 "error": {
103 "message": "Rate limit exceeded",
104 "type": "rate_limit_error",
105 }
106 });
107 let mut resp = (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
108 if let Ok(val) = retry_secs.to_string().parse() {
109 resp.headers_mut().insert("retry-after", val);
110 }
111 resp
112 }
113 }
114}
115
116#[derive(Debug)]
131pub struct PerKeyRateLimiter {
132 buckets: Arc<RwLock<HashMap<String, Mutex<TokenBucket>>>>,
133 default_capacity: f64,
134 default_rate: f64,
135 overrides: HashMap<String, (f64, f64)>,
136}
137
138impl PerKeyRateLimiter {
139 pub fn new(default_capacity: f64, default_rate: f64) -> Self {
141 Self {
142 buckets: Arc::new(RwLock::new(HashMap::new())),
143 default_capacity,
144 default_rate,
145 overrides: HashMap::new(),
146 }
147 }
148
149 pub fn with_overrides(mut self, overrides: HashMap<String, (f64, f64)>) -> Self {
153 self.overrides = overrides;
154 self
155 }
156
157 pub fn check_key(&self, key: &str) -> bool {
165 {
167 let map = self.buckets.read().unwrap_or_else(|e| e.into_inner());
168 if let Some(bucket_mutex) = map.get(key) {
169 let mut bucket = bucket_mutex.lock().unwrap_or_else(|e| e.into_inner());
170 return bucket.try_acquire().is_ok();
171 }
172 }
173
174 let (capacity, rate) = self
176 .overrides
177 .get(key)
178 .copied()
179 .unwrap_or((self.default_capacity, self.default_rate));
180
181 let mut map = self.buckets.write().unwrap_or_else(|e| e.into_inner());
182
183 let bucket_mutex = map
185 .entry(key.to_string())
186 .or_insert_with(|| Mutex::new(TokenBucket::new(capacity, rate)));
187
188 let bucket = bucket_mutex.get_mut().unwrap_or_else(|e| e.into_inner());
189 bucket.try_acquire().is_ok()
190 }
191}
192
193fn extract_key_from_request(request: &Request) -> Option<String> {
198 if let Some(auth) = request
200 .headers()
201 .get("authorization")
202 .and_then(|v| v.to_str().ok())
203 {
204 if let Some(token) = auth.strip_prefix("Bearer ") {
205 return Some(token.to_string());
206 }
207 }
208
209 request
211 .headers()
212 .get("x-api-key")
213 .and_then(|v| v.to_str().ok())
214 .map(|s| s.to_string())
215}
216
217pub async fn per_key_rate_limit_middleware(
223 State(limiter): State<Arc<PerKeyRateLimiter>>,
224 request: Request,
225 next: Next,
226) -> Response {
227 let key = extract_key_from_request(&request);
228
229 let allowed = match key.as_deref() {
232 None => true,
233 Some(k) => limiter.check_key(k),
234 };
235
236 if allowed {
237 next.run(request).await
238 } else {
239 let body = serde_json::json!({
240 "error": {
241 "message": "Per-key rate limit exceeded",
242 "type": "rate_limit_error",
243 }
244 });
245 let mut resp = (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
246 resp.headers_mut().insert(
247 "retry-after",
248 "1".parse().unwrap_or_else(|_| "1".parse().expect("static")),
249 );
250 resp
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn test_bucket_allows_within_capacity() {
260 let mut bucket = TokenBucket::new(5.0, 1.0);
261 for _ in 0..5 {
262 assert!(bucket.try_acquire().is_ok());
263 }
264 assert!(bucket.try_acquire().is_err());
266 }
267
268 #[test]
269 fn test_bucket_refills() {
270 let mut bucket = TokenBucket::new(1.0, 1000.0); assert!(bucket.try_acquire().is_ok());
272 assert!(bucket.try_acquire().is_err());
273 std::thread::sleep(std::time::Duration::from_millis(10));
275 assert!(bucket.try_acquire().is_ok());
276 }
277
278 #[test]
279 fn test_retry_after_is_positive() {
280 let mut bucket = TokenBucket::new(1.0, 1.0);
281 bucket.try_acquire().ok(); let err = bucket.try_acquire().unwrap_err();
283 assert!(err > 0.0, "retry_after should be positive");
284 }
285
286 #[tokio::test]
287 async fn test_rate_limit_middleware_allows() {
288 use axum::{body::Body, http::Request as HttpRequest, middleware, routing::get, Router};
289 use tower::ServiceExt;
290
291 let limiter = RateLimiter::new(10.0, 10.0);
292 let app = Router::new()
293 .route("/test", get(|| async { "ok" }))
294 .layer(middleware::from_fn(rate_limit_middleware))
295 .layer(axum::Extension(limiter));
296
297 let req = HttpRequest::builder()
298 .uri("/test")
299 .body(Body::empty())
300 .unwrap();
301 let resp = app.oneshot(req).await.unwrap();
302 assert_eq!(resp.status(), StatusCode::OK);
303 }
304
305 #[test]
308 fn per_key_two_keys_are_independent() {
309 let limiter = PerKeyRateLimiter::new(2.0, 1.0);
311
312 assert!(limiter.check_key("key-a"), "key-a first hit should pass");
314 assert!(limiter.check_key("key-a"), "key-a second hit should pass");
315 assert!(
316 !limiter.check_key("key-a"),
317 "key-a third hit should be rejected"
318 );
319
320 assert!(
322 limiter.check_key("key-b"),
323 "key-b should be unaffected by key-a exhaustion"
324 );
325 }
326
327 #[test]
328 fn per_key_burst_then_rejected() {
329 let limiter = PerKeyRateLimiter::new(3.0, 0.001); for i in 0..3 {
333 assert!(limiter.check_key("burst-key"), "hit #{i} should be allowed");
334 }
335 assert!(
337 !limiter.check_key("burst-key"),
338 "4th hit should be rejected (bucket exhausted)"
339 );
340 }
341
342 #[test]
343 fn per_key_override_applied() {
344 let mut overrides = HashMap::new();
345 overrides.insert("premium-key".to_string(), (10.0, 1.0));
347
348 let limiter = PerKeyRateLimiter::new(1.0, 1.0).with_overrides(overrides);
349
350 assert!(
352 limiter.check_key("default-key"),
353 "default first hit allowed"
354 );
355 assert!(
356 !limiter.check_key("default-key"),
357 "default second hit rejected"
358 );
359
360 for i in 0..10 {
362 assert!(
363 limiter.check_key("premium-key"),
364 "premium hit #{i} should be allowed"
365 );
366 }
367 assert!(
368 !limiter.check_key("premium-key"),
369 "premium 11th hit rejected"
370 );
371 }
372
373 #[test]
374 fn per_key_anonymous_request_allowed() {
375 let limiter = PerKeyRateLimiter::new(5.0, 1.0);
379 assert!(
382 limiter.check_key("any-key"),
383 "any key with capacity should be allowed"
384 );
385 }
386
387 #[test]
388 fn per_key_lazy_insert_idempotent() {
389 let limiter = PerKeyRateLimiter::new(5.0, 1.0);
390
391 for i in 0..5 {
394 assert!(
395 limiter.check_key("idempotent-key"),
396 "hit #{i} should pass (capacity=5)"
397 );
398 }
399 assert!(
401 !limiter.check_key("idempotent-key"),
402 "6th hit should be rejected"
403 );
404
405 let map = limiter.buckets.read().unwrap();
407 assert_eq!(
408 map.len(),
409 1,
410 "only one bucket should be inserted for a single key"
411 );
412 }
413}