1use std::collections::HashMap;
4use std::hash::Hash;
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use axum::body::Body;
10use axum::http::{Request, Response, StatusCode};
11use futures_util::future::BoxFuture;
12use tokio::sync::Mutex;
13use tower::{Layer, Service};
14
15const MAX_DISTINCT_IPS: usize = 10_000;
19const MAX_DISTINCT_ACTORS: usize = 5_000;
20
21#[derive(Clone)]
23pub struct GlobalRateLimitLayer {
24 state: Arc<Mutex<RateLimitState>>,
25 capacity: u64,
26 per_ip_capacity: u64,
27 per_actor_capacity: u64,
28 window: Duration,
29 trusted_proxy_cidrs: Vec<IpCidr>,
30}
31
32struct RateLimitState {
33 count: u64,
34 window_start: Instant,
35 per_ip: HashMap<IpAddr, (u64, Instant)>,
36 per_actor: HashMap<String, (u64, Instant)>,
37 throttled_per_ip: HashMap<IpAddr, u64>,
38 throttled_per_actor: HashMap<String, u64>,
39 throttled_global: u64,
40}
41
42#[derive(Clone, Debug)]
43struct IpCidr {
44 network: IpAddr,
45 prefix_len: u8,
46}
47
48impl GlobalRateLimitLayer {
49 pub fn new(capacity: u64, window: Duration) -> Self {
51 Self {
52 state: Arc::new(Mutex::new(RateLimitState {
53 count: 0,
54 window_start: Instant::now(),
55 per_ip: HashMap::new(),
56 per_actor: HashMap::new(),
57 throttled_per_ip: HashMap::new(),
58 throttled_per_actor: HashMap::new(),
59 throttled_global: 0,
60 })),
61 capacity,
62 per_ip_capacity: 300,
63 per_actor_capacity: 200,
64 window,
65 trusted_proxy_cidrs: Vec::new(),
66 }
67 }
68
69 pub fn with_per_ip_capacity(mut self, per_ip_capacity: u64) -> Self {
70 self.per_ip_capacity = per_ip_capacity;
71 self
72 }
73
74 pub fn with_per_actor_capacity(mut self, per_actor_capacity: u64) -> Self {
75 self.per_actor_capacity = per_actor_capacity;
76 self
77 }
78
79 pub fn with_trusted_proxy_cidrs(mut self, cidrs: &[String]) -> Self {
80 self.trusted_proxy_cidrs = cidrs
81 .iter()
82 .filter_map(|c| IpCidr::parse(c))
83 .collect::<Vec<_>>();
84 self
85 }
86
87 fn evict_stale<K>(counter: &mut HashMap<K, (u64, Instant)>, window: Duration)
88 where
89 K: Eq + Hash,
90 {
91 let now = Instant::now();
92 counter.retain(|_, (_, start)| now.duration_since(*start) < window);
93 }
94
95 pub async fn snapshot(&self) -> ThrottleSnapshot {
100 let guard = self.state.lock().await;
101
102 let mut top_ips: Vec<_> = guard
103 .throttled_per_ip
104 .iter()
105 .map(|(ip, &count)| (ip.to_string(), count))
106 .collect();
107 top_ips.sort_by(|a, b| b.1.cmp(&a.1));
108 top_ips.truncate(10);
109
110 let mut top_actors: Vec<_> = guard
111 .throttled_per_actor
112 .iter()
113 .map(|(actor, &count)| (actor.clone(), count))
114 .collect();
115 top_actors.sort_by(|a, b| b.1.cmp(&a.1));
116 top_actors.truncate(10);
117
118 ThrottleSnapshot {
119 window_secs: self.window.as_secs(),
120 global_count: guard.count,
121 global_capacity: self.capacity,
122 per_ip_capacity: self.per_ip_capacity,
123 per_actor_capacity: self.per_actor_capacity,
124 throttled_global: guard.throttled_global,
125 active_ips: guard.per_ip.len(),
126 active_actors: guard.per_actor.len(),
127 top_throttled_ips: top_ips,
128 top_throttled_actors: top_actors,
129 }
130 }
131}
132
133#[derive(Debug, Clone, serde::Serialize)]
135pub struct ThrottleSnapshot {
136 pub window_secs: u64,
137 pub global_count: u64,
138 pub global_capacity: u64,
139 pub per_ip_capacity: u64,
140 pub per_actor_capacity: u64,
141 pub throttled_global: u64,
142 pub active_ips: usize,
143 pub active_actors: usize,
144 pub top_throttled_ips: Vec<(String, u64)>,
145 pub top_throttled_actors: Vec<(String, u64)>,
146}
147
148impl<S> Layer<S> for GlobalRateLimitLayer {
149 type Service = GlobalRateLimitService<S>;
150
151 fn layer(&self, inner: S) -> Self::Service {
152 GlobalRateLimitService {
153 inner,
154 state: self.state.clone(),
155 capacity: self.capacity,
156 per_ip_capacity: self.per_ip_capacity,
157 per_actor_capacity: self.per_actor_capacity,
158 window: self.window,
159 trusted_proxy_cidrs: self.trusted_proxy_cidrs.clone(),
160 }
161 }
162}
163
164#[derive(Clone)]
165pub struct GlobalRateLimitService<S> {
166 inner: S,
167 state: Arc<Mutex<RateLimitState>>,
168 capacity: u64,
169 per_ip_capacity: u64,
170 per_actor_capacity: u64,
171 window: Duration,
172 trusted_proxy_cidrs: Vec<IpCidr>,
173}
174
175fn too_many_requests_response() -> Response<Body> {
176 let body = serde_json::json!({
177 "error": "rate_limit_exceeded",
178 "message": "Too many requests, please try again later"
179 });
180 let body_bytes = serde_json::to_vec(&body)
181 .unwrap_or_else(|_| br#"{"error":"rate_limit_exceeded"}"#.to_vec());
182 match Response::builder()
183 .status(StatusCode::TOO_MANY_REQUESTS)
184 .header("content-type", "application/json")
185 .body(Body::from(body_bytes))
186 {
187 Ok(resp) => resp,
188 Err(_) => {
189 let mut resp = Response::new(Body::from(
190 br#"{"error":"rate_limit_exceeded"}"#.as_slice().to_vec(),
191 ));
192 *resp.status_mut() = StatusCode::TOO_MANY_REQUESTS;
193 resp
194 }
195 }
196}
197
198fn stable_token_fingerprint(raw: &str) -> String {
199 use sha2::{Digest, Sha256};
200 let hash = Sha256::digest(raw.as_bytes());
201 hex::encode(&hash[..8])
204}
205
206fn extract_actor_id(req: &Request<Body>) -> Option<String> {
207 let principal = crate::auth::extract_auth_principal(req);
208 if let Some(v) = req.headers().get("x-api-key")
209 && let Ok(raw) = v.to_str()
210 && !raw.is_empty()
211 {
212 return Some(format!("api_key:{}", stable_token_fingerprint(raw)));
213 }
214 if let Some(v) = req.headers().get("authorization")
215 && let Ok(raw) = v.to_str()
216 && let Some(token) = raw.strip_prefix("Bearer ")
217 && !token.is_empty()
218 {
219 return Some(format!("bearer:{}", stable_token_fingerprint(token)));
220 }
221 principal
224}
225
226fn parse_ip(s: &str) -> Option<IpAddr> {
227 s.trim().parse().ok()
228}
229
230fn forwarded_ip(req: &Request<Body>) -> Option<IpAddr> {
231 req.headers()
232 .get("x-forwarded-for")
233 .and_then(|v| v.to_str().ok())
234 .and_then(|s| s.split(',').next())
235 .and_then(parse_ip)
236}
237
238fn real_ip(req: &Request<Body>) -> Option<IpAddr> {
239 req.headers()
240 .get("x-real-ip")
241 .and_then(|v| v.to_str().ok())
242 .and_then(parse_ip)
243}
244
245fn trust_forwarded_headers(proxy_ip: IpAddr, trusted_proxy_cidrs: &[IpCidr]) -> bool {
246 trusted_proxy_cidrs
247 .iter()
248 .any(|cidr| cidr.contains(proxy_ip))
249}
250
251fn resolve_client_ip(req: &Request<Body>, trusted_proxy_cidrs: &[IpCidr]) -> IpAddr {
252 let forwarded = forwarded_ip(req);
253 let real = real_ip(req);
254
255 if let (Some(client_ip), Some(proxy_ip)) = (forwarded, real)
256 && trust_forwarded_headers(proxy_ip, trusted_proxy_cidrs)
257 {
258 return client_ip;
259 }
260
261 if let Some(proxy_ip) = real {
262 return proxy_ip;
263 }
264
265 use axum::extract::ConnectInfo;
269 use std::net::SocketAddr;
270 req.extensions()
271 .get::<ConnectInfo<SocketAddr>>()
272 .map(|ci| ci.0.ip())
273 .unwrap_or(IpAddr::from([127, 0, 0, 1]))
274}
275
276impl IpCidr {
277 fn parse(raw: &str) -> Option<Self> {
278 let (ip, prefix) = raw.split_once('/')?;
279 let network = ip.parse::<IpAddr>().ok()?;
280 let prefix_len = prefix.parse::<u8>().ok()?;
281 let max = match network {
282 IpAddr::V4(_) => 32,
283 IpAddr::V6(_) => 128,
284 };
285 if prefix_len > max {
286 return None;
287 }
288 Some(Self {
289 network,
290 prefix_len,
291 })
292 }
293
294 fn contains(&self, ip: IpAddr) -> bool {
295 match (self.network, ip) {
296 (IpAddr::V4(net), IpAddr::V4(candidate)) => {
297 cidr_match_v4(net, candidate, self.prefix_len)
298 }
299 (IpAddr::V6(net), IpAddr::V6(candidate)) => {
300 cidr_match_v6(net, candidate, self.prefix_len)
301 }
302 _ => false,
303 }
304 }
305}
306
307fn cidr_match_v4(network: Ipv4Addr, candidate: Ipv4Addr, prefix_len: u8) -> bool {
308 let mask = if prefix_len == 0 {
309 0
310 } else {
311 u32::MAX << (32 - prefix_len)
312 };
313 (u32::from(network) & mask) == (u32::from(candidate) & mask)
314}
315
316fn cidr_match_v6(network: Ipv6Addr, candidate: Ipv6Addr, prefix_len: u8) -> bool {
317 let net = u128::from_be_bytes(network.octets());
318 let cand = u128::from_be_bytes(candidate.octets());
319 let mask = if prefix_len == 0 {
320 0
321 } else {
322 u128::MAX << (128 - prefix_len)
323 };
324 (net & mask) == (cand & mask)
325}
326
327impl<S> Service<Request<Body>> for GlobalRateLimitService<S>
328where
329 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
330 S::Future: Send + 'static,
331{
332 type Response = Response<Body>;
333 type Error = S::Error;
334 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
335
336 fn poll_ready(
337 &mut self,
338 cx: &mut std::task::Context<'_>,
339 ) -> std::task::Poll<Result<(), Self::Error>> {
340 self.inner.poll_ready(cx)
341 }
342
343 fn call(&mut self, req: Request<Body>) -> Self::Future {
344 let mut inner = self.inner.clone();
345 let state = self.state.clone();
346 let capacity = self.capacity;
347 let per_ip_capacity = self.per_ip_capacity;
348 let per_actor_capacity = self.per_actor_capacity;
349 let window = self.window;
350 let trusted_proxy_cidrs = self.trusted_proxy_cidrs.clone();
351 let ip = resolve_client_ip(&req, &trusted_proxy_cidrs);
352 let actor = extract_actor_id(&req);
353
354 Box::pin(async move {
355 let now = Instant::now();
356 let mut guard = state.lock().await;
357 if now.duration_since(guard.window_start) >= window {
358 guard.window_start = now;
359 guard.count = 0;
360 GlobalRateLimitLayer::evict_stale(&mut guard.per_ip, window);
361 GlobalRateLimitLayer::evict_stale(&mut guard.per_actor, window);
362 guard.throttled_per_ip.clear();
363 guard.throttled_per_actor.clear();
364 guard.throttled_global = 0;
365 }
366 if guard.count >= capacity {
367 guard.throttled_global += 1;
368 return Ok(too_many_requests_response());
369 }
370
371 let per_ip_cap = per_ip_capacity;
373 if !guard.per_ip.contains_key(&ip) && guard.per_ip.len() >= MAX_DISTINCT_IPS {
374 return Ok(too_many_requests_response());
375 }
376 let ip_entry = guard.per_ip.entry(ip).or_insert((0, now));
377 if now.duration_since(ip_entry.1) >= window {
378 *ip_entry = (0, now);
379 }
380 if ip_entry.0 >= per_ip_cap {
381 *guard.throttled_per_ip.entry(ip).or_insert(0) += 1;
382 return Ok(too_many_requests_response());
383 }
384 ip_entry.0 += 1;
385
386 if let Some(ref actor_id) = actor {
388 if !guard.per_actor.contains_key(actor_id)
389 && guard.per_actor.len() >= MAX_DISTINCT_ACTORS
390 {
391 return Ok(too_many_requests_response());
392 }
393 let actor_entry = guard.per_actor.entry(actor_id.clone()).or_insert((0, now));
394 if now.duration_since(actor_entry.1) >= window {
395 *actor_entry = (0, now);
396 }
397 if actor_entry.0 >= per_actor_capacity {
398 *guard
399 .throttled_per_actor
400 .entry(actor_id.clone())
401 .or_insert(0) += 1;
402 return Ok(too_many_requests_response());
403 }
404 actor_entry.0 += 1;
405 }
406
407 guard.count += 1;
409
410 drop(guard);
411
412 inner.call(req).await
413 })
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use axum::body::Body;
421 use axum::http::{Request, StatusCode};
422 use tower::{Service, ServiceExt};
423
424 fn dummy_service() -> axum::routing::Router {
425 axum::routing::Router::new().route("/", axum::routing::get(|| async { "ok" }))
426 }
427
428 #[tokio::test]
429 async fn allows_requests_within_capacity() {
430 let layer = GlobalRateLimitLayer::new(5, Duration::from_secs(60));
431 let mut svc = layer.layer(dummy_service().into_service());
432 for _ in 0..5 {
433 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
434 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
435 assert_eq!(resp.status(), StatusCode::OK);
436 }
437 }
438
439 #[tokio::test]
440 async fn returns_429_when_capacity_exceeded() {
441 let layer = GlobalRateLimitLayer::new(2, Duration::from_secs(60));
442 let mut svc = layer.layer(dummy_service().into_service());
443 for _ in 0..2 {
444 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
445 let _ = svc.ready().await.unwrap().call(req).await.unwrap();
446 }
447 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
448 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
449 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
450 }
451
452 #[tokio::test]
453 async fn window_resets_after_expiry() {
454 let layer = GlobalRateLimitLayer::new(1, Duration::from_millis(50));
455 let mut svc = layer.layer(dummy_service().into_service());
456 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
457 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
458 assert_eq!(resp.status(), StatusCode::OK);
459 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
460 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
461 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
462 tokio::time::sleep(Duration::from_millis(60)).await;
463 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
464 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
465 assert_eq!(resp.status(), StatusCode::OK);
466 }
467
468 #[tokio::test]
469 async fn per_ip_limits_enforced() {
470 let layer = GlobalRateLimitLayer::new(1000, Duration::from_secs(60));
471 let mut svc = layer.layer(dummy_service().into_service());
472 for _ in 0..300 {
473 let req = Request::builder()
474 .uri("/")
475 .header("x-real-ip", "1.2.3.4")
476 .body(Body::empty())
477 .unwrap();
478 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
479 assert_eq!(resp.status(), StatusCode::OK);
480 }
481 let req = Request::builder()
482 .uri("/")
483 .header("x-real-ip", "1.2.3.4")
484 .body(Body::empty())
485 .unwrap();
486 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
487 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
488 let req = Request::builder()
489 .uri("/")
490 .header("x-real-ip", "5.6.7.8")
491 .body(Body::empty())
492 .unwrap();
493 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
494 assert_eq!(resp.status(), StatusCode::OK);
495 }
496
497 #[test]
498 fn cidr_parse_and_contains() {
499 let cidr = IpCidr::parse("10.0.0.0/8").expect("cidr");
500 assert!(cidr.contains("10.1.2.3".parse().unwrap()));
501 assert!(!cidr.contains("11.1.2.3".parse().unwrap()));
502 }
503
504 #[test]
505 fn trusted_proxy_resolution_prefers_forwarded_when_proxy_trusted() {
506 let req = Request::builder()
507 .header("x-forwarded-for", "1.2.3.4")
508 .header("x-real-ip", "10.0.0.5")
509 .body(Body::empty())
510 .unwrap();
511 let cidr = IpCidr::parse("10.0.0.0/8").unwrap();
512 let ip = resolve_client_ip(&req, &[cidr]);
513 assert_eq!(ip, "1.2.3.4".parse::<IpAddr>().unwrap());
514 }
515
516 #[test]
517 fn untrusted_proxy_resolution_uses_direct_ip() {
518 let req = Request::builder()
519 .header("x-forwarded-for", "1.2.3.4")
520 .header("x-real-ip", "198.51.100.2")
521 .body(Body::empty())
522 .unwrap();
523 let ip = resolve_client_ip(&req, &[]);
524 assert_eq!(ip, "198.51.100.2".parse::<IpAddr>().unwrap());
525 }
526
527 #[test]
528 fn forwarded_header_without_trusted_proxy_is_ignored() {
529 let req = Request::builder()
530 .header("x-forwarded-for", "1.2.3.4")
531 .body(Body::empty())
532 .unwrap();
533 let ip = resolve_client_ip(&req, &[]);
534 assert_eq!(ip, "127.0.0.1".parse::<IpAddr>().unwrap());
535 }
536
537 #[tokio::test]
538 async fn actor_limits_enforced() {
539 let layer =
540 GlobalRateLimitLayer::new(1000, Duration::from_secs(60)).with_per_actor_capacity(2);
541 let mut svc = layer.layer(dummy_service().into_service());
542 for _ in 0..2 {
543 let req = Request::builder()
544 .uri("/")
545 .header("authorization", "Bearer actor-token")
546 .body(Body::empty())
547 .unwrap();
548 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
549 assert_eq!(resp.status(), StatusCode::OK);
550 }
551 let req = Request::builder()
552 .uri("/")
553 .header("authorization", "Bearer actor-token")
554 .body(Body::empty())
555 .unwrap();
556 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
557 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
558 }
559
560 #[tokio::test]
561 async fn snapshot_reflects_throttle_state() {
562 let layer = GlobalRateLimitLayer::new(2, Duration::from_secs(60));
563 let mut svc = layer.layer(dummy_service().into_service());
564
565 for _ in 0..2 {
567 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
568 let _ = svc.ready().await.unwrap().call(req).await.unwrap();
569 }
570 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
572 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
573 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
574
575 let snap = layer.snapshot().await;
576 assert_eq!(snap.global_count, 2);
577 assert_eq!(snap.global_capacity, 2);
578 assert!(
579 snap.throttled_global >= 1,
580 "should record ≥1 throttled global"
581 );
582 assert_eq!(snap.window_secs, 60);
583 }
584}