1use std::collections::HashMap;
2use std::future::Future;
3use std::hash::{DefaultHasher, Hash, Hasher};
4use std::pin::Pin;
5use std::sync::Arc;
6use std::sync::RwLock;
7use std::task::{Context, Poll};
8use std::time::Instant;
9
10use axum::body::Body;
11use axum::response::IntoResponse;
12use http::{Request, Response};
13use serde::Deserialize;
14use tokio_util::sync::CancellationToken;
15use tower::{Layer, Service};
16
17#[non_exhaustive]
28#[derive(Debug, Clone, Deserialize)]
29#[serde(default)]
30pub struct RateLimitConfig {
31 pub per_second: u64,
33 pub burst_size: u32,
35 pub use_headers: bool,
37 pub cleanup_interval_secs: u64,
39 pub max_keys: usize,
42}
43
44impl Default for RateLimitConfig {
45 fn default() -> Self {
46 Self {
47 per_second: 1,
48 burst_size: 10,
49 use_headers: true,
50 cleanup_interval_secs: 60,
51 max_keys: 10_000,
52 }
53 }
54}
55
56struct TokenBucket {
61 tokens: f64,
62 last_refill: Instant,
63}
64
65enum CheckResult {
66 Allowed { remaining: u32 },
67 Rejected { retry_after_secs: f64 },
68}
69
70impl TokenBucket {
71 fn new(burst_size: u32) -> Self {
72 Self {
73 tokens: burst_size as f64,
74 last_refill: Instant::now(),
75 }
76 }
77
78 fn check(&mut self, per_second: u64, burst_size: u32) -> CheckResult {
79 let now = Instant::now();
80 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
81 self.last_refill = now;
82
83 self.tokens = (self.tokens + elapsed * per_second as f64).min(burst_size as f64);
85
86 if self.tokens >= 1.0 {
87 self.tokens -= 1.0;
88 CheckResult::Allowed {
89 remaining: self.tokens as u32,
90 }
91 } else {
92 let deficit = 1.0 - self.tokens;
93 let wait = deficit / per_second as f64;
94 CheckResult::Rejected {
95 retry_after_secs: wait,
96 }
97 }
98 }
99}
100
101const DEFAULT_SHARDS: usize = 16;
106
107struct ShardedMap {
108 shards: Vec<RwLock<HashMap<String, TokenBucket>>>,
109}
110
111impl ShardedMap {
112 fn new(num_shards: usize) -> Self {
113 let mut shards = Vec::with_capacity(num_shards);
114 for _ in 0..num_shards {
115 shards.push(RwLock::new(HashMap::new()));
116 }
117 Self { shards }
118 }
119
120 fn shard_index(&self, key: &str) -> usize {
121 let mut hasher = DefaultHasher::new();
122 key.hash(&mut hasher);
123 hasher.finish() as usize % self.shards.len()
124 }
125
126 fn check_or_insert(
127 &self,
128 key: &str,
129 per_second: u64,
130 burst_size: u32,
131 max_keys: usize,
132 ) -> CheckResult {
133 let idx = self.shard_index(key);
134 let shard = &self.shards[idx];
135
136 {
138 let read = shard.read().expect("rate limit shard lock poisoned");
139 if read.contains_key(key) {
140 drop(read);
141 let mut write = shard.write().expect("rate limit shard lock poisoned");
143 if let Some(bucket) = write.get_mut(key) {
144 return bucket.check(per_second, burst_size);
145 }
146 }
147 }
148
149 if max_keys > 0 {
151 let total: usize = self
152 .shards
153 .iter()
154 .map(|s| s.read().expect("rate limit shard lock poisoned").len())
155 .sum();
156 if total >= max_keys {
157 return CheckResult::Rejected {
158 retry_after_secs: 1.0,
159 };
160 }
161 }
162
163 let mut write = shard.write().expect("rate limit shard lock poisoned");
165 if let Some(bucket) = write.get_mut(key) {
167 return bucket.check(per_second, burst_size);
168 }
169
170 let mut bucket = TokenBucket::new(burst_size);
171 let result = bucket.check(per_second, burst_size);
172 write.insert(key.to_string(), bucket);
173 result
174 }
175
176 fn cleanup(&self, per_second: u64, burst_size: u32) {
177 let max_idle = if per_second > 0 {
178 std::time::Duration::from_secs_f64(burst_size as f64 / per_second as f64)
179 } else {
180 std::time::Duration::from_secs(3600)
181 };
182 let now = Instant::now();
183
184 for shard in &self.shards {
185 let mut write = shard.write().expect("rate limit shard lock poisoned");
186 write.retain(|_, bucket| now.duration_since(bucket.last_refill) < max_idle);
187 }
188 }
189}
190
191pub trait KeyExtractor: Clone + Send + Sync + 'static {
201 fn extract<B>(&self, req: &Request<B>) -> Option<String>;
204}
205
206#[derive(Debug, Clone)]
212pub struct PeerIpKeyExtractor;
213
214impl KeyExtractor for PeerIpKeyExtractor {
215 fn extract<B>(&self, req: &Request<B>) -> Option<String> {
216 req.extensions()
217 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
218 .map(|ci| ci.0.ip().to_string())
219 }
220}
221
222#[derive(Debug, Clone)]
226pub struct GlobalKeyExtractor;
227
228impl KeyExtractor for GlobalKeyExtractor {
229 fn extract<B>(&self, _req: &Request<B>) -> Option<String> {
230 Some("__global__".to_string())
231 }
232}
233
234pub struct RateLimitLayer<K> {
243 state: Arc<ShardedMap>,
244 config: RateLimitConfig,
245 extractor: K,
246}
247
248impl<K: Clone> Clone for RateLimitLayer<K> {
249 fn clone(&self) -> Self {
250 Self {
251 state: self.state.clone(),
252 config: self.config.clone(),
253 extractor: self.extractor.clone(),
254 }
255 }
256}
257
258impl<S, K: KeyExtractor> Layer<S> for RateLimitLayer<K> {
259 type Service = RateLimitService<S, K>;
260
261 fn layer(&self, inner: S) -> Self::Service {
262 RateLimitService {
263 inner,
264 state: self.state.clone(),
265 config: self.config.clone(),
266 extractor: self.extractor.clone(),
267 }
268 }
269}
270
271pub struct RateLimitService<S, K> {
277 inner: S,
278 state: Arc<ShardedMap>,
279 config: RateLimitConfig,
280 extractor: K,
281}
282
283impl<S: Clone, K: Clone> Clone for RateLimitService<S, K> {
284 fn clone(&self) -> Self {
285 Self {
286 inner: self.inner.clone(),
287 state: self.state.clone(),
288 config: self.config.clone(),
289 extractor: self.extractor.clone(),
290 }
291 }
292}
293
294impl<S, K> Service<Request<Body>> for RateLimitService<S, K>
295where
296 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
297 S::Future: Send,
298 K: KeyExtractor,
299{
300 type Response = Response<Body>;
301 type Error = S::Error;
302 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
303
304 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
305 self.inner.poll_ready(cx)
306 }
307
308 fn call(&mut self, req: Request<Body>) -> Self::Future {
309 let Some(key) = self.extractor.extract(&req) else {
310 let response =
312 crate::error::Error::internal("unable to extract rate-limit key").into_response();
313 return Box::pin(async move { Ok(response) });
314 };
315
316 let result = self.state.check_or_insert(
317 &key,
318 self.config.per_second,
319 self.config.burst_size,
320 self.config.max_keys,
321 );
322
323 match result {
324 CheckResult::Rejected { retry_after_secs } => {
325 let retry_secs = retry_after_secs.ceil() as u64;
326 let error =
327 crate::error::Error::too_many_requests(format!("retry after {retry_secs}s"))
328 .with_details(serde_json::json!({"retry_after": retry_secs}));
329 let mut response = error.into_response();
330
331 if self.config.use_headers {
332 let headers = response.headers_mut();
333 headers.insert("retry-after", retry_secs.into());
334 headers.insert("x-ratelimit-limit", self.config.burst_size.into());
335 headers.insert("x-ratelimit-remaining", 0u32.into());
336 }
337
338 Box::pin(async move { Ok(response) })
339 }
340 CheckResult::Allowed { remaining } => {
341 let use_headers = self.config.use_headers;
342 let burst_size = self.config.burst_size;
343 let per_second = self.config.per_second;
344 let mut inner = self.inner.clone();
345
346 Box::pin(async move {
347 let mut response = inner.call(req).await?;
348
349 if use_headers {
350 let headers = response.headers_mut();
351 if !headers.contains_key("x-ratelimit-limit") {
352 headers.insert("x-ratelimit-limit", burst_size.into());
353 }
354 if !headers.contains_key("x-ratelimit-remaining") {
355 headers.insert("x-ratelimit-remaining", remaining.into());
356 }
357 if !headers.contains_key("x-ratelimit-reset") {
358 let reset_secs =
359 (burst_size as u64)
360 .checked_div(per_second)
361 .map_or(0, |delta| {
362 let now = std::time::SystemTime::now()
363 .duration_since(std::time::UNIX_EPOCH)
364 .unwrap()
365 .as_secs();
366 now + delta
367 });
368 headers.insert("x-ratelimit-reset", reset_secs.into());
369 }
370 }
371
372 Ok(response)
373 })
374 }
375 }
376 }
377}
378
379pub fn rate_limit(
393 config: &RateLimitConfig,
394 cancel: CancellationToken,
395) -> RateLimitLayer<PeerIpKeyExtractor> {
396 rate_limit_with(config, PeerIpKeyExtractor, cancel)
397}
398
399pub fn rate_limit_with<K: KeyExtractor>(
408 config: &RateLimitConfig,
409 extractor: K,
410 cancel: CancellationToken,
411) -> RateLimitLayer<K> {
412 let state = Arc::new(ShardedMap::new(DEFAULT_SHARDS));
413 let cleanup_state = state.clone();
414 let per_second = config.per_second;
415 let burst_size = config.burst_size;
416 let interval = std::time::Duration::from_secs(config.cleanup_interval_secs);
417
418 tokio::spawn(async move {
419 loop {
420 tokio::select! {
421 _ = cancel.cancelled() => break,
422 _ = tokio::time::sleep(interval) => {
423 cleanup_state.cleanup(per_second, burst_size);
424 }
425 }
426 }
427 });
428
429 RateLimitLayer {
430 state,
431 config: config.clone(),
432 extractor,
433 }
434}
435
436#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
447 fn token_bucket_allows_within_burst() {
448 let mut bucket = TokenBucket::new(3);
449 for _ in 0..3 {
450 assert!(matches!(bucket.check(1, 3), CheckResult::Allowed { .. }));
451 }
452 }
453
454 #[test]
455 fn token_bucket_rejects_over_burst() {
456 let mut bucket = TokenBucket::new(2);
457 bucket.check(1, 2); bucket.check(1, 2); assert!(matches!(bucket.check(1, 2), CheckResult::Rejected { .. }));
460 }
461
462 #[test]
463 fn token_bucket_refills_over_time() {
464 let mut bucket = TokenBucket::new(1);
465 bucket.check(10, 1); bucket.last_refill = Instant::now() - std::time::Duration::from_secs(1);
468 assert!(matches!(bucket.check(10, 1), CheckResult::Allowed { .. }));
469 }
470
471 #[test]
472 fn token_bucket_remaining_count() {
473 let mut bucket = TokenBucket::new(5);
474 match bucket.check(1, 5) {
475 CheckResult::Allowed { remaining } => assert_eq!(remaining, 4),
476 _ => panic!("expected Allowed"),
477 }
478 }
479
480 #[test]
481 fn token_bucket_retry_after_positive() {
482 let mut bucket = TokenBucket::new(1);
483 bucket.check(1, 1); match bucket.check(1, 1) {
485 CheckResult::Rejected { retry_after_secs } => {
486 assert!(retry_after_secs > 0.0);
487 }
488 _ => panic!("expected Rejected"),
489 }
490 }
491
492 #[test]
495 fn sharded_map_allows_new_key() {
496 let map = ShardedMap::new(4);
497 assert!(matches!(
498 map.check_or_insert("ip1", 1, 5, 100),
499 CheckResult::Allowed { .. }
500 ));
501 }
502
503 #[test]
504 fn sharded_map_tracks_per_key() {
505 let map = ShardedMap::new(4);
506 map.check_or_insert("a", 1, 1, 100);
508 assert!(matches!(
509 map.check_or_insert("a", 1, 1, 100),
510 CheckResult::Rejected { .. }
511 ));
512 assert!(matches!(
514 map.check_or_insert("b", 1, 1, 100),
515 CheckResult::Allowed { .. }
516 ));
517 }
518
519 #[test]
520 fn sharded_map_max_keys_rejects_new() {
521 let map = ShardedMap::new(2);
522 map.check_or_insert("a", 1, 5, 2);
523 map.check_or_insert("b", 1, 5, 2);
524 assert!(matches!(
526 map.check_or_insert("c", 1, 5, 2),
527 CheckResult::Rejected { .. }
528 ));
529 }
530
531 #[test]
532 fn sharded_map_cleanup_removes_stale() {
533 let map = ShardedMap::new(2);
534 map.check_or_insert("a", 1, 1, 100);
535 {
537 let mut shard = map.shards[map.shard_index("a")].write().unwrap();
538 if let Some(bucket) = shard.get_mut("a") {
539 bucket.last_refill = Instant::now() - std::time::Duration::from_secs(10);
540 }
541 }
542 map.cleanup(1, 1); assert!(matches!(
545 map.check_or_insert("a", 1, 1, 100),
546 CheckResult::Allowed { .. }
547 ));
548 }
549}