1use std::collections::HashMap;
2use std::future::Future;
3use std::hash::{DefaultHasher, Hash, Hasher};
4use std::pin::Pin;
5use std::sync::Arc;
6use std::sync::RwLock;
7use std::task::{Context, Poll};
8use std::time::Instant;
9
10use axum::body::Body;
11use axum::response::IntoResponse;
12use http::{Request, Response};
13use serde::Deserialize;
14use tokio_util::sync::CancellationToken;
15use tower::{Layer, Service};
16
17#[non_exhaustive]
28#[derive(Debug, Clone, Deserialize)]
29#[serde(default)]
30pub struct RateLimitConfig {
31 pub per_second: u64,
33 pub burst_size: u32,
35 pub use_headers: bool,
37 pub cleanup_interval_secs: u64,
39 pub max_keys: usize,
42}
43
44impl Default for RateLimitConfig {
45 fn default() -> Self {
46 Self {
47 per_second: 1,
48 burst_size: 10,
49 use_headers: true,
50 cleanup_interval_secs: 60,
51 max_keys: 10_000,
52 }
53 }
54}
55
56struct TokenBucket {
61 tokens: f64,
62 last_refill: Instant,
63}
64
65enum CheckResult {
66 Allowed { remaining: u32 },
67 Rejected { retry_after_secs: f64 },
68}
69
70impl TokenBucket {
71 fn new(burst_size: u32) -> Self {
72 Self {
73 tokens: burst_size as f64,
74 last_refill: Instant::now(),
75 }
76 }
77
78 fn check(&mut self, per_second: u64, burst_size: u32) -> CheckResult {
79 let now = Instant::now();
80 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
81 self.last_refill = now;
82
83 self.tokens = (self.tokens + elapsed * per_second as f64).min(burst_size as f64);
85
86 if self.tokens >= 1.0 {
87 self.tokens -= 1.0;
88 CheckResult::Allowed {
89 remaining: self.tokens as u32,
90 }
91 } else {
92 let deficit = 1.0 - self.tokens;
93 let wait = deficit / per_second as f64;
94 CheckResult::Rejected {
95 retry_after_secs: wait,
96 }
97 }
98 }
99}
100
101const DEFAULT_SHARDS: usize = 16;
106
107struct ShardedMap {
108 shards: Vec<RwLock<HashMap<String, TokenBucket>>>,
109}
110
111impl ShardedMap {
112 fn new(num_shards: usize) -> Self {
113 let mut shards = Vec::with_capacity(num_shards);
114 for _ in 0..num_shards {
115 shards.push(RwLock::new(HashMap::new()));
116 }
117 Self { shards }
118 }
119
120 fn shard_index(&self, key: &str) -> usize {
121 let mut hasher = DefaultHasher::new();
122 key.hash(&mut hasher);
123 hasher.finish() as usize % self.shards.len()
124 }
125
126 fn check_or_insert(
127 &self,
128 key: &str,
129 per_second: u64,
130 burst_size: u32,
131 max_keys: usize,
132 ) -> CheckResult {
133 let idx = self.shard_index(key);
134 let shard = &self.shards[idx];
135
136 {
138 let read = shard.read().expect("rate limit shard lock poisoned");
139 if read.contains_key(key) {
140 drop(read);
141 let mut write = shard.write().expect("rate limit shard lock poisoned");
143 if let Some(bucket) = write.get_mut(key) {
144 return bucket.check(per_second, burst_size);
145 }
146 }
147 }
148
149 if max_keys > 0 {
151 let total: usize = self
152 .shards
153 .iter()
154 .map(|s| s.read().expect("rate limit shard lock poisoned").len())
155 .sum();
156 if total >= max_keys {
157 return CheckResult::Rejected {
158 retry_after_secs: 1.0,
159 };
160 }
161 }
162
163 let mut write = shard.write().expect("rate limit shard lock poisoned");
165 if let Some(bucket) = write.get_mut(key) {
167 return bucket.check(per_second, burst_size);
168 }
169
170 let mut bucket = TokenBucket::new(burst_size);
171 let result = bucket.check(per_second, burst_size);
172 write.insert(key.to_string(), bucket);
173 result
174 }
175
176 fn cleanup(&self, per_second: u64, burst_size: u32) {
177 let max_idle = if per_second > 0 {
178 std::time::Duration::from_secs_f64(burst_size as f64 / per_second as f64)
179 } else {
180 std::time::Duration::from_secs(3600)
181 };
182 let now = Instant::now();
183
184 for shard in &self.shards {
185 let mut write = shard.write().expect("rate limit shard lock poisoned");
186 write.retain(|_, bucket| now.duration_since(bucket.last_refill) < max_idle);
187 }
188 }
189}
190
191pub trait KeyExtractor: Clone + Send + Sync + 'static {
201 fn extract<B>(&self, req: &Request<B>) -> Option<String>;
204}
205
206#[derive(Debug, Clone)]
212pub struct PeerIpKeyExtractor;
213
214impl KeyExtractor for PeerIpKeyExtractor {
215 fn extract<B>(&self, req: &Request<B>) -> Option<String> {
216 req.extensions()
217 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
218 .map(|ci| ci.0.ip().to_string())
219 }
220}
221
222#[derive(Debug, Clone)]
226pub struct GlobalKeyExtractor;
227
228impl KeyExtractor for GlobalKeyExtractor {
229 fn extract<B>(&self, _req: &Request<B>) -> Option<String> {
230 Some("__global__".to_string())
231 }
232}
233
234pub struct RateLimitLayer<K> {
243 state: Arc<ShardedMap>,
244 config: RateLimitConfig,
245 extractor: K,
246}
247
248impl<K: Clone> Clone for RateLimitLayer<K> {
249 fn clone(&self) -> Self {
250 Self {
251 state: self.state.clone(),
252 config: self.config.clone(),
253 extractor: self.extractor.clone(),
254 }
255 }
256}
257
258impl<S, K: KeyExtractor> Layer<S> for RateLimitLayer<K> {
259 type Service = RateLimitService<S, K>;
260
261 fn layer(&self, inner: S) -> Self::Service {
262 RateLimitService {
263 inner,
264 state: self.state.clone(),
265 config: self.config.clone(),
266 extractor: self.extractor.clone(),
267 }
268 }
269}
270
271pub struct RateLimitService<S, K> {
277 inner: S,
278 state: Arc<ShardedMap>,
279 config: RateLimitConfig,
280 extractor: K,
281}
282
283impl<S: Clone, K: Clone> Clone for RateLimitService<S, K> {
284 fn clone(&self) -> Self {
285 Self {
286 inner: self.inner.clone(),
287 state: self.state.clone(),
288 config: self.config.clone(),
289 extractor: self.extractor.clone(),
290 }
291 }
292}
293
294impl<S, K> Service<Request<Body>> for RateLimitService<S, K>
295where
296 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
297 S::Future: Send,
298 K: KeyExtractor,
299{
300 type Response = Response<Body>;
301 type Error = S::Error;
302 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
303
304 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
305 self.inner.poll_ready(cx)
306 }
307
308 fn call(&mut self, req: Request<Body>) -> Self::Future {
309 let Some(key) = self.extractor.extract(&req) else {
310 let response =
312 crate::error::Error::internal("unable to extract rate-limit key").into_response();
313 return Box::pin(async move { Ok(response) });
314 };
315
316 let result = self.state.check_or_insert(
317 &key,
318 self.config.per_second,
319 self.config.burst_size,
320 self.config.max_keys,
321 );
322
323 match result {
324 CheckResult::Rejected { retry_after_secs } => {
325 let retry_secs = retry_after_secs.ceil() as u64;
326 let error =
327 crate::error::Error::too_many_requests(format!("retry after {retry_secs}s"))
328 .with_details(serde_json::json!({"retry_after": retry_secs}));
329 let mut response = error.into_response();
330
331 if self.config.use_headers {
332 let headers = response.headers_mut();
333 headers.insert("retry-after", retry_secs.into());
334 headers.insert("x-ratelimit-limit", self.config.burst_size.into());
335 headers.insert("x-ratelimit-remaining", 0u32.into());
336 }
337
338 Box::pin(async move { Ok(response) })
339 }
340 CheckResult::Allowed { remaining } => {
341 let use_headers = self.config.use_headers;
342 let burst_size = self.config.burst_size;
343 let per_second = self.config.per_second;
344 let mut inner = self.inner.clone();
345
346 Box::pin(async move {
347 let mut response = inner.call(req).await?;
348
349 if use_headers {
350 let headers = response.headers_mut();
351 if !headers.contains_key("x-ratelimit-limit") {
352 headers.insert("x-ratelimit-limit", burst_size.into());
353 }
354 if !headers.contains_key("x-ratelimit-remaining") {
355 headers.insert("x-ratelimit-remaining", remaining.into());
356 }
357 if !headers.contains_key("x-ratelimit-reset") {
358 let reset_secs = if per_second > 0 {
359 let now = std::time::SystemTime::now()
360 .duration_since(std::time::UNIX_EPOCH)
361 .unwrap()
362 .as_secs();
363 now + (burst_size as u64 / per_second)
364 } else {
365 0
366 };
367 headers.insert("x-ratelimit-reset", reset_secs.into());
368 }
369 }
370
371 Ok(response)
372 })
373 }
374 }
375 }
376}
377
378pub fn rate_limit(
392 config: &RateLimitConfig,
393 cancel: CancellationToken,
394) -> RateLimitLayer<PeerIpKeyExtractor> {
395 rate_limit_with(config, PeerIpKeyExtractor, cancel)
396}
397
398pub fn rate_limit_with<K: KeyExtractor>(
407 config: &RateLimitConfig,
408 extractor: K,
409 cancel: CancellationToken,
410) -> RateLimitLayer<K> {
411 let state = Arc::new(ShardedMap::new(DEFAULT_SHARDS));
412 let cleanup_state = state.clone();
413 let per_second = config.per_second;
414 let burst_size = config.burst_size;
415 let interval = std::time::Duration::from_secs(config.cleanup_interval_secs);
416
417 tokio::spawn(async move {
418 loop {
419 tokio::select! {
420 _ = cancel.cancelled() => break,
421 _ = tokio::time::sleep(interval) => {
422 cleanup_state.cleanup(per_second, burst_size);
423 }
424 }
425 }
426 });
427
428 RateLimitLayer {
429 state,
430 config: config.clone(),
431 extractor,
432 }
433}
434
435#[cfg(test)]
440mod tests {
441 use super::*;
442
443 #[test]
446 fn token_bucket_allows_within_burst() {
447 let mut bucket = TokenBucket::new(3);
448 for _ in 0..3 {
449 assert!(matches!(bucket.check(1, 3), CheckResult::Allowed { .. }));
450 }
451 }
452
453 #[test]
454 fn token_bucket_rejects_over_burst() {
455 let mut bucket = TokenBucket::new(2);
456 bucket.check(1, 2); bucket.check(1, 2); assert!(matches!(bucket.check(1, 2), CheckResult::Rejected { .. }));
459 }
460
461 #[test]
462 fn token_bucket_refills_over_time() {
463 let mut bucket = TokenBucket::new(1);
464 bucket.check(10, 1); bucket.last_refill = Instant::now() - std::time::Duration::from_secs(1);
467 assert!(matches!(bucket.check(10, 1), CheckResult::Allowed { .. }));
468 }
469
470 #[test]
471 fn token_bucket_remaining_count() {
472 let mut bucket = TokenBucket::new(5);
473 match bucket.check(1, 5) {
474 CheckResult::Allowed { remaining } => assert_eq!(remaining, 4),
475 _ => panic!("expected Allowed"),
476 }
477 }
478
479 #[test]
480 fn token_bucket_retry_after_positive() {
481 let mut bucket = TokenBucket::new(1);
482 bucket.check(1, 1); match bucket.check(1, 1) {
484 CheckResult::Rejected { retry_after_secs } => {
485 assert!(retry_after_secs > 0.0);
486 }
487 _ => panic!("expected Rejected"),
488 }
489 }
490
491 #[test]
494 fn sharded_map_allows_new_key() {
495 let map = ShardedMap::new(4);
496 assert!(matches!(
497 map.check_or_insert("ip1", 1, 5, 100),
498 CheckResult::Allowed { .. }
499 ));
500 }
501
502 #[test]
503 fn sharded_map_tracks_per_key() {
504 let map = ShardedMap::new(4);
505 map.check_or_insert("a", 1, 1, 100);
507 assert!(matches!(
508 map.check_or_insert("a", 1, 1, 100),
509 CheckResult::Rejected { .. }
510 ));
511 assert!(matches!(
513 map.check_or_insert("b", 1, 1, 100),
514 CheckResult::Allowed { .. }
515 ));
516 }
517
518 #[test]
519 fn sharded_map_max_keys_rejects_new() {
520 let map = ShardedMap::new(2);
521 map.check_or_insert("a", 1, 5, 2);
522 map.check_or_insert("b", 1, 5, 2);
523 assert!(matches!(
525 map.check_or_insert("c", 1, 5, 2),
526 CheckResult::Rejected { .. }
527 ));
528 }
529
530 #[test]
531 fn sharded_map_cleanup_removes_stale() {
532 let map = ShardedMap::new(2);
533 map.check_or_insert("a", 1, 1, 100);
534 {
536 let mut shard = map.shards[map.shard_index("a")].write().unwrap();
537 if let Some(bucket) = shard.get_mut("a") {
538 bucket.last_refill = Instant::now() - std::time::Duration::from_secs(10);
539 }
540 }
541 map.cleanup(1, 1); assert!(matches!(
544 map.check_or_insert("a", 1, 1, 100),
545 CheckResult::Allowed { .. }
546 ));
547 }
548}