Skip to main content

roboticus_api/
rate_limit.rs

1//! Global API rate limiting (fixed window, Clone-friendly for axum Router).
2
3use std::collections::HashMap;
4use std::hash::Hash;
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use axum::body::Body;
10use axum::http::{Request, Response, StatusCode};
11use futures_util::future::BoxFuture;
12use tokio::sync::Mutex;
13use tower::{Layer, Service};
14
15/// Hard cap on distinct tracked IPs/actors within a window.
16/// Requests from new IPs beyond this limit are immediately rate-limited
17/// to prevent unbounded memory growth during distributed floods.
18const MAX_DISTINCT_IPS: usize = 10_000;
19const MAX_DISTINCT_ACTORS: usize = 5_000;
20
21/// Fixed-window rate limit state: at most `capacity` requests per `window`.
22#[derive(Clone)]
23pub struct GlobalRateLimitLayer {
24    state: Arc<Mutex<RateLimitState>>,
25    capacity: u64,
26    per_ip_capacity: u64,
27    per_actor_capacity: u64,
28    window: Duration,
29    trusted_proxy_cidrs: Vec<IpCidr>,
30}
31
32struct RateLimitState {
33    count: u64,
34    window_start: Instant,
35    per_ip: HashMap<IpAddr, (u64, Instant)>,
36    per_actor: HashMap<String, (u64, Instant)>,
37    throttled_per_ip: HashMap<IpAddr, u64>,
38    throttled_per_actor: HashMap<String, u64>,
39    throttled_global: u64,
40}
41
42#[derive(Clone, Debug)]
43struct IpCidr {
44    network: IpAddr,
45    prefix_len: u8,
46}
47
48impl GlobalRateLimitLayer {
49    /// Allow at most `capacity` requests per `window` globally, and `per_ip` per IP.
50    pub fn new(capacity: u64, window: Duration) -> Self {
51        Self {
52            state: Arc::new(Mutex::new(RateLimitState {
53                count: 0,
54                window_start: Instant::now(),
55                per_ip: HashMap::new(),
56                per_actor: HashMap::new(),
57                throttled_per_ip: HashMap::new(),
58                throttled_per_actor: HashMap::new(),
59                throttled_global: 0,
60            })),
61            capacity,
62            per_ip_capacity: 300,
63            per_actor_capacity: 200,
64            window,
65            trusted_proxy_cidrs: Vec::new(),
66        }
67    }
68
69    pub fn with_per_ip_capacity(mut self, per_ip_capacity: u64) -> Self {
70        self.per_ip_capacity = per_ip_capacity;
71        self
72    }
73
74    pub fn with_per_actor_capacity(mut self, per_actor_capacity: u64) -> Self {
75        self.per_actor_capacity = per_actor_capacity;
76        self
77    }
78
79    pub fn with_trusted_proxy_cidrs(mut self, cidrs: &[String]) -> Self {
80        self.trusted_proxy_cidrs = cidrs
81            .iter()
82            .filter_map(|c| IpCidr::parse(c))
83            .collect::<Vec<_>>();
84        self
85    }
86
87    fn evict_stale<K>(counter: &mut HashMap<K, (u64, Instant)>, window: Duration)
88    where
89        K: Eq + Hash,
90    {
91        let now = Instant::now();
92        counter.retain(|_, (_, start)| now.duration_since(*start) < window);
93    }
94
95    /// Snapshot current throttle statistics for admin observability.
96    ///
97    /// Returns counts of throttled requests per-IP, per-actor, and globally
98    /// within the current window, plus top offenders (up to 10 each).
99    pub async fn snapshot(&self) -> ThrottleSnapshot {
100        let guard = self.state.lock().await;
101
102        let mut top_ips: Vec<_> = guard
103            .throttled_per_ip
104            .iter()
105            .map(|(ip, &count)| (ip.to_string(), count))
106            .collect();
107        top_ips.sort_by(|a, b| b.1.cmp(&a.1));
108        top_ips.truncate(10);
109
110        let mut top_actors: Vec<_> = guard
111            .throttled_per_actor
112            .iter()
113            .map(|(actor, &count)| (actor.clone(), count))
114            .collect();
115        top_actors.sort_by(|a, b| b.1.cmp(&a.1));
116        top_actors.truncate(10);
117
118        ThrottleSnapshot {
119            window_secs: self.window.as_secs(),
120            global_count: guard.count,
121            global_capacity: self.capacity,
122            per_ip_capacity: self.per_ip_capacity,
123            per_actor_capacity: self.per_actor_capacity,
124            throttled_global: guard.throttled_global,
125            active_ips: guard.per_ip.len(),
126            active_actors: guard.per_actor.len(),
127            top_throttled_ips: top_ips,
128            top_throttled_actors: top_actors,
129        }
130    }
131}
132
133/// Snapshot of current throttle counters for observability.
134#[derive(Debug, Clone, serde::Serialize)]
135pub struct ThrottleSnapshot {
136    pub window_secs: u64,
137    pub global_count: u64,
138    pub global_capacity: u64,
139    pub per_ip_capacity: u64,
140    pub per_actor_capacity: u64,
141    pub throttled_global: u64,
142    pub active_ips: usize,
143    pub active_actors: usize,
144    pub top_throttled_ips: Vec<(String, u64)>,
145    pub top_throttled_actors: Vec<(String, u64)>,
146}
147
148impl<S> Layer<S> for GlobalRateLimitLayer {
149    type Service = GlobalRateLimitService<S>;
150
151    fn layer(&self, inner: S) -> Self::Service {
152        GlobalRateLimitService {
153            inner,
154            state: self.state.clone(),
155            capacity: self.capacity,
156            per_ip_capacity: self.per_ip_capacity,
157            per_actor_capacity: self.per_actor_capacity,
158            window: self.window,
159            trusted_proxy_cidrs: self.trusted_proxy_cidrs.clone(),
160        }
161    }
162}
163
164#[derive(Clone)]
165pub struct GlobalRateLimitService<S> {
166    inner: S,
167    state: Arc<Mutex<RateLimitState>>,
168    capacity: u64,
169    per_ip_capacity: u64,
170    per_actor_capacity: u64,
171    window: Duration,
172    trusted_proxy_cidrs: Vec<IpCidr>,
173}
174
175fn too_many_requests_response() -> Response<Body> {
176    let body = serde_json::json!({
177        "error": "rate_limit_exceeded",
178        "message": "Too many requests, please try again later"
179    });
180    let body_bytes = serde_json::to_vec(&body)
181        .unwrap_or_else(|_| br#"{"error":"rate_limit_exceeded"}"#.to_vec());
182    match Response::builder()
183        .status(StatusCode::TOO_MANY_REQUESTS)
184        .header("content-type", "application/json")
185        .body(Body::from(body_bytes))
186    {
187        Ok(resp) => resp,
188        Err(_) => {
189            let mut resp = Response::new(Body::from(
190                br#"{"error":"rate_limit_exceeded"}"#.as_slice().to_vec(),
191            ));
192            *resp.status_mut() = StatusCode::TOO_MANY_REQUESTS;
193            resp
194        }
195    }
196}
197
198fn stable_token_fingerprint(raw: &str) -> String {
199    use sha2::{Digest, Sha256};
200    let hash = Sha256::digest(raw.as_bytes());
201    // 8 bytes (64 bits) is plenty for rate-limit dedup — collision-resistant
202    // enough for bucket identity while keeping map keys small.
203    hex::encode(&hash[..8])
204}
205
206fn extract_actor_id(req: &Request<Body>) -> Option<String> {
207    let principal = crate::auth::extract_auth_principal(req);
208    if let Some(v) = req.headers().get("x-api-key")
209        && let Ok(raw) = v.to_str()
210        && !raw.is_empty()
211    {
212        return Some(format!("api_key:{}", stable_token_fingerprint(raw)));
213    }
214    if let Some(v) = req.headers().get("authorization")
215        && let Ok(raw) = v.to_str()
216        && let Some(token) = raw.strip_prefix("Bearer ")
217        && !token.is_empty()
218    {
219        return Some(format!("bearer:{}", stable_token_fingerprint(token)));
220    }
221    // x-user-id header is intentionally NOT used as an actor identity here.
222    // It is unauthenticated and would allow rate-limit bypass by cycling IDs.
223    principal
224}
225
226fn parse_ip(s: &str) -> Option<IpAddr> {
227    s.trim().parse().ok()
228}
229
230fn forwarded_ip(req: &Request<Body>) -> Option<IpAddr> {
231    req.headers()
232        .get("x-forwarded-for")
233        .and_then(|v| v.to_str().ok())
234        .and_then(|s| s.split(',').next())
235        .and_then(parse_ip)
236}
237
238fn real_ip(req: &Request<Body>) -> Option<IpAddr> {
239    req.headers()
240        .get("x-real-ip")
241        .and_then(|v| v.to_str().ok())
242        .and_then(parse_ip)
243}
244
245fn trust_forwarded_headers(proxy_ip: IpAddr, trusted_proxy_cidrs: &[IpCidr]) -> bool {
246    trusted_proxy_cidrs
247        .iter()
248        .any(|cidr| cidr.contains(proxy_ip))
249}
250
251fn resolve_client_ip(req: &Request<Body>, trusted_proxy_cidrs: &[IpCidr]) -> IpAddr {
252    let forwarded = forwarded_ip(req);
253    let real = real_ip(req);
254
255    if let (Some(client_ip), Some(proxy_ip)) = (forwarded, real)
256        && trust_forwarded_headers(proxy_ip, trusted_proxy_cidrs)
257    {
258        return client_ip;
259    }
260
261    if let Some(proxy_ip) = real {
262        return proxy_ip;
263    }
264
265    // Fall back to the actual TCP peer address from ConnectInfo rather than
266    // hardcoding 127.0.0.1, which would lump all headerless clients into
267    // a single rate-limit bucket.
268    use axum::extract::ConnectInfo;
269    use std::net::SocketAddr;
270    req.extensions()
271        .get::<ConnectInfo<SocketAddr>>()
272        .map(|ci| ci.0.ip())
273        .unwrap_or(IpAddr::from([127, 0, 0, 1]))
274}
275
276impl IpCidr {
277    fn parse(raw: &str) -> Option<Self> {
278        let (ip, prefix) = raw.split_once('/')?;
279        let network = ip.parse::<IpAddr>().ok()?;
280        let prefix_len = prefix.parse::<u8>().ok()?;
281        let max = match network {
282            IpAddr::V4(_) => 32,
283            IpAddr::V6(_) => 128,
284        };
285        if prefix_len > max {
286            return None;
287        }
288        Some(Self {
289            network,
290            prefix_len,
291        })
292    }
293
294    fn contains(&self, ip: IpAddr) -> bool {
295        match (self.network, ip) {
296            (IpAddr::V4(net), IpAddr::V4(candidate)) => {
297                cidr_match_v4(net, candidate, self.prefix_len)
298            }
299            (IpAddr::V6(net), IpAddr::V6(candidate)) => {
300                cidr_match_v6(net, candidate, self.prefix_len)
301            }
302            _ => false,
303        }
304    }
305}
306
307fn cidr_match_v4(network: Ipv4Addr, candidate: Ipv4Addr, prefix_len: u8) -> bool {
308    let mask = if prefix_len == 0 {
309        0
310    } else {
311        u32::MAX << (32 - prefix_len)
312    };
313    (u32::from(network) & mask) == (u32::from(candidate) & mask)
314}
315
316fn cidr_match_v6(network: Ipv6Addr, candidate: Ipv6Addr, prefix_len: u8) -> bool {
317    let net = u128::from_be_bytes(network.octets());
318    let cand = u128::from_be_bytes(candidate.octets());
319    let mask = if prefix_len == 0 {
320        0
321    } else {
322        u128::MAX << (128 - prefix_len)
323    };
324    (net & mask) == (cand & mask)
325}
326
327impl<S> Service<Request<Body>> for GlobalRateLimitService<S>
328where
329    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
330    S::Future: Send + 'static,
331{
332    type Response = Response<Body>;
333    type Error = S::Error;
334    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
335
336    fn poll_ready(
337        &mut self,
338        cx: &mut std::task::Context<'_>,
339    ) -> std::task::Poll<Result<(), Self::Error>> {
340        self.inner.poll_ready(cx)
341    }
342
343    fn call(&mut self, req: Request<Body>) -> Self::Future {
344        let mut inner = self.inner.clone();
345        let state = self.state.clone();
346        let capacity = self.capacity;
347        let per_ip_capacity = self.per_ip_capacity;
348        let per_actor_capacity = self.per_actor_capacity;
349        let window = self.window;
350        let trusted_proxy_cidrs = self.trusted_proxy_cidrs.clone();
351        let ip = resolve_client_ip(&req, &trusted_proxy_cidrs);
352        let actor = extract_actor_id(&req);
353
354        Box::pin(async move {
355            let now = Instant::now();
356            let mut guard = state.lock().await;
357            if now.duration_since(guard.window_start) >= window {
358                guard.window_start = now;
359                guard.count = 0;
360                GlobalRateLimitLayer::evict_stale(&mut guard.per_ip, window);
361                GlobalRateLimitLayer::evict_stale(&mut guard.per_actor, window);
362                guard.throttled_per_ip.clear();
363                guard.throttled_per_actor.clear();
364                guard.throttled_global = 0;
365            }
366            if guard.count >= capacity {
367                guard.throttled_global += 1;
368                return Ok(too_many_requests_response());
369            }
370
371            // Check per-IP limit.
372            let per_ip_cap = per_ip_capacity;
373            if !guard.per_ip.contains_key(&ip) && guard.per_ip.len() >= MAX_DISTINCT_IPS {
374                return Ok(too_many_requests_response());
375            }
376            let ip_entry = guard.per_ip.entry(ip).or_insert((0, now));
377            if now.duration_since(ip_entry.1) >= window {
378                *ip_entry = (0, now);
379            }
380            if ip_entry.0 >= per_ip_cap {
381                *guard.throttled_per_ip.entry(ip).or_insert(0) += 1;
382                return Ok(too_many_requests_response());
383            }
384            ip_entry.0 += 1;
385
386            // Check per-actor limit.
387            if let Some(ref actor_id) = actor {
388                if !guard.per_actor.contains_key(actor_id)
389                    && guard.per_actor.len() >= MAX_DISTINCT_ACTORS
390                {
391                    return Ok(too_many_requests_response());
392                }
393                let actor_entry = guard.per_actor.entry(actor_id.clone()).or_insert((0, now));
394                if now.duration_since(actor_entry.1) >= window {
395                    *actor_entry = (0, now);
396                }
397                if actor_entry.0 >= per_actor_capacity {
398                    *guard
399                        .throttled_per_actor
400                        .entry(actor_id.clone())
401                        .or_insert(0) += 1;
402                    return Ok(too_many_requests_response());
403                }
404                actor_entry.0 += 1;
405            }
406
407            // All per-IP/per-actor checks passed — now increment global counter.
408            guard.count += 1;
409
410            drop(guard);
411
412            inner.call(req).await
413        })
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use axum::body::Body;
421    use axum::http::{Request, StatusCode};
422    use tower::{Service, ServiceExt};
423
424    fn dummy_service() -> axum::routing::Router {
425        axum::routing::Router::new().route("/", axum::routing::get(|| async { "ok" }))
426    }
427
428    #[tokio::test]
429    async fn allows_requests_within_capacity() {
430        let layer = GlobalRateLimitLayer::new(5, Duration::from_secs(60));
431        let mut svc = layer.layer(dummy_service().into_service());
432        for _ in 0..5 {
433            let req = Request::builder().uri("/").body(Body::empty()).unwrap();
434            let resp = svc.ready().await.unwrap().call(req).await.unwrap();
435            assert_eq!(resp.status(), StatusCode::OK);
436        }
437    }
438
439    #[tokio::test]
440    async fn returns_429_when_capacity_exceeded() {
441        let layer = GlobalRateLimitLayer::new(2, Duration::from_secs(60));
442        let mut svc = layer.layer(dummy_service().into_service());
443        for _ in 0..2 {
444            let req = Request::builder().uri("/").body(Body::empty()).unwrap();
445            let _ = svc.ready().await.unwrap().call(req).await.unwrap();
446        }
447        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
448        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
449        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
450    }
451
452    #[tokio::test]
453    async fn window_resets_after_expiry() {
454        let layer = GlobalRateLimitLayer::new(1, Duration::from_millis(50));
455        let mut svc = layer.layer(dummy_service().into_service());
456        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
457        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
458        assert_eq!(resp.status(), StatusCode::OK);
459        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
460        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
461        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
462        tokio::time::sleep(Duration::from_millis(60)).await;
463        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
464        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
465        assert_eq!(resp.status(), StatusCode::OK);
466    }
467
468    #[tokio::test]
469    async fn per_ip_limits_enforced() {
470        let layer = GlobalRateLimitLayer::new(1000, Duration::from_secs(60));
471        let mut svc = layer.layer(dummy_service().into_service());
472        for _ in 0..300 {
473            let req = Request::builder()
474                .uri("/")
475                .header("x-real-ip", "1.2.3.4")
476                .body(Body::empty())
477                .unwrap();
478            let resp = svc.ready().await.unwrap().call(req).await.unwrap();
479            assert_eq!(resp.status(), StatusCode::OK);
480        }
481        let req = Request::builder()
482            .uri("/")
483            .header("x-real-ip", "1.2.3.4")
484            .body(Body::empty())
485            .unwrap();
486        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
487        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
488        let req = Request::builder()
489            .uri("/")
490            .header("x-real-ip", "5.6.7.8")
491            .body(Body::empty())
492            .unwrap();
493        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
494        assert_eq!(resp.status(), StatusCode::OK);
495    }
496
497    #[test]
498    fn cidr_parse_and_contains() {
499        let cidr = IpCidr::parse("10.0.0.0/8").expect("cidr");
500        assert!(cidr.contains("10.1.2.3".parse().unwrap()));
501        assert!(!cidr.contains("11.1.2.3".parse().unwrap()));
502    }
503
504    #[test]
505    fn trusted_proxy_resolution_prefers_forwarded_when_proxy_trusted() {
506        let req = Request::builder()
507            .header("x-forwarded-for", "1.2.3.4")
508            .header("x-real-ip", "10.0.0.5")
509            .body(Body::empty())
510            .unwrap();
511        let cidr = IpCidr::parse("10.0.0.0/8").unwrap();
512        let ip = resolve_client_ip(&req, &[cidr]);
513        assert_eq!(ip, "1.2.3.4".parse::<IpAddr>().unwrap());
514    }
515
516    #[test]
517    fn untrusted_proxy_resolution_uses_direct_ip() {
518        let req = Request::builder()
519            .header("x-forwarded-for", "1.2.3.4")
520            .header("x-real-ip", "198.51.100.2")
521            .body(Body::empty())
522            .unwrap();
523        let ip = resolve_client_ip(&req, &[]);
524        assert_eq!(ip, "198.51.100.2".parse::<IpAddr>().unwrap());
525    }
526
527    #[test]
528    fn forwarded_header_without_trusted_proxy_is_ignored() {
529        let req = Request::builder()
530            .header("x-forwarded-for", "1.2.3.4")
531            .body(Body::empty())
532            .unwrap();
533        let ip = resolve_client_ip(&req, &[]);
534        assert_eq!(ip, "127.0.0.1".parse::<IpAddr>().unwrap());
535    }
536
537    #[tokio::test]
538    async fn actor_limits_enforced() {
539        let layer =
540            GlobalRateLimitLayer::new(1000, Duration::from_secs(60)).with_per_actor_capacity(2);
541        let mut svc = layer.layer(dummy_service().into_service());
542        for _ in 0..2 {
543            let req = Request::builder()
544                .uri("/")
545                .header("authorization", "Bearer actor-token")
546                .body(Body::empty())
547                .unwrap();
548            let resp = svc.ready().await.unwrap().call(req).await.unwrap();
549            assert_eq!(resp.status(), StatusCode::OK);
550        }
551        let req = Request::builder()
552            .uri("/")
553            .header("authorization", "Bearer actor-token")
554            .body(Body::empty())
555            .unwrap();
556        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
557        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
558    }
559
560    #[tokio::test]
561    async fn snapshot_reflects_throttle_state() {
562        let layer = GlobalRateLimitLayer::new(2, Duration::from_secs(60));
563        let mut svc = layer.layer(dummy_service().into_service());
564
565        // Exhaust global capacity.
566        for _ in 0..2 {
567            let req = Request::builder().uri("/").body(Body::empty()).unwrap();
568            let _ = svc.ready().await.unwrap().call(req).await.unwrap();
569        }
570        // This should be throttled.
571        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
572        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
573        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
574
575        let snap = layer.snapshot().await;
576        assert_eq!(snap.global_count, 2);
577        assert_eq!(snap.global_capacity, 2);
578        assert!(
579            snap.throttled_global >= 1,
580            "should record ≥1 throttled global"
581        );
582        assert_eq!(snap.window_secs, 60);
583    }
584}