Skip to main content

api_gateway/middleware/
rate_limit.rs

1use crate::config::ApiGatewayConfig;
2use anyhow::{Context, Result, anyhow};
3use axum::http::{HeaderValue, Method, StatusCode, header};
4use axum::{
5    extract::Request,
6    middleware::Next,
7    response::{IntoResponse, Response},
8};
9use governor::clock::Clock;
10use governor::middleware::StateInformationMiddleware;
11use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
12use std::collections::HashMap;
13use std::num::NonZeroU32;
14use std::sync::Arc;
15use tokio::sync::Semaphore;
16
17use crate::middleware::common;
18
19type RateLimitKey = (Method, String);
20type BucketMap = Arc<HashMap<RateLimitKey, Arc<BucketMapEntry>>>;
21type InflightMap = Arc<HashMap<RateLimitKey, Arc<Semaphore>>>;
22
23#[derive(Default, Clone)]
24pub struct RateLimiterMap {
25    buckets: BucketMap,
26    inflight: InflightMap,
27}
28
29struct BucketMapEntry {
30    bucket: DefaultDirectRateLimiter<StateInformationMiddleware>,
31    policy: HeaderValue,
32    burst: HeaderValue,
33}
34
35impl BucketMapEntry {
36    pub fn new(rps: u32, burst: u32) -> Result<Self> {
37        let bucket = RateLimiter::direct(
38            Quota::per_second(NonZeroU32::new(rps).with_context(|| anyhow!("rps is zero"))?)
39                .allow_burst(NonZeroU32::new(burst).with_context(|| anyhow!("burst is zero"))?),
40        )
41        .with_middleware::<StateInformationMiddleware>();
42        let policy = HeaderValue::from_str(&format!("\"burst\";q={burst};w={rps}"))
43            .context("Failed to create rate limit policy")?;
44        Ok(Self {
45            bucket,
46            policy,
47            burst: burst.into(),
48        })
49    }
50}
51
52impl RateLimiterMap {
53    /// # Errors
54    /// Returns an error if any rate limit spec is 0.
55    pub fn from_specs(
56        specs: &Vec<modkit::api::OperationSpec>,
57        cfg: &ApiGatewayConfig,
58    ) -> Result<Self> {
59        let mut buckets = HashMap::new();
60        let mut inflight = HashMap::new();
61        // TODO: Add support for per-route rate limiting
62        for spec in specs {
63            let (rps, burst, max_in_flight) = spec.rate_limit.as_ref().map_or(
64                (
65                    cfg.defaults.rate_limit.rps,
66                    cfg.defaults.rate_limit.burst,
67                    cfg.defaults.rate_limit.in_flight,
68                ),
69                |r| (r.rps, r.burst, r.in_flight),
70            );
71
72            let key = (spec.method.clone(), spec.path.clone());
73            buckets.insert(
74                key.clone(),
75                Arc::new(
76                    BucketMapEntry::new(rps, burst)
77                        .with_context(|| anyhow!("RateLimit spec invalid {spec:?} invalid"))?,
78                ),
79            );
80            inflight.insert(key, Arc::new(Semaphore::new(max_in_flight as usize)));
81        }
82        Ok(Self {
83            buckets: Arc::new(buckets),
84            inflight: Arc::new(inflight),
85        })
86    }
87}
88
89// TODO: Use tower-governor instead of own implementation (upd: https://github.com/benwis/tower-governor/issues/59 )
90pub async fn rate_limit_middleware(map: RateLimiterMap, mut req: Request, next: Next) -> Response {
91    let method = req.method().clone();
92    // Use MatchedPath extension (set by Axum router) for accurate route matching
93    let path = req
94        .extensions()
95        .get::<axum::extract::MatchedPath>()
96        .map_or_else(|| req.uri().path().to_owned(), |p| p.as_str().to_owned());
97
98    let path = common::resolve_path(&req, path.as_str());
99
100    let key = (method, path);
101
102    if let Some(bucker_map_entry) = map.buckets.get(&key) {
103        let headers = req.headers_mut();
104        headers.insert("RateLimit-Policy", bucker_map_entry.policy.clone());
105        match bucker_map_entry.bucket.check() {
106            Ok(state) => {
107                headers.insert("RateLimit-Limit", bucker_map_entry.burst.clone());
108                headers.insert(
109                    "RateLimit-Limit-Remaining",
110                    state.remaining_burst_capacity().into(),
111                );
112                headers.insert("X-RateLimit-Limit", bucker_map_entry.burst.clone());
113                headers.insert(
114                    "X-RateLimit-Remaining",
115                    state.remaining_burst_capacity().into(),
116                );
117            }
118            Err(not_until) => {
119                let wait = not_until.wait_time_from(bucker_map_entry.bucket.clock().now());
120                headers.insert(header::RETRY_AFTER, wait.as_secs().into());
121                return StatusCode::TOO_MANY_REQUESTS.into_response();
122            }
123        }
124    }
125
126    if let Some(sem) = map.inflight.get(&key) {
127        match sem.clone().try_acquire_owned() {
128            Ok(_permit) => {
129                // Allow request; permit is dropped when response future completes
130                return next.run(req).await;
131            }
132            Err(_) => {
133                return StatusCode::SERVICE_UNAVAILABLE.into_response();
134            }
135        }
136    }
137
138    next.run(req).await
139}