1use std::collections::HashMap;
4use std::hash::Hash;
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use axum::body::Body;
10use axum::http::{Request, Response, StatusCode};
11use futures_util::future::BoxFuture;
12use tokio::sync::Mutex;
13use tower::{Layer, Service};
14
15const MAX_DISTINCT_IPS: usize = 10_000;
19const MAX_DISTINCT_ACTORS: usize = 5_000;
20
21#[derive(Clone)]
23pub struct GlobalRateLimitLayer {
24 state: Arc<Mutex<RateLimitState>>,
25 capacity: u64,
26 per_ip_capacity: u64,
27 per_actor_capacity: u64,
28 window: Duration,
29 trusted_proxy_cidrs: Vec<IpCidr>,
30}
31
32struct RateLimitState {
33 count: u64,
34 window_start: Instant,
35 per_ip: HashMap<IpAddr, (u64, Instant)>,
36 per_actor: HashMap<String, (u64, Instant)>,
37 throttled_per_ip: HashMap<IpAddr, u64>,
38 throttled_per_actor: HashMap<String, u64>,
39 throttled_global: u64,
40}
41
42#[derive(Clone, Debug)]
43struct IpCidr {
44 network: IpAddr,
45 prefix_len: u8,
46}
47
48impl GlobalRateLimitLayer {
49 pub fn new(capacity: u64, window: Duration) -> Self {
51 Self {
52 state: Arc::new(Mutex::new(RateLimitState {
53 count: 0,
54 window_start: Instant::now(),
55 per_ip: HashMap::new(),
56 per_actor: HashMap::new(),
57 throttled_per_ip: HashMap::new(),
58 throttled_per_actor: HashMap::new(),
59 throttled_global: 0,
60 })),
61 capacity,
62 per_ip_capacity: 300,
63 per_actor_capacity: 200,
64 window,
65 trusted_proxy_cidrs: Vec::new(),
66 }
67 }
68
69 pub fn with_per_ip_capacity(mut self, per_ip_capacity: u64) -> Self {
70 self.per_ip_capacity = per_ip_capacity;
71 self
72 }
73
74 pub fn with_per_actor_capacity(mut self, per_actor_capacity: u64) -> Self {
75 self.per_actor_capacity = per_actor_capacity;
76 self
77 }
78
79 pub fn with_trusted_proxy_cidrs(mut self, cidrs: &[String]) -> Self {
80 self.trusted_proxy_cidrs = cidrs
81 .iter()
82 .filter_map(|c| IpCidr::parse(c))
83 .collect::<Vec<_>>();
84 self
85 }
86
87 fn evict_stale<K>(counter: &mut HashMap<K, (u64, Instant)>, window: Duration)
88 where
89 K: Eq + Hash,
90 {
91 let now = Instant::now();
92 counter.retain(|_, (_, start)| now.duration_since(*start) < window);
93 }
94
95 pub async fn snapshot(&self) -> ThrottleSnapshot {
100 let guard = self.state.lock().await;
101
102 let mut top_ips: Vec<_> = guard
103 .throttled_per_ip
104 .iter()
105 .map(|(ip, &count)| (ip.to_string(), count))
106 .collect();
107 top_ips.sort_by(|a, b| b.1.cmp(&a.1));
108 top_ips.truncate(10);
109
110 let mut top_actors: Vec<_> = guard
111 .throttled_per_actor
112 .iter()
113 .map(|(actor, &count)| (actor.clone(), count))
114 .collect();
115 top_actors.sort_by(|a, b| b.1.cmp(&a.1));
116 top_actors.truncate(10);
117
118 ThrottleSnapshot {
119 window_secs: self.window.as_secs(),
120 global_count: guard.count,
121 global_capacity: self.capacity,
122 per_ip_capacity: self.per_ip_capacity,
123 per_actor_capacity: self.per_actor_capacity,
124 throttled_global: guard.throttled_global,
125 active_ips: guard.per_ip.len(),
126 active_actors: guard.per_actor.len(),
127 top_throttled_ips: top_ips,
128 top_throttled_actors: top_actors,
129 }
130 }
131}
132
133#[derive(Debug, Clone, serde::Serialize)]
135pub struct ThrottleSnapshot {
136 pub window_secs: u64,
137 pub global_count: u64,
138 pub global_capacity: u64,
139 pub per_ip_capacity: u64,
140 pub per_actor_capacity: u64,
141 pub throttled_global: u64,
142 pub active_ips: usize,
143 pub active_actors: usize,
144 pub top_throttled_ips: Vec<(String, u64)>,
145 pub top_throttled_actors: Vec<(String, u64)>,
146}
147
148impl<S> Layer<S> for GlobalRateLimitLayer {
149 type Service = GlobalRateLimitService<S>;
150
151 fn layer(&self, inner: S) -> Self::Service {
152 GlobalRateLimitService {
153 inner,
154 state: self.state.clone(),
155 capacity: self.capacity,
156 per_ip_capacity: self.per_ip_capacity,
157 per_actor_capacity: self.per_actor_capacity,
158 window: self.window,
159 trusted_proxy_cidrs: self.trusted_proxy_cidrs.clone(),
160 }
161 }
162}
163
164#[derive(Clone)]
165pub struct GlobalRateLimitService<S> {
166 inner: S,
167 state: Arc<Mutex<RateLimitState>>,
168 capacity: u64,
169 per_ip_capacity: u64,
170 per_actor_capacity: u64,
171 window: Duration,
172 trusted_proxy_cidrs: Vec<IpCidr>,
173}
174
175fn too_many_requests_response(limit: u64, window_secs: u64) -> Response<Body> {
176 let body = serde_json::json!({
177 "type": "about:blank",
178 "title": "Too Many Requests",
179 "status": 429,
180 "detail": "rate_limit_exceeded"
181 });
182 let body_bytes = serde_json::to_vec(&body)
183 .unwrap_or_else(|_| br#"{"type":"about:blank","title":"Too Many Requests","status":429,"detail":"rate_limit_exceeded"}"#.to_vec());
184 match Response::builder()
185 .status(StatusCode::TOO_MANY_REQUESTS)
186 .header("content-type", "application/problem+json")
187 .header("ratelimit-limit", limit.to_string())
188 .header("ratelimit-remaining", "0")
189 .header("ratelimit-reset", window_secs.to_string())
190 .header("retry-after", window_secs.to_string())
191 .body(Body::from(body_bytes))
192 {
193 Ok(resp) => resp,
194 Err(_) => {
195 let mut resp = Response::new(Body::from(
196 br#"{"type":"about:blank","title":"Too Many Requests","status":429,"detail":"rate_limit_exceeded"}"#.as_slice().to_vec(),
197 ));
198 *resp.status_mut() = StatusCode::TOO_MANY_REQUESTS;
199 resp
200 }
201 }
202}
203
204fn inject_rate_limit_headers(
206 resp: &mut Response<Body>,
207 limit: u64,
208 remaining: u64,
209 reset_secs: u64,
210) {
211 let headers = resp.headers_mut();
212 headers.insert(
213 "ratelimit-limit",
214 limit.to_string().parse().expect("numeric header value"),
215 );
216 headers.insert(
217 "ratelimit-remaining",
218 remaining.to_string().parse().expect("numeric header value"),
219 );
220 headers.insert(
221 "ratelimit-reset",
222 reset_secs
223 .to_string()
224 .parse()
225 .expect("numeric header value"),
226 );
227}
228
229fn stable_token_fingerprint(raw: &str) -> String {
230 use sha2::{Digest, Sha256};
231 let hash = Sha256::digest(raw.as_bytes());
232 hex::encode(&hash[..8])
235}
236
237fn extract_actor_id(req: &Request<Body>) -> Option<String> {
238 let principal = crate::auth::extract_auth_principal(req);
239 if let Some(v) = req.headers().get("x-api-key")
240 && let Ok(raw) = v.to_str()
241 && !raw.is_empty()
242 {
243 return Some(format!("api_key:{}", stable_token_fingerprint(raw)));
244 }
245 if let Some(v) = req.headers().get("authorization")
246 && let Ok(raw) = v.to_str()
247 && let Some(token) = raw.strip_prefix("Bearer ")
248 && !token.is_empty()
249 {
250 return Some(format!("bearer:{}", stable_token_fingerprint(token)));
251 }
252 principal
255}
256
257fn parse_ip(s: &str) -> Option<IpAddr> {
258 s.trim().parse().ok()
259}
260
261fn forwarded_ip(req: &Request<Body>) -> Option<IpAddr> {
262 req.headers()
263 .get("x-forwarded-for")
264 .and_then(|v| v.to_str().ok())
265 .and_then(|s| s.split(',').next())
266 .and_then(parse_ip)
267}
268
269fn real_ip(req: &Request<Body>) -> Option<IpAddr> {
270 req.headers()
271 .get("x-real-ip")
272 .and_then(|v| v.to_str().ok())
273 .and_then(parse_ip)
274}
275
276fn trust_forwarded_headers(proxy_ip: IpAddr, trusted_proxy_cidrs: &[IpCidr]) -> bool {
277 trusted_proxy_cidrs
278 .iter()
279 .any(|cidr| cidr.contains(proxy_ip))
280}
281
282fn resolve_client_ip(req: &Request<Body>, trusted_proxy_cidrs: &[IpCidr]) -> IpAddr {
283 let forwarded = forwarded_ip(req);
284 let real = real_ip(req);
285
286 if let (Some(client_ip), Some(proxy_ip)) = (forwarded, real)
287 && trust_forwarded_headers(proxy_ip, trusted_proxy_cidrs)
288 {
289 return client_ip;
290 }
291
292 if let Some(proxy_ip) = real {
293 return proxy_ip;
294 }
295
296 use axum::extract::ConnectInfo;
300 use std::net::SocketAddr;
301 req.extensions()
302 .get::<ConnectInfo<SocketAddr>>()
303 .map(|ci| ci.0.ip())
304 .unwrap_or(IpAddr::from([127, 0, 0, 1]))
305}
306
307impl IpCidr {
308 fn parse(raw: &str) -> Option<Self> {
309 let (ip, prefix) = raw.split_once('/')?;
310 let network = ip.parse::<IpAddr>().ok()?;
311 let prefix_len = prefix.parse::<u8>().ok()?;
312 let max = match network {
313 IpAddr::V4(_) => 32,
314 IpAddr::V6(_) => 128,
315 };
316 if prefix_len > max {
317 return None;
318 }
319 Some(Self {
320 network,
321 prefix_len,
322 })
323 }
324
325 fn contains(&self, ip: IpAddr) -> bool {
326 match (self.network, ip) {
327 (IpAddr::V4(net), IpAddr::V4(candidate)) => {
328 cidr_match_v4(net, candidate, self.prefix_len)
329 }
330 (IpAddr::V6(net), IpAddr::V6(candidate)) => {
331 cidr_match_v6(net, candidate, self.prefix_len)
332 }
333 _ => false,
334 }
335 }
336}
337
338fn cidr_match_v4(network: Ipv4Addr, candidate: Ipv4Addr, prefix_len: u8) -> bool {
339 let mask = if prefix_len == 0 {
340 0
341 } else {
342 u32::MAX << (32 - prefix_len)
343 };
344 (u32::from(network) & mask) == (u32::from(candidate) & mask)
345}
346
347fn cidr_match_v6(network: Ipv6Addr, candidate: Ipv6Addr, prefix_len: u8) -> bool {
348 let net = u128::from_be_bytes(network.octets());
349 let cand = u128::from_be_bytes(candidate.octets());
350 let mask = if prefix_len == 0 {
351 0
352 } else {
353 u128::MAX << (128 - prefix_len)
354 };
355 (net & mask) == (cand & mask)
356}
357
358impl<S> Service<Request<Body>> for GlobalRateLimitService<S>
359where
360 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
361 S::Future: Send + 'static,
362{
363 type Response = Response<Body>;
364 type Error = S::Error;
365 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
366
367 fn poll_ready(
368 &mut self,
369 cx: &mut std::task::Context<'_>,
370 ) -> std::task::Poll<Result<(), Self::Error>> {
371 self.inner.poll_ready(cx)
372 }
373
374 fn call(&mut self, req: Request<Body>) -> Self::Future {
375 let mut inner = self.inner.clone();
376 let state = self.state.clone();
377 let capacity = self.capacity;
378 let per_ip_capacity = self.per_ip_capacity;
379 let per_actor_capacity = self.per_actor_capacity;
380 let window = self.window;
381 let trusted_proxy_cidrs = self.trusted_proxy_cidrs.clone();
382 let ip = resolve_client_ip(&req, &trusted_proxy_cidrs);
383 let actor = extract_actor_id(&req);
384
385 Box::pin(async move {
386 let now = Instant::now();
387 let mut guard = state.lock().await;
388 if now.duration_since(guard.window_start) >= window {
389 guard.window_start = now;
390 guard.count = 0;
391 GlobalRateLimitLayer::evict_stale(&mut guard.per_ip, window);
392 GlobalRateLimitLayer::evict_stale(&mut guard.per_actor, window);
393 guard.throttled_per_ip.clear();
394 guard.throttled_per_actor.clear();
395 guard.throttled_global = 0;
396 }
397
398 let elapsed = now.duration_since(guard.window_start);
400 let reset_secs = window.as_secs().saturating_sub(elapsed.as_secs());
401
402 if guard.count >= capacity {
403 guard.throttled_global += 1;
404 return Ok(too_many_requests_response(capacity, reset_secs));
405 }
406
407 let per_ip_cap = per_ip_capacity;
409 if !guard.per_ip.contains_key(&ip) && guard.per_ip.len() >= MAX_DISTINCT_IPS {
410 return Ok(too_many_requests_response(per_ip_cap, reset_secs));
411 }
412 let ip_entry = guard.per_ip.entry(ip).or_insert((0, now));
413 if now.duration_since(ip_entry.1) >= window {
414 *ip_entry = (0, now);
415 }
416 if ip_entry.0 >= per_ip_cap {
417 *guard.throttled_per_ip.entry(ip).or_insert(0) += 1;
418 return Ok(too_many_requests_response(per_ip_cap, reset_secs));
419 }
420 ip_entry.0 += 1;
421
422 if let Some(ref actor_id) = actor {
424 if !guard.per_actor.contains_key(actor_id)
425 && guard.per_actor.len() >= MAX_DISTINCT_ACTORS
426 {
427 return Ok(too_many_requests_response(per_actor_capacity, reset_secs));
428 }
429 let actor_entry = guard.per_actor.entry(actor_id.clone()).or_insert((0, now));
430 if now.duration_since(actor_entry.1) >= window {
431 *actor_entry = (0, now);
432 }
433 if actor_entry.0 >= per_actor_capacity {
434 *guard
435 .throttled_per_actor
436 .entry(actor_id.clone())
437 .or_insert(0) += 1;
438 return Ok(too_many_requests_response(per_actor_capacity, reset_secs));
439 }
440 actor_entry.0 += 1;
441 }
442
443 let remaining = capacity.saturating_sub(guard.count + 1);
445 guard.count += 1;
446
447 drop(guard);
448
449 let mut resp = inner.call(req).await?;
450 inject_rate_limit_headers(&mut resp, capacity, remaining, reset_secs);
451 Ok(resp)
452 })
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459 use axum::body::Body;
460 use axum::http::{Request, StatusCode};
461 use tower::{Service, ServiceExt};
462
463 fn dummy_service() -> axum::routing::Router {
464 axum::routing::Router::new().route("/", axum::routing::get(|| async { "ok" }))
465 }
466
467 #[tokio::test]
468 async fn allows_requests_within_capacity() {
469 let layer = GlobalRateLimitLayer::new(5, Duration::from_secs(60));
470 let mut svc = layer.layer(dummy_service().into_service());
471 for _ in 0..5 {
472 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
473 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
474 assert_eq!(resp.status(), StatusCode::OK);
475 }
476 }
477
478 #[tokio::test]
479 async fn returns_429_when_capacity_exceeded() {
480 let layer = GlobalRateLimitLayer::new(2, Duration::from_secs(60));
481 let mut svc = layer.layer(dummy_service().into_service());
482 for _ in 0..2 {
483 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
484 let _ = svc.ready().await.unwrap().call(req).await.unwrap();
485 }
486 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
487 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
488 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
489 }
490
491 #[tokio::test]
492 async fn window_resets_after_expiry() {
493 let layer = GlobalRateLimitLayer::new(1, Duration::from_millis(50));
494 let mut svc = layer.layer(dummy_service().into_service());
495 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
496 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
497 assert_eq!(resp.status(), StatusCode::OK);
498 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
499 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
500 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
501 tokio::time::sleep(Duration::from_millis(60)).await;
502 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
503 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
504 assert_eq!(resp.status(), StatusCode::OK);
505 }
506
507 #[tokio::test]
508 async fn per_ip_limits_enforced() {
509 let layer = GlobalRateLimitLayer::new(1000, Duration::from_secs(60));
510 let mut svc = layer.layer(dummy_service().into_service());
511 for _ in 0..300 {
512 let req = Request::builder()
513 .uri("/")
514 .header("x-real-ip", "1.2.3.4")
515 .body(Body::empty())
516 .unwrap();
517 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
518 assert_eq!(resp.status(), StatusCode::OK);
519 }
520 let req = Request::builder()
521 .uri("/")
522 .header("x-real-ip", "1.2.3.4")
523 .body(Body::empty())
524 .unwrap();
525 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
526 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
527 let req = Request::builder()
528 .uri("/")
529 .header("x-real-ip", "5.6.7.8")
530 .body(Body::empty())
531 .unwrap();
532 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
533 assert_eq!(resp.status(), StatusCode::OK);
534 }
535
536 #[test]
537 fn cidr_parse_and_contains() {
538 let cidr = IpCidr::parse("10.0.0.0/8").expect("cidr");
539 assert!(cidr.contains("10.1.2.3".parse().unwrap()));
540 assert!(!cidr.contains("11.1.2.3".parse().unwrap()));
541 }
542
543 #[test]
544 fn trusted_proxy_resolution_prefers_forwarded_when_proxy_trusted() {
545 let req = Request::builder()
546 .header("x-forwarded-for", "1.2.3.4")
547 .header("x-real-ip", "10.0.0.5")
548 .body(Body::empty())
549 .unwrap();
550 let cidr = IpCidr::parse("10.0.0.0/8").unwrap();
551 let ip = resolve_client_ip(&req, &[cidr]);
552 assert_eq!(ip, "1.2.3.4".parse::<IpAddr>().unwrap());
553 }
554
555 #[test]
556 fn untrusted_proxy_resolution_uses_direct_ip() {
557 let req = Request::builder()
558 .header("x-forwarded-for", "1.2.3.4")
559 .header("x-real-ip", "198.51.100.2")
560 .body(Body::empty())
561 .unwrap();
562 let ip = resolve_client_ip(&req, &[]);
563 assert_eq!(ip, "198.51.100.2".parse::<IpAddr>().unwrap());
564 }
565
566 #[test]
567 fn forwarded_header_without_trusted_proxy_is_ignored() {
568 let req = Request::builder()
569 .header("x-forwarded-for", "1.2.3.4")
570 .body(Body::empty())
571 .unwrap();
572 let ip = resolve_client_ip(&req, &[]);
573 assert_eq!(ip, "127.0.0.1".parse::<IpAddr>().unwrap());
574 }
575
576 #[tokio::test]
577 async fn actor_limits_enforced() {
578 let layer =
579 GlobalRateLimitLayer::new(1000, Duration::from_secs(60)).with_per_actor_capacity(2);
580 let mut svc = layer.layer(dummy_service().into_service());
581 for _ in 0..2 {
582 let req = Request::builder()
583 .uri("/")
584 .header("authorization", "Bearer actor-token")
585 .body(Body::empty())
586 .unwrap();
587 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
588 assert_eq!(resp.status(), StatusCode::OK);
589 }
590 let req = Request::builder()
591 .uri("/")
592 .header("authorization", "Bearer actor-token")
593 .body(Body::empty())
594 .unwrap();
595 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
596 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
597 }
598
599 #[tokio::test]
600 async fn snapshot_reflects_throttle_state() {
601 let layer = GlobalRateLimitLayer::new(2, Duration::from_secs(60));
602 let mut svc = layer.layer(dummy_service().into_service());
603
604 for _ in 0..2 {
606 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
607 let _ = svc.ready().await.unwrap().call(req).await.unwrap();
608 }
609 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
611 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
612 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
613
614 let snap = layer.snapshot().await;
615 assert_eq!(snap.global_count, 2);
616 assert_eq!(snap.global_capacity, 2);
617 assert!(
618 snap.throttled_global >= 1,
619 "should record ≥1 throttled global"
620 );
621 assert_eq!(snap.window_secs, 60);
622 }
623}