Skip to main content

roboticus_api/
rate_limit.rs

1//! Global API rate limiting (fixed window, Clone-friendly for axum Router).
2
3use std::collections::HashMap;
4use std::hash::Hash;
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use axum::body::Body;
10use axum::http::{Request, Response, StatusCode};
11use futures_util::future::BoxFuture;
12use tokio::sync::Mutex;
13use tower::{Layer, Service};
14
15/// Hard cap on distinct tracked IPs/actors within a window.
16/// Requests from new IPs beyond this limit are immediately rate-limited
17/// to prevent unbounded memory growth during distributed floods.
18const MAX_DISTINCT_IPS: usize = 10_000;
19const MAX_DISTINCT_ACTORS: usize = 5_000;
20
21/// Fixed-window rate limit state: at most `capacity` requests per `window`.
22#[derive(Clone)]
23pub struct GlobalRateLimitLayer {
24    state: Arc<Mutex<RateLimitState>>,
25    capacity: u64,
26    per_ip_capacity: u64,
27    per_actor_capacity: u64,
28    window: Duration,
29    trusted_proxy_cidrs: Vec<IpCidr>,
30}
31
32struct RateLimitState {
33    count: u64,
34    window_start: Instant,
35    per_ip: HashMap<IpAddr, (u64, Instant)>,
36    per_actor: HashMap<String, (u64, Instant)>,
37    throttled_per_ip: HashMap<IpAddr, u64>,
38    throttled_per_actor: HashMap<String, u64>,
39    throttled_global: u64,
40}
41
42#[derive(Clone, Debug)]
43struct IpCidr {
44    network: IpAddr,
45    prefix_len: u8,
46}
47
48impl GlobalRateLimitLayer {
49    /// Allow at most `capacity` requests per `window` globally, and `per_ip` per IP.
50    pub fn new(capacity: u64, window: Duration) -> Self {
51        Self {
52            state: Arc::new(Mutex::new(RateLimitState {
53                count: 0,
54                window_start: Instant::now(),
55                per_ip: HashMap::new(),
56                per_actor: HashMap::new(),
57                throttled_per_ip: HashMap::new(),
58                throttled_per_actor: HashMap::new(),
59                throttled_global: 0,
60            })),
61            capacity,
62            per_ip_capacity: 300,
63            per_actor_capacity: 200,
64            window,
65            trusted_proxy_cidrs: Vec::new(),
66        }
67    }
68
69    pub fn with_per_ip_capacity(mut self, per_ip_capacity: u64) -> Self {
70        self.per_ip_capacity = per_ip_capacity;
71        self
72    }
73
74    pub fn with_per_actor_capacity(mut self, per_actor_capacity: u64) -> Self {
75        self.per_actor_capacity = per_actor_capacity;
76        self
77    }
78
79    pub fn with_trusted_proxy_cidrs(mut self, cidrs: &[String]) -> Self {
80        self.trusted_proxy_cidrs = cidrs
81            .iter()
82            .filter_map(|c| IpCidr::parse(c))
83            .collect::<Vec<_>>();
84        self
85    }
86
87    fn evict_stale<K>(counter: &mut HashMap<K, (u64, Instant)>, window: Duration)
88    where
89        K: Eq + Hash,
90    {
91        let now = Instant::now();
92        counter.retain(|_, (_, start)| now.duration_since(*start) < window);
93    }
94
95    /// Snapshot current throttle statistics for admin observability.
96    ///
97    /// Returns counts of throttled requests per-IP, per-actor, and globally
98    /// within the current window, plus top offenders (up to 10 each).
99    pub async fn snapshot(&self) -> ThrottleSnapshot {
100        let guard = self.state.lock().await;
101
102        let mut top_ips: Vec<_> = guard
103            .throttled_per_ip
104            .iter()
105            .map(|(ip, &count)| (ip.to_string(), count))
106            .collect();
107        top_ips.sort_by(|a, b| b.1.cmp(&a.1));
108        top_ips.truncate(10);
109
110        let mut top_actors: Vec<_> = guard
111            .throttled_per_actor
112            .iter()
113            .map(|(actor, &count)| (actor.clone(), count))
114            .collect();
115        top_actors.sort_by(|a, b| b.1.cmp(&a.1));
116        top_actors.truncate(10);
117
118        ThrottleSnapshot {
119            window_secs: self.window.as_secs(),
120            global_count: guard.count,
121            global_capacity: self.capacity,
122            per_ip_capacity: self.per_ip_capacity,
123            per_actor_capacity: self.per_actor_capacity,
124            throttled_global: guard.throttled_global,
125            active_ips: guard.per_ip.len(),
126            active_actors: guard.per_actor.len(),
127            top_throttled_ips: top_ips,
128            top_throttled_actors: top_actors,
129        }
130    }
131}
132
133/// Snapshot of current throttle counters for observability.
134#[derive(Debug, Clone, serde::Serialize)]
135pub struct ThrottleSnapshot {
136    pub window_secs: u64,
137    pub global_count: u64,
138    pub global_capacity: u64,
139    pub per_ip_capacity: u64,
140    pub per_actor_capacity: u64,
141    pub throttled_global: u64,
142    pub active_ips: usize,
143    pub active_actors: usize,
144    pub top_throttled_ips: Vec<(String, u64)>,
145    pub top_throttled_actors: Vec<(String, u64)>,
146}
147
148impl<S> Layer<S> for GlobalRateLimitLayer {
149    type Service = GlobalRateLimitService<S>;
150
151    fn layer(&self, inner: S) -> Self::Service {
152        GlobalRateLimitService {
153            inner,
154            state: self.state.clone(),
155            capacity: self.capacity,
156            per_ip_capacity: self.per_ip_capacity,
157            per_actor_capacity: self.per_actor_capacity,
158            window: self.window,
159            trusted_proxy_cidrs: self.trusted_proxy_cidrs.clone(),
160        }
161    }
162}
163
164#[derive(Clone)]
165pub struct GlobalRateLimitService<S> {
166    inner: S,
167    state: Arc<Mutex<RateLimitState>>,
168    capacity: u64,
169    per_ip_capacity: u64,
170    per_actor_capacity: u64,
171    window: Duration,
172    trusted_proxy_cidrs: Vec<IpCidr>,
173}
174
175fn too_many_requests_response(limit: u64, window_secs: u64) -> Response<Body> {
176    let body = serde_json::json!({
177        "type": "about:blank",
178        "title": "Too Many Requests",
179        "status": 429,
180        "detail": "rate_limit_exceeded"
181    });
182    let body_bytes = serde_json::to_vec(&body)
183        .unwrap_or_else(|_| br#"{"type":"about:blank","title":"Too Many Requests","status":429,"detail":"rate_limit_exceeded"}"#.to_vec());
184    match Response::builder()
185        .status(StatusCode::TOO_MANY_REQUESTS)
186        .header("content-type", "application/problem+json")
187        .header("ratelimit-limit", limit.to_string())
188        .header("ratelimit-remaining", "0")
189        .header("ratelimit-reset", window_secs.to_string())
190        .header("retry-after", window_secs.to_string())
191        .body(Body::from(body_bytes))
192    {
193        Ok(resp) => resp,
194        Err(_) => {
195            let mut resp = Response::new(Body::from(
196                br#"{"type":"about:blank","title":"Too Many Requests","status":429,"detail":"rate_limit_exceeded"}"#.as_slice().to_vec(),
197            ));
198            *resp.status_mut() = StatusCode::TOO_MANY_REQUESTS;
199            resp
200        }
201    }
202}
203
204/// Inject rate limit headers into a successful response.
205fn inject_rate_limit_headers(
206    resp: &mut Response<Body>,
207    limit: u64,
208    remaining: u64,
209    reset_secs: u64,
210) {
211    let headers = resp.headers_mut();
212    headers.insert(
213        "ratelimit-limit",
214        limit.to_string().parse().expect("numeric header value"),
215    );
216    headers.insert(
217        "ratelimit-remaining",
218        remaining.to_string().parse().expect("numeric header value"),
219    );
220    headers.insert(
221        "ratelimit-reset",
222        reset_secs
223            .to_string()
224            .parse()
225            .expect("numeric header value"),
226    );
227}
228
229fn stable_token_fingerprint(raw: &str) -> String {
230    use sha2::{Digest, Sha256};
231    let hash = Sha256::digest(raw.as_bytes());
232    // 8 bytes (64 bits) is plenty for rate-limit dedup — collision-resistant
233    // enough for bucket identity while keeping map keys small.
234    hex::encode(&hash[..8])
235}
236
237fn extract_actor_id(req: &Request<Body>) -> Option<String> {
238    let principal = crate::auth::extract_auth_principal(req);
239    if let Some(v) = req.headers().get("x-api-key")
240        && let Ok(raw) = v.to_str()
241        && !raw.is_empty()
242    {
243        return Some(format!("api_key:{}", stable_token_fingerprint(raw)));
244    }
245    if let Some(v) = req.headers().get("authorization")
246        && let Ok(raw) = v.to_str()
247        && let Some(token) = raw.strip_prefix("Bearer ")
248        && !token.is_empty()
249    {
250        return Some(format!("bearer:{}", stable_token_fingerprint(token)));
251    }
252    // x-user-id header is intentionally NOT used as an actor identity here.
253    // It is unauthenticated and would allow rate-limit bypass by cycling IDs.
254    principal
255}
256
257fn parse_ip(s: &str) -> Option<IpAddr> {
258    s.trim().parse().ok()
259}
260
261fn forwarded_ip(req: &Request<Body>) -> Option<IpAddr> {
262    req.headers()
263        .get("x-forwarded-for")
264        .and_then(|v| v.to_str().ok())
265        .and_then(|s| s.split(',').next())
266        .and_then(parse_ip)
267}
268
269fn real_ip(req: &Request<Body>) -> Option<IpAddr> {
270    req.headers()
271        .get("x-real-ip")
272        .and_then(|v| v.to_str().ok())
273        .and_then(parse_ip)
274}
275
276fn trust_forwarded_headers(proxy_ip: IpAddr, trusted_proxy_cidrs: &[IpCidr]) -> bool {
277    trusted_proxy_cidrs
278        .iter()
279        .any(|cidr| cidr.contains(proxy_ip))
280}
281
282fn resolve_client_ip(req: &Request<Body>, trusted_proxy_cidrs: &[IpCidr]) -> IpAddr {
283    let forwarded = forwarded_ip(req);
284    let real = real_ip(req);
285
286    if let (Some(client_ip), Some(proxy_ip)) = (forwarded, real)
287        && trust_forwarded_headers(proxy_ip, trusted_proxy_cidrs)
288    {
289        return client_ip;
290    }
291
292    if let Some(proxy_ip) = real {
293        return proxy_ip;
294    }
295
296    // Fall back to the actual TCP peer address from ConnectInfo rather than
297    // hardcoding 127.0.0.1, which would lump all headerless clients into
298    // a single rate-limit bucket.
299    use axum::extract::ConnectInfo;
300    use std::net::SocketAddr;
301    req.extensions()
302        .get::<ConnectInfo<SocketAddr>>()
303        .map(|ci| ci.0.ip())
304        .unwrap_or(IpAddr::from([127, 0, 0, 1]))
305}
306
307impl IpCidr {
308    fn parse(raw: &str) -> Option<Self> {
309        let (ip, prefix) = raw.split_once('/')?;
310        let network = ip.parse::<IpAddr>().ok()?;
311        let prefix_len = prefix.parse::<u8>().ok()?;
312        let max = match network {
313            IpAddr::V4(_) => 32,
314            IpAddr::V6(_) => 128,
315        };
316        if prefix_len > max {
317            return None;
318        }
319        Some(Self {
320            network,
321            prefix_len,
322        })
323    }
324
325    fn contains(&self, ip: IpAddr) -> bool {
326        match (self.network, ip) {
327            (IpAddr::V4(net), IpAddr::V4(candidate)) => {
328                cidr_match_v4(net, candidate, self.prefix_len)
329            }
330            (IpAddr::V6(net), IpAddr::V6(candidate)) => {
331                cidr_match_v6(net, candidate, self.prefix_len)
332            }
333            _ => false,
334        }
335    }
336}
337
338fn cidr_match_v4(network: Ipv4Addr, candidate: Ipv4Addr, prefix_len: u8) -> bool {
339    let mask = if prefix_len == 0 {
340        0
341    } else {
342        u32::MAX << (32 - prefix_len)
343    };
344    (u32::from(network) & mask) == (u32::from(candidate) & mask)
345}
346
347fn cidr_match_v6(network: Ipv6Addr, candidate: Ipv6Addr, prefix_len: u8) -> bool {
348    let net = u128::from_be_bytes(network.octets());
349    let cand = u128::from_be_bytes(candidate.octets());
350    let mask = if prefix_len == 0 {
351        0
352    } else {
353        u128::MAX << (128 - prefix_len)
354    };
355    (net & mask) == (cand & mask)
356}
357
358impl<S> Service<Request<Body>> for GlobalRateLimitService<S>
359where
360    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
361    S::Future: Send + 'static,
362{
363    type Response = Response<Body>;
364    type Error = S::Error;
365    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
366
367    fn poll_ready(
368        &mut self,
369        cx: &mut std::task::Context<'_>,
370    ) -> std::task::Poll<Result<(), Self::Error>> {
371        self.inner.poll_ready(cx)
372    }
373
374    fn call(&mut self, req: Request<Body>) -> Self::Future {
375        let mut inner = self.inner.clone();
376        let state = self.state.clone();
377        let capacity = self.capacity;
378        let per_ip_capacity = self.per_ip_capacity;
379        let per_actor_capacity = self.per_actor_capacity;
380        let window = self.window;
381        let trusted_proxy_cidrs = self.trusted_proxy_cidrs.clone();
382        let ip = resolve_client_ip(&req, &trusted_proxy_cidrs);
383        let actor = extract_actor_id(&req);
384
385        Box::pin(async move {
386            let now = Instant::now();
387            let mut guard = state.lock().await;
388            if now.duration_since(guard.window_start) >= window {
389                guard.window_start = now;
390                guard.count = 0;
391                GlobalRateLimitLayer::evict_stale(&mut guard.per_ip, window);
392                GlobalRateLimitLayer::evict_stale(&mut guard.per_actor, window);
393                guard.throttled_per_ip.clear();
394                guard.throttled_per_actor.clear();
395                guard.throttled_global = 0;
396            }
397
398            // Time remaining in current window (for reset header).
399            let elapsed = now.duration_since(guard.window_start);
400            let reset_secs = window.as_secs().saturating_sub(elapsed.as_secs());
401
402            if guard.count >= capacity {
403                guard.throttled_global += 1;
404                return Ok(too_many_requests_response(capacity, reset_secs));
405            }
406
407            // Check per-IP limit.
408            let per_ip_cap = per_ip_capacity;
409            if !guard.per_ip.contains_key(&ip) && guard.per_ip.len() >= MAX_DISTINCT_IPS {
410                return Ok(too_many_requests_response(per_ip_cap, reset_secs));
411            }
412            let ip_entry = guard.per_ip.entry(ip).or_insert((0, now));
413            if now.duration_since(ip_entry.1) >= window {
414                *ip_entry = (0, now);
415            }
416            if ip_entry.0 >= per_ip_cap {
417                *guard.throttled_per_ip.entry(ip).or_insert(0) += 1;
418                return Ok(too_many_requests_response(per_ip_cap, reset_secs));
419            }
420            ip_entry.0 += 1;
421
422            // Check per-actor limit.
423            if let Some(ref actor_id) = actor {
424                if !guard.per_actor.contains_key(actor_id)
425                    && guard.per_actor.len() >= MAX_DISTINCT_ACTORS
426                {
427                    return Ok(too_many_requests_response(per_actor_capacity, reset_secs));
428                }
429                let actor_entry = guard.per_actor.entry(actor_id.clone()).or_insert((0, now));
430                if now.duration_since(actor_entry.1) >= window {
431                    *actor_entry = (0, now);
432                }
433                if actor_entry.0 >= per_actor_capacity {
434                    *guard
435                        .throttled_per_actor
436                        .entry(actor_id.clone())
437                        .or_insert(0) += 1;
438                    return Ok(too_many_requests_response(per_actor_capacity, reset_secs));
439                }
440                actor_entry.0 += 1;
441            }
442
443            // All per-IP/per-actor checks passed — now increment global counter.
444            let remaining = capacity.saturating_sub(guard.count + 1);
445            guard.count += 1;
446
447            drop(guard);
448
449            let mut resp = inner.call(req).await?;
450            inject_rate_limit_headers(&mut resp, capacity, remaining, reset_secs);
451            Ok(resp)
452        })
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459    use axum::body::Body;
460    use axum::http::{Request, StatusCode};
461    use tower::{Service, ServiceExt};
462
463    fn dummy_service() -> axum::routing::Router {
464        axum::routing::Router::new().route("/", axum::routing::get(|| async { "ok" }))
465    }
466
467    #[tokio::test]
468    async fn allows_requests_within_capacity() {
469        let layer = GlobalRateLimitLayer::new(5, Duration::from_secs(60));
470        let mut svc = layer.layer(dummy_service().into_service());
471        for _ in 0..5 {
472            let req = Request::builder().uri("/").body(Body::empty()).unwrap();
473            let resp = svc.ready().await.unwrap().call(req).await.unwrap();
474            assert_eq!(resp.status(), StatusCode::OK);
475        }
476    }
477
478    #[tokio::test]
479    async fn returns_429_when_capacity_exceeded() {
480        let layer = GlobalRateLimitLayer::new(2, Duration::from_secs(60));
481        let mut svc = layer.layer(dummy_service().into_service());
482        for _ in 0..2 {
483            let req = Request::builder().uri("/").body(Body::empty()).unwrap();
484            let _ = svc.ready().await.unwrap().call(req).await.unwrap();
485        }
486        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
487        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
488        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
489    }
490
491    #[tokio::test]
492    async fn window_resets_after_expiry() {
493        let layer = GlobalRateLimitLayer::new(1, Duration::from_millis(50));
494        let mut svc = layer.layer(dummy_service().into_service());
495        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
496        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
497        assert_eq!(resp.status(), StatusCode::OK);
498        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
499        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
500        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
501        tokio::time::sleep(Duration::from_millis(60)).await;
502        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
503        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
504        assert_eq!(resp.status(), StatusCode::OK);
505    }
506
507    #[tokio::test]
508    async fn per_ip_limits_enforced() {
509        let layer = GlobalRateLimitLayer::new(1000, Duration::from_secs(60));
510        let mut svc = layer.layer(dummy_service().into_service());
511        for _ in 0..300 {
512            let req = Request::builder()
513                .uri("/")
514                .header("x-real-ip", "1.2.3.4")
515                .body(Body::empty())
516                .unwrap();
517            let resp = svc.ready().await.unwrap().call(req).await.unwrap();
518            assert_eq!(resp.status(), StatusCode::OK);
519        }
520        let req = Request::builder()
521            .uri("/")
522            .header("x-real-ip", "1.2.3.4")
523            .body(Body::empty())
524            .unwrap();
525        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
526        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
527        let req = Request::builder()
528            .uri("/")
529            .header("x-real-ip", "5.6.7.8")
530            .body(Body::empty())
531            .unwrap();
532        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
533        assert_eq!(resp.status(), StatusCode::OK);
534    }
535
536    #[test]
537    fn cidr_parse_and_contains() {
538        let cidr = IpCidr::parse("10.0.0.0/8").expect("cidr");
539        assert!(cidr.contains("10.1.2.3".parse().unwrap()));
540        assert!(!cidr.contains("11.1.2.3".parse().unwrap()));
541    }
542
543    #[test]
544    fn trusted_proxy_resolution_prefers_forwarded_when_proxy_trusted() {
545        let req = Request::builder()
546            .header("x-forwarded-for", "1.2.3.4")
547            .header("x-real-ip", "10.0.0.5")
548            .body(Body::empty())
549            .unwrap();
550        let cidr = IpCidr::parse("10.0.0.0/8").unwrap();
551        let ip = resolve_client_ip(&req, &[cidr]);
552        assert_eq!(ip, "1.2.3.4".parse::<IpAddr>().unwrap());
553    }
554
555    #[test]
556    fn untrusted_proxy_resolution_uses_direct_ip() {
557        let req = Request::builder()
558            .header("x-forwarded-for", "1.2.3.4")
559            .header("x-real-ip", "198.51.100.2")
560            .body(Body::empty())
561            .unwrap();
562        let ip = resolve_client_ip(&req, &[]);
563        assert_eq!(ip, "198.51.100.2".parse::<IpAddr>().unwrap());
564    }
565
566    #[test]
567    fn forwarded_header_without_trusted_proxy_is_ignored() {
568        let req = Request::builder()
569            .header("x-forwarded-for", "1.2.3.4")
570            .body(Body::empty())
571            .unwrap();
572        let ip = resolve_client_ip(&req, &[]);
573        assert_eq!(ip, "127.0.0.1".parse::<IpAddr>().unwrap());
574    }
575
576    #[tokio::test]
577    async fn actor_limits_enforced() {
578        let layer =
579            GlobalRateLimitLayer::new(1000, Duration::from_secs(60)).with_per_actor_capacity(2);
580        let mut svc = layer.layer(dummy_service().into_service());
581        for _ in 0..2 {
582            let req = Request::builder()
583                .uri("/")
584                .header("authorization", "Bearer actor-token")
585                .body(Body::empty())
586                .unwrap();
587            let resp = svc.ready().await.unwrap().call(req).await.unwrap();
588            assert_eq!(resp.status(), StatusCode::OK);
589        }
590        let req = Request::builder()
591            .uri("/")
592            .header("authorization", "Bearer actor-token")
593            .body(Body::empty())
594            .unwrap();
595        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
596        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
597    }
598
599    #[tokio::test]
600    async fn snapshot_reflects_throttle_state() {
601        let layer = GlobalRateLimitLayer::new(2, Duration::from_secs(60));
602        let mut svc = layer.layer(dummy_service().into_service());
603
604        // Exhaust global capacity.
605        for _ in 0..2 {
606            let req = Request::builder().uri("/").body(Body::empty()).unwrap();
607            let _ = svc.ready().await.unwrap().call(req).await.unwrap();
608        }
609        // This should be throttled.
610        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
611        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
612        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
613
614        let snap = layer.snapshot().await;
615        assert_eq!(snap.global_count, 2);
616        assert_eq!(snap.global_capacity, 2);
617        assert!(
618            snap.throttled_global >= 1,
619            "should record ≥1 throttled global"
620        );
621        assert_eq!(snap.window_secs, 60);
622    }
623}