Skip to main content

ic_bn_lib/http/middleware/
rate_limiter.rs

1use std::{
2    net::IpAddr,
3    sync::Arc,
4    task::{Context, Poll},
5    time::Duration,
6};
7
8use ::governor::{clock::QuantaInstant, middleware::NoOpMiddleware};
9use anyhow::{Error, anyhow};
10use axum::{body::Body, extract::Request, response::IntoResponse, response::Response};
11use futures::future::BoxFuture;
12use http::{HeaderName, StatusCode};
13use tower::{Layer, Service};
14use tower_governor::{
15    GovernorError, GovernorLayer,
16    governor::{Governor, GovernorConfig, GovernorConfigBuilder},
17    key_extractor::{GlobalKeyExtractor, KeyExtractor},
18};
19
20use crate::{hname, http::middleware::extract_ip_from_request};
21
22pub type GovernorLayerAxum<K> = GovernorLayer<K, NoOpMiddleware<QuantaInstant>, Body>;
23
24const BYPASS_HEADER: HeaderName = hname!("x-ratelimit-bypass-token");
25
26/// Extracts an IP from the request as a rate-limiting key
27#[derive(Clone)]
28pub struct IpKeyExtractor;
29
30impl KeyExtractor for IpKeyExtractor {
31    type Key = IpAddr;
32
33    fn extract<B>(&self, req: &Request<B>) -> Result<Self::Key, GovernorError> {
34        extract_ip_from_request(req).ok_or(GovernorError::UnableToExtractKey)
35    }
36}
37
38/// Ratelimiter that implements Tower Service
39#[derive(Clone)]
40pub struct RateLimiter<S, K: KeyExtractor> {
41    governor: Governor<K, NoOpMiddleware<QuantaInstant>, S, Body>,
42    bypass_token: Option<String>,
43    inner: S,
44}
45
46/// Implement Tower Service for RateLimiter
47impl<S, K> Service<Request> for RateLimiter<S, K>
48where
49    S: Service<Request, Response = Response> + Send + 'static,
50    S::Future: Send + 'static,
51    K: KeyExtractor,
52{
53    type Response = S::Response;
54    type Error = S::Error;
55    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
56
57    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
58        self.inner.poll_ready(cx)
59    }
60
61    fn call(&mut self, request: Request) -> Self::Future {
62        // Check that bypass token is configured, header was sent and it matches
63        let bypass = request
64            .headers()
65            .get(BYPASS_HEADER)
66            .zip(self.bypass_token.as_ref())
67            .map(|(hdr, token)| hdr.as_bytes() == token.as_bytes())
68            == Some(true);
69
70        // If bypassing - call the wrapped service directly
71        if bypass {
72            let fut = self.inner.call(request);
73            return Box::pin(fut);
74        }
75
76        // Otherwise go through Governor
77        let fut = self.governor.call(request);
78        Box::pin(fut)
79    }
80}
81
82/// Layer usable as an Axum middleware
83#[derive(Clone, derive_new::new)]
84pub struct RateLimiterLayer<K: KeyExtractor, R> {
85    config: Arc<GovernorConfig<K, NoOpMiddleware<QuantaInstant>>>,
86    rate_limited_response: R,
87    bypass_token: Option<String>,
88}
89
90impl<S, K, R> Layer<S> for RateLimiterLayer<K, R>
91where
92    S: Clone,
93    K: KeyExtractor,
94    R: IntoResponse + Clone + Send + Sync + 'static,
95{
96    type Service = RateLimiter<S, K>;
97
98    fn layer(&self, inner: S) -> Self::Service {
99        let rate_limited_response = self.rate_limited_response.clone();
100
101        let governor = Governor::new(inner.clone(), &self.config).error_handler(move |err| {
102            match err {
103                GovernorError::TooManyRequests { .. } => rate_limited_response.clone().into_response(),
104                GovernorError::UnableToExtractKey => (
105                    StatusCode::INTERNAL_SERVER_ERROR,
106                    "Unable to extract rate limiting key",
107                )
108                    .into_response(),
109                GovernorError::Other { code, msg, headers } => (
110                    StatusCode::INTERNAL_SERVER_ERROR,
111                    format!(
112                        "Rate limiter failed unexpectedly: code={code}, msg={msg:?}, headers={headers:?}"
113                    ),
114                )
115                    .into_response()
116            }
117        });
118
119        RateLimiter {
120            governor,
121            bypass_token: self.bypass_token.clone(),
122            inner,
123        }
124    }
125}
126
127/// Create unkeyed rate-limiter
128pub fn layer_global<R: IntoResponse + Clone + Send + Sync + 'static>(
129    rps: u32,
130    burst_size: u32,
131    rate_limited_response: R,
132    bypass_token: Option<String>,
133) -> Result<RateLimiterLayer<GlobalKeyExtractor, R>, Error> {
134    layer(
135        rps,
136        burst_size,
137        GlobalKeyExtractor,
138        rate_limited_response,
139        bypass_token,
140    )
141}
142
143/// Create ratelimiter keyed by IP
144pub fn layer_by_ip<R: IntoResponse + Clone + Send + Sync + 'static>(
145    rps: u32,
146    burst_size: u32,
147    rate_limited_response: R,
148    bypass_token: Option<String>,
149) -> Result<RateLimiterLayer<IpKeyExtractor, R>, Error> {
150    layer(
151        rps,
152        burst_size,
153        IpKeyExtractor,
154        rate_limited_response,
155        bypass_token,
156    )
157}
158
159/// Create a ratelimiter with a provided key extractor
160pub fn layer<K: KeyExtractor, R: IntoResponse + Clone + Send + Sync + 'static>(
161    rps: u32,
162    burst_size: u32,
163    key_extractor: K,
164    rate_limited_response: R,
165    bypass_token: Option<String>,
166) -> Result<RateLimiterLayer<K, R>, Error> {
167    let period = Duration::from_secs(1)
168        .checked_div(rps)
169        .ok_or_else(|| anyhow!("RPS is zero"))?;
170
171    let config = GovernorConfigBuilder::default()
172        .period(period)
173        .burst_size(burst_size)
174        .key_extractor(key_extractor)
175        .finish()
176        .ok_or_else(|| anyhow!("unable to build governor config"))?;
177
178    Ok(RateLimiterLayer::new(
179        Arc::new(config),
180        rate_limited_response,
181        bypass_token,
182    ))
183}
184
185#[cfg(test)]
186mod test {
187    use super::*;
188
189    use axum::{
190        Router,
191        body::{Body, to_bytes},
192        extract::Request,
193        response::IntoResponse,
194        routing::post,
195    };
196    use http::{Method, StatusCode};
197    use ic_bn_lib_common::types::http::ConnInfo;
198    use std::{sync::Arc, time::Duration};
199    use tokio::time::sleep;
200    use tower::Service;
201
202    async fn handler(_request: Request<Body>) -> impl IntoResponse {
203        "test_call"
204    }
205
206    async fn send_request(
207        router: &mut Router,
208    ) -> Result<http::Response<Body>, std::convert::Infallible> {
209        let conn_info = ConnInfo::default();
210        let mut request = Request::post("/").body(Body::from("".to_string())).unwrap();
211        request.extensions_mut().insert(Arc::new(conn_info));
212        router.call(request).await
213    }
214
215    #[tokio::test]
216    async fn test_rate_limiter_rps_limit() {
217        let rps = 5;
218        let burst_size = 5; // how many requests can go through at once (without delay)
219
220        let rate_limiter_mw = layer(
221            rps,
222            burst_size,
223            IpKeyExtractor,
224            (StatusCode::TOO_MANY_REQUESTS, "foo"),
225            None,
226        )
227        .expect("failed to build middleware");
228
229        let mut app = Router::new()
230            .route("/", post(handler))
231            .layer(rate_limiter_mw);
232
233        // Test cases: (delay_ms, expected_status)
234        let delay_for_token_ms = 230; // when a token should become available ~ 1000ms/rps=200ms (we add some delta=30 ms to avoid flakiness)
235        let test_cases = vec![
236            // Initial burst of 5 requests should succeed and fills full burst capacity
237            (0, StatusCode::OK),
238            (0, StatusCode::OK),
239            (0, StatusCode::OK),
240            (0, StatusCode::OK),
241            (0, StatusCode::OK),
242            // For 6th request no tokens left => 429
243            (0, StatusCode::TOO_MANY_REQUESTS),
244            // Wait for 1 token to be available
245            (delay_for_token_ms, StatusCode::OK),
246            // Bucket is empty again, request should fail
247            (0, StatusCode::TOO_MANY_REQUESTS),
248            // Wait for 2 tokens to be available, next 2 requests succeed
249            (2 * delay_for_token_ms, StatusCode::OK),
250            (0, StatusCode::OK),
251            // Bucket is empty again, request should fail
252            (0, StatusCode::TOO_MANY_REQUESTS),
253            // Wait for 5 tokens, next 5 requests succeed
254            (5 * delay_for_token_ms, StatusCode::OK),
255            (0, StatusCode::OK),
256            (0, StatusCode::OK),
257            (0, StatusCode::OK),
258            (0, StatusCode::OK),
259            // Bucket is empty again, requests should fail
260            (0, StatusCode::TOO_MANY_REQUESTS),
261            (0, StatusCode::TOO_MANY_REQUESTS),
262        ];
263
264        // Execute all tests
265        for (idx, (delay_ms, expected_status)) in test_cases.into_iter().enumerate() {
266            if delay_ms > 0 {
267                sleep(Duration::from_millis(delay_ms)).await;
268            }
269            let result = send_request(&mut app).await.unwrap();
270            assert_eq!(result.status(), expected_status, "test {idx} failed");
271        }
272    }
273
274    #[tokio::test]
275    async fn test_rate_limiter_returns_server_error() {
276        let rps = 1;
277        let burst_size = 1;
278
279        let rate_limiter_mw = layer(
280            rps,
281            burst_size,
282            IpKeyExtractor,
283            (StatusCode::TOO_MANY_REQUESTS, "foo"),
284            None,
285        )
286        .expect("failed to build middleware");
287
288        let mut app = Router::new()
289            .route("/", post(handler))
290            .layer(rate_limiter_mw);
291
292        // Send request without connection info, i.e. without ip address.
293        let request = Request::post("/").body(Body::from("".to_string())).unwrap();
294        let result = app.call(request).await.unwrap();
295
296        assert_eq!(result.status(), StatusCode::INTERNAL_SERVER_ERROR);
297        let body = to_bytes(result.into_body(), 1024).await.unwrap().to_vec();
298        assert_eq!(body, b"Unable to extract rate limiting key");
299    }
300
301    #[tokio::test]
302    async fn test_rate_limiter_bypass_token() {
303        let rate_limiter_mw = layer(
304            1,
305            10,
306            GlobalKeyExtractor,
307            (StatusCode::TOO_MANY_REQUESTS, "foo"),
308            Some("top_secret_token".into()),
309        )
310        .expect("failed to build middleware");
311
312        let mut app = Router::new()
313            .route("/", post(handler))
314            .layer(rate_limiter_mw);
315
316        // First 10 pass
317        for _ in 0..10 {
318            let req = Request::builder()
319                .method(Method::POST)
320                .body(Body::empty())
321                .unwrap();
322            let res = app.call(req).await.unwrap();
323            assert_eq!(res.status(), StatusCode::OK);
324        }
325
326        // Then all blocked
327        for _ in 0..100 {
328            let req = Request::builder()
329                .method(Method::POST)
330                .body(Body::empty())
331                .unwrap();
332            let res = app.call(req).await.unwrap();
333            assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
334        }
335
336        // But pass with a token
337        for _ in 0..100 {
338            let req = Request::builder()
339                .method(Method::POST)
340                .header(BYPASS_HEADER, "top_secret_token")
341                .body(Body::empty())
342                .unwrap();
343            let res = app.call(req).await.unwrap();
344            assert_eq!(res.status(), StatusCode::OK);
345        }
346
347        // And doesn't work with a bad token
348        for _ in 0..100 {
349            let req = Request::builder()
350                .method(Method::POST)
351                .header(BYPASS_HEADER, "not_very_secret")
352                .body(Body::empty())
353                .unwrap();
354            let res = app.call(req).await.unwrap();
355            assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
356        }
357    }
358}