Skip to main content

ic_bn_lib/http/middleware/
rate_limiter.rs

1use std::{
2    net::IpAddr,
3    sync::Arc,
4    task::{Context, Poll},
5    time::Duration,
6};
7
8use ::governor::{clock::QuantaInstant, middleware::NoOpMiddleware};
9use anyhow::{Error, anyhow};
10use axum::{body::Body, extract::Request, response::IntoResponse, response::Response};
11use bytes::Bytes;
12use futures::future::BoxFuture;
13use http::{HeaderName, HeaderValue, StatusCode};
14use tower::{Layer, Service};
15use tower_governor::{
16    GovernorError, GovernorLayer,
17    governor::{Governor, GovernorConfig, GovernorConfigBuilder},
18    key_extractor::{GlobalKeyExtractor, KeyExtractor},
19};
20
21use crate::{hname, http::middleware::extract_ip_from_request};
22
23pub type GovernorLayerAxum<K> = GovernorLayer<K, NoOpMiddleware<QuantaInstant>, Body>;
24
25const BYPASS_HEADER: HeaderName = hname!("x-ratelimit-bypass-token");
26
27/// Extracts an IP from the request as a rate-limiting key
28#[derive(Clone)]
29pub struct IpKeyExtractor;
30
31impl KeyExtractor for IpKeyExtractor {
32    type Key = IpAddr;
33
34    fn extract<B>(&self, req: &Request<B>) -> Result<Self::Key, GovernorError> {
35        extract_ip_from_request(req).ok_or(GovernorError::UnableToExtractKey)
36    }
37}
38
39/// Ratelimiter that implements Tower Service
40#[derive(Clone)]
41pub struct RateLimiter<S, K: KeyExtractor> {
42    governor: Governor<K, NoOpMiddleware<QuantaInstant>, S, Body>,
43    bypass_token: Option<String>,
44    inner: S,
45}
46
47/// Implement Tower Service for RateLimiter
48impl<S, K> Service<Request> for RateLimiter<S, K>
49where
50    S: Service<Request, Response = Response> + Send + 'static,
51    S::Future: Send + 'static,
52    K: KeyExtractor,
53{
54    type Response = S::Response;
55    type Error = S::Error;
56    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
57
58    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
59        self.inner.poll_ready(cx)
60    }
61
62    fn call(&mut self, request: Request) -> Self::Future {
63        // Check that bypass token is configured, header was sent and it matches
64        let bypass = request
65            .headers()
66            .get(BYPASS_HEADER)
67            .zip(self.bypass_token.as_ref())
68            .map(|(hdr, token)| hdr.as_bytes() == token.as_bytes())
69            == Some(true);
70
71        // If bypassing - call the wrapped service directly
72        if bypass {
73            let fut = self.inner.call(request);
74            return Box::pin(fut);
75        }
76
77        // Otherwise go through Governor
78        let fut = self.governor.call(request);
79        Box::pin(fut)
80    }
81}
82
83/// Layer usable as an Axum middleware
84#[derive(Clone, derive_new::new)]
85pub struct RateLimiterLayer<K: KeyExtractor, R> {
86    config: Arc<GovernorConfig<K, NoOpMiddleware<QuantaInstant>>>,
87    rate_limited_response: R,
88    bypass_token: Option<String>,
89}
90
91impl<S, K, R> Layer<S> for RateLimiterLayer<K, R>
92where
93    S: Clone,
94    K: KeyExtractor,
95    R: IntoResponse + Clone + Send + Sync + 'static,
96{
97    type Service = RateLimiter<S, K>;
98
99    fn layer(&self, inner: S) -> Self::Service {
100        let rate_limited_response = self.rate_limited_response.clone();
101
102        let governor = Governor::new(inner.clone(), &self.config).error_handler(move |err| {
103            match err {
104                GovernorError::TooManyRequests { wait_time, headers: _ } => {
105                    let mut response = rate_limited_response.clone().into_response();
106                    // Add Retry-After header using timing from governor
107                    // wait_time is in milliseconds, convert to seconds (minimum 1 second)
108                    let retry_secs = ((wait_time / 1000).max(1)) as u32;
109                    let header_value = HeaderValue::from_maybe_shared(Bytes::from(retry_secs.to_string())).unwrap();
110                    response.headers_mut().insert(http::header::RETRY_AFTER, header_value);
111                    response
112                },
113                GovernorError::UnableToExtractKey => (
114                    StatusCode::INTERNAL_SERVER_ERROR,
115                    "Unable to extract rate limiting key",
116                )
117                    .into_response(),
118                GovernorError::Other { code, msg, headers } => (
119                    StatusCode::INTERNAL_SERVER_ERROR,
120                    format!(
121                        "Rate limiter failed unexpectedly: code={code}, msg={msg:?}, headers={headers:?}"
122                    ),
123                )
124                    .into_response()
125            }
126        });
127
128        RateLimiter {
129            governor,
130            bypass_token: self.bypass_token.clone(),
131            inner,
132        }
133    }
134}
135
136/// Create unkeyed rate-limiter
137pub fn layer_global<R: IntoResponse + Clone + Send + Sync + 'static>(
138    rps: u32,
139    burst_size: u32,
140    rate_limited_response: R,
141    bypass_token: Option<String>,
142) -> Result<RateLimiterLayer<GlobalKeyExtractor, R>, Error> {
143    layer(
144        rps,
145        burst_size,
146        GlobalKeyExtractor,
147        rate_limited_response,
148        bypass_token,
149    )
150}
151
152/// Create ratelimiter keyed by IP
153pub fn layer_by_ip<R: IntoResponse + Clone + Send + Sync + 'static>(
154    rps: u32,
155    burst_size: u32,
156    rate_limited_response: R,
157    bypass_token: Option<String>,
158) -> Result<RateLimiterLayer<IpKeyExtractor, R>, Error> {
159    layer(
160        rps,
161        burst_size,
162        IpKeyExtractor,
163        rate_limited_response,
164        bypass_token,
165    )
166}
167
168/// Create a ratelimiter with a provided key extractor
169pub fn layer<K: KeyExtractor, R: IntoResponse + Clone + Send + Sync + 'static>(
170    rps: u32,
171    burst_size: u32,
172    key_extractor: K,
173    rate_limited_response: R,
174    bypass_token: Option<String>,
175) -> Result<RateLimiterLayer<K, R>, Error> {
176    let period = Duration::from_secs(1)
177        .checked_div(rps)
178        .ok_or_else(|| anyhow!("RPS is zero"))?;
179
180    let config = GovernorConfigBuilder::default()
181        .period(period)
182        .burst_size(burst_size)
183        .key_extractor(key_extractor)
184        .finish()
185        .ok_or_else(|| anyhow!("unable to build governor config"))?;
186
187    Ok(RateLimiterLayer::new(
188        Arc::new(config),
189        rate_limited_response,
190        bypass_token,
191    ))
192}
193
194#[cfg(test)]
195mod test {
196    use super::*;
197
198    use axum::{
199        Router,
200        body::{Body, to_bytes},
201        extract::Request,
202        response::IntoResponse,
203        routing::post,
204    };
205    use http::{Method, StatusCode};
206    use ic_bn_lib_common::types::http::ConnInfo;
207    use std::{sync::Arc, time::Duration};
208    use tokio::time::sleep;
209    use tower::Service;
210
211    async fn handler(_request: Request<Body>) -> impl IntoResponse {
212        "test_call"
213    }
214
215    async fn send_request(
216        router: &mut Router,
217    ) -> Result<http::Response<Body>, std::convert::Infallible> {
218        let conn_info = ConnInfo::default();
219        let mut request = Request::post("/").body(Body::from("".to_string())).unwrap();
220        request.extensions_mut().insert(Arc::new(conn_info));
221        router.call(request).await
222    }
223
224    #[tokio::test]
225    async fn test_rate_limiter_rps_limit() {
226        let rps = 5;
227        let burst_size = 5; // how many requests can go through at once (without delay)
228
229        let rate_limiter_mw = layer(
230            rps,
231            burst_size,
232            IpKeyExtractor,
233            (StatusCode::TOO_MANY_REQUESTS, "foo"),
234            None,
235        )
236        .expect("failed to build middleware");
237
238        let mut app = Router::new()
239            .route("/", post(handler))
240            .layer(rate_limiter_mw);
241
242        // Test cases: (delay_ms, expected_status)
243        let delay_for_token_ms = 230; // when a token should become available ~ 1000ms/rps=200ms (we add some delta=30 ms to avoid flakiness)
244        let test_cases = vec![
245            // Initial burst of 5 requests should succeed and fills full burst capacity
246            (0, StatusCode::OK),
247            (0, StatusCode::OK),
248            (0, StatusCode::OK),
249            (0, StatusCode::OK),
250            (0, StatusCode::OK),
251            // For 6th request no tokens left => 429
252            (0, StatusCode::TOO_MANY_REQUESTS),
253            // Wait for 1 token to be available
254            (delay_for_token_ms, StatusCode::OK),
255            // Bucket is empty again, request should fail
256            (0, StatusCode::TOO_MANY_REQUESTS),
257            // Wait for 2 tokens to be available, next 2 requests succeed
258            (2 * delay_for_token_ms, StatusCode::OK),
259            (0, StatusCode::OK),
260            // Bucket is empty again, request should fail
261            (0, StatusCode::TOO_MANY_REQUESTS),
262            // Wait for 5 tokens, next 5 requests succeed
263            (5 * delay_for_token_ms, StatusCode::OK),
264            (0, StatusCode::OK),
265            (0, StatusCode::OK),
266            (0, StatusCode::OK),
267            (0, StatusCode::OK),
268            // Bucket is empty again, requests should fail
269            (0, StatusCode::TOO_MANY_REQUESTS),
270            (0, StatusCode::TOO_MANY_REQUESTS),
271        ];
272
273        // Execute all tests
274        for (idx, (delay_ms, expected_status)) in test_cases.into_iter().enumerate() {
275            if delay_ms > 0 {
276                sleep(Duration::from_millis(delay_ms)).await;
277            }
278            let result = send_request(&mut app).await.unwrap();
279            assert_eq!(result.status(), expected_status, "test {idx} failed");
280
281            // Verify Retry-After header is present on rate-limited responses
282            if expected_status == StatusCode::TOO_MANY_REQUESTS {
283                let retry_after = result.headers().get(http::header::RETRY_AFTER);
284                assert!(
285                    retry_after.is_some(),
286                    "test {idx}: Retry-After header missing on 429 response"
287                );
288
289                // Verify the header value is a valid number and reasonable (between 1 and 10 seconds)
290                if let Some(header_value) = retry_after {
291                    let retry_secs: u32 = header_value.to_str().unwrap().parse().unwrap();
292                    assert!(
293                        retry_secs >= 1 && retry_secs <= 10,
294                        "test {idx}: Retry-After value {retry_secs} is outside expected range [1, 10]"
295                    );
296                }
297            }
298        }
299    }
300
301    #[tokio::test]
302    async fn test_rate_limiter_returns_server_error() {
303        let rps = 1;
304        let burst_size = 1;
305
306        let rate_limiter_mw = layer(
307            rps,
308            burst_size,
309            IpKeyExtractor,
310            (StatusCode::TOO_MANY_REQUESTS, "foo"),
311            None,
312        )
313        .expect("failed to build middleware");
314
315        let mut app = Router::new()
316            .route("/", post(handler))
317            .layer(rate_limiter_mw);
318
319        // Send request without connection info, i.e. without ip address.
320        let request = Request::post("/").body(Body::from("".to_string())).unwrap();
321        let result = app.call(request).await.unwrap();
322
323        assert_eq!(result.status(), StatusCode::INTERNAL_SERVER_ERROR);
324        let body = to_bytes(result.into_body(), 1024).await.unwrap().to_vec();
325        assert_eq!(body, b"Unable to extract rate limiting key");
326    }
327
328    #[tokio::test]
329    async fn test_rate_limiter_bypass_token() {
330        let rate_limiter_mw = layer(
331            1,
332            10,
333            GlobalKeyExtractor,
334            (StatusCode::TOO_MANY_REQUESTS, "foo"),
335            Some("top_secret_token".into()),
336        )
337        .expect("failed to build middleware");
338
339        let mut app = Router::new()
340            .route("/", post(handler))
341            .layer(rate_limiter_mw);
342
343        // First 10 pass
344        for _ in 0..10 {
345            let req = Request::builder()
346                .method(Method::POST)
347                .body(Body::empty())
348                .unwrap();
349            let res = app.call(req).await.unwrap();
350            assert_eq!(res.status(), StatusCode::OK);
351        }
352
353        // Then all blocked
354        for _ in 0..100 {
355            let req = Request::builder()
356                .method(Method::POST)
357                .body(Body::empty())
358                .unwrap();
359            let res = app.call(req).await.unwrap();
360            assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
361        }
362
363        // But pass with a token
364        for _ in 0..100 {
365            let req = Request::builder()
366                .method(Method::POST)
367                .header(BYPASS_HEADER, "top_secret_token")
368                .body(Body::empty())
369                .unwrap();
370            let res = app.call(req).await.unwrap();
371            assert_eq!(res.status(), StatusCode::OK);
372        }
373
374        // And doesn't work with a bad token
375        for _ in 0..100 {
376            let req = Request::builder()
377                .method(Method::POST)
378                .header(BYPASS_HEADER, "not_very_secret")
379                .body(Body::empty())
380                .unwrap();
381            let res = app.call(req).await.unwrap();
382            assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
383        }
384    }
385}