1use std::{
2 net::IpAddr,
3 sync::Arc,
4 task::{Context, Poll},
5 time::Duration,
6};
7
8use ::governor::{clock::QuantaInstant, middleware::NoOpMiddleware};
9use anyhow::{Error, anyhow};
10use axum::{body::Body, extract::Request, response::IntoResponse, response::Response};
11use bytes::Bytes;
12use futures::future::BoxFuture;
13use http::{HeaderName, HeaderValue, StatusCode};
14use tower::{Layer, Service};
15use tower_governor::{
16 GovernorError, GovernorLayer,
17 governor::{Governor, GovernorConfig, GovernorConfigBuilder},
18 key_extractor::{GlobalKeyExtractor, KeyExtractor},
19};
20
21use crate::{hname, http::middleware::extract_ip_from_request};
22
23pub type GovernorLayerAxum<K> = GovernorLayer<K, NoOpMiddleware<QuantaInstant>, Body>;
24
25const BYPASS_HEADER: HeaderName = hname!("x-ratelimit-bypass-token");
26
27#[derive(Clone)]
29pub struct IpKeyExtractor;
30
31impl KeyExtractor for IpKeyExtractor {
32 type Key = IpAddr;
33
34 fn extract<B>(&self, req: &Request<B>) -> Result<Self::Key, GovernorError> {
35 extract_ip_from_request(req).ok_or(GovernorError::UnableToExtractKey)
36 }
37}
38
39#[derive(Clone)]
41pub struct RateLimiter<S, K: KeyExtractor> {
42 governor: Governor<K, NoOpMiddleware<QuantaInstant>, S, Body>,
43 bypass_token: Option<String>,
44 inner: S,
45}
46
47impl<S, K> Service<Request> for RateLimiter<S, K>
49where
50 S: Service<Request, Response = Response> + Send + 'static,
51 S::Future: Send + 'static,
52 K: KeyExtractor,
53{
54 type Response = S::Response;
55 type Error = S::Error;
56 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
57
58 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
59 self.inner.poll_ready(cx)
60 }
61
62 fn call(&mut self, request: Request) -> Self::Future {
63 let bypass = request
65 .headers()
66 .get(BYPASS_HEADER)
67 .zip(self.bypass_token.as_ref())
68 .map(|(hdr, token)| hdr.as_bytes() == token.as_bytes())
69 == Some(true);
70
71 if bypass {
73 let fut = self.inner.call(request);
74 return Box::pin(fut);
75 }
76
77 let fut = self.governor.call(request);
79 Box::pin(fut)
80 }
81}
82
83#[derive(Clone, derive_new::new)]
85pub struct RateLimiterLayer<K: KeyExtractor, R> {
86 config: Arc<GovernorConfig<K, NoOpMiddleware<QuantaInstant>>>,
87 rate_limited_response: R,
88 bypass_token: Option<String>,
89}
90
91impl<S, K, R> Layer<S> for RateLimiterLayer<K, R>
92where
93 S: Clone,
94 K: KeyExtractor,
95 R: IntoResponse + Clone + Send + Sync + 'static,
96{
97 type Service = RateLimiter<S, K>;
98
99 fn layer(&self, inner: S) -> Self::Service {
100 let rate_limited_response = self.rate_limited_response.clone();
101
102 let governor = Governor::new(inner.clone(), &self.config).error_handler(move |err| {
103 match err {
104 GovernorError::TooManyRequests { wait_time, headers: _ } => {
105 let mut response = rate_limited_response.clone().into_response();
106 let retry_secs = ((wait_time / 1000).max(1)) as u32;
109 let header_value = HeaderValue::from_maybe_shared(Bytes::from(retry_secs.to_string())).unwrap();
110 response.headers_mut().insert(http::header::RETRY_AFTER, header_value);
111 response
112 },
113 GovernorError::UnableToExtractKey => (
114 StatusCode::INTERNAL_SERVER_ERROR,
115 "Unable to extract rate limiting key",
116 )
117 .into_response(),
118 GovernorError::Other { code, msg, headers } => (
119 StatusCode::INTERNAL_SERVER_ERROR,
120 format!(
121 "Rate limiter failed unexpectedly: code={code}, msg={msg:?}, headers={headers:?}"
122 ),
123 )
124 .into_response()
125 }
126 });
127
128 RateLimiter {
129 governor,
130 bypass_token: self.bypass_token.clone(),
131 inner,
132 }
133 }
134}
135
136pub fn layer_global<R: IntoResponse + Clone + Send + Sync + 'static>(
138 rps: u32,
139 burst_size: u32,
140 rate_limited_response: R,
141 bypass_token: Option<String>,
142) -> Result<RateLimiterLayer<GlobalKeyExtractor, R>, Error> {
143 layer(
144 rps,
145 burst_size,
146 GlobalKeyExtractor,
147 rate_limited_response,
148 bypass_token,
149 )
150}
151
152pub fn layer_by_ip<R: IntoResponse + Clone + Send + Sync + 'static>(
154 rps: u32,
155 burst_size: u32,
156 rate_limited_response: R,
157 bypass_token: Option<String>,
158) -> Result<RateLimiterLayer<IpKeyExtractor, R>, Error> {
159 layer(
160 rps,
161 burst_size,
162 IpKeyExtractor,
163 rate_limited_response,
164 bypass_token,
165 )
166}
167
168pub fn layer<K: KeyExtractor, R: IntoResponse + Clone + Send + Sync + 'static>(
170 rps: u32,
171 burst_size: u32,
172 key_extractor: K,
173 rate_limited_response: R,
174 bypass_token: Option<String>,
175) -> Result<RateLimiterLayer<K, R>, Error> {
176 let period = Duration::from_secs(1)
177 .checked_div(rps)
178 .ok_or_else(|| anyhow!("RPS is zero"))?;
179
180 let config = GovernorConfigBuilder::default()
181 .period(period)
182 .burst_size(burst_size)
183 .key_extractor(key_extractor)
184 .finish()
185 .ok_or_else(|| anyhow!("unable to build governor config"))?;
186
187 Ok(RateLimiterLayer::new(
188 Arc::new(config),
189 rate_limited_response,
190 bypass_token,
191 ))
192}
193
194#[cfg(test)]
195mod test {
196 use super::*;
197
198 use axum::{
199 Router,
200 body::{Body, to_bytes},
201 extract::Request,
202 response::IntoResponse,
203 routing::post,
204 };
205 use http::{Method, StatusCode};
206 use ic_bn_lib_common::types::http::ConnInfo;
207 use std::{sync::Arc, time::Duration};
208 use tokio::time::sleep;
209 use tower::Service;
210
211 async fn handler(_request: Request<Body>) -> impl IntoResponse {
212 "test_call"
213 }
214
215 async fn send_request(
216 router: &mut Router,
217 ) -> Result<http::Response<Body>, std::convert::Infallible> {
218 let conn_info = ConnInfo::default();
219 let mut request = Request::post("/").body(Body::from("".to_string())).unwrap();
220 request.extensions_mut().insert(Arc::new(conn_info));
221 router.call(request).await
222 }
223
224 #[tokio::test]
225 async fn test_rate_limiter_rps_limit() {
226 let rps = 5;
227 let burst_size = 5; let rate_limiter_mw = layer(
230 rps,
231 burst_size,
232 IpKeyExtractor,
233 (StatusCode::TOO_MANY_REQUESTS, "foo"),
234 None,
235 )
236 .expect("failed to build middleware");
237
238 let mut app = Router::new()
239 .route("/", post(handler))
240 .layer(rate_limiter_mw);
241
242 let delay_for_token_ms = 230; let test_cases = vec![
245 (0, StatusCode::OK),
247 (0, StatusCode::OK),
248 (0, StatusCode::OK),
249 (0, StatusCode::OK),
250 (0, StatusCode::OK),
251 (0, StatusCode::TOO_MANY_REQUESTS),
253 (delay_for_token_ms, StatusCode::OK),
255 (0, StatusCode::TOO_MANY_REQUESTS),
257 (2 * delay_for_token_ms, StatusCode::OK),
259 (0, StatusCode::OK),
260 (0, StatusCode::TOO_MANY_REQUESTS),
262 (5 * delay_for_token_ms, StatusCode::OK),
264 (0, StatusCode::OK),
265 (0, StatusCode::OK),
266 (0, StatusCode::OK),
267 (0, StatusCode::OK),
268 (0, StatusCode::TOO_MANY_REQUESTS),
270 (0, StatusCode::TOO_MANY_REQUESTS),
271 ];
272
273 for (idx, (delay_ms, expected_status)) in test_cases.into_iter().enumerate() {
275 if delay_ms > 0 {
276 sleep(Duration::from_millis(delay_ms)).await;
277 }
278 let result = send_request(&mut app).await.unwrap();
279 assert_eq!(result.status(), expected_status, "test {idx} failed");
280
281 if expected_status == StatusCode::TOO_MANY_REQUESTS {
283 let retry_after = result.headers().get(http::header::RETRY_AFTER);
284 assert!(
285 retry_after.is_some(),
286 "test {idx}: Retry-After header missing on 429 response"
287 );
288
289 if let Some(header_value) = retry_after {
291 let retry_secs: u32 = header_value.to_str().unwrap().parse().unwrap();
292 assert!(
293 retry_secs >= 1 && retry_secs <= 10,
294 "test {idx}: Retry-After value {retry_secs} is outside expected range [1, 10]"
295 );
296 }
297 }
298 }
299 }
300
301 #[tokio::test]
302 async fn test_rate_limiter_returns_server_error() {
303 let rps = 1;
304 let burst_size = 1;
305
306 let rate_limiter_mw = layer(
307 rps,
308 burst_size,
309 IpKeyExtractor,
310 (StatusCode::TOO_MANY_REQUESTS, "foo"),
311 None,
312 )
313 .expect("failed to build middleware");
314
315 let mut app = Router::new()
316 .route("/", post(handler))
317 .layer(rate_limiter_mw);
318
319 let request = Request::post("/").body(Body::from("".to_string())).unwrap();
321 let result = app.call(request).await.unwrap();
322
323 assert_eq!(result.status(), StatusCode::INTERNAL_SERVER_ERROR);
324 let body = to_bytes(result.into_body(), 1024).await.unwrap().to_vec();
325 assert_eq!(body, b"Unable to extract rate limiting key");
326 }
327
328 #[tokio::test]
329 async fn test_rate_limiter_bypass_token() {
330 let rate_limiter_mw = layer(
331 1,
332 10,
333 GlobalKeyExtractor,
334 (StatusCode::TOO_MANY_REQUESTS, "foo"),
335 Some("top_secret_token".into()),
336 )
337 .expect("failed to build middleware");
338
339 let mut app = Router::new()
340 .route("/", post(handler))
341 .layer(rate_limiter_mw);
342
343 for _ in 0..10 {
345 let req = Request::builder()
346 .method(Method::POST)
347 .body(Body::empty())
348 .unwrap();
349 let res = app.call(req).await.unwrap();
350 assert_eq!(res.status(), StatusCode::OK);
351 }
352
353 for _ in 0..100 {
355 let req = Request::builder()
356 .method(Method::POST)
357 .body(Body::empty())
358 .unwrap();
359 let res = app.call(req).await.unwrap();
360 assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
361 }
362
363 for _ in 0..100 {
365 let req = Request::builder()
366 .method(Method::POST)
367 .header(BYPASS_HEADER, "top_secret_token")
368 .body(Body::empty())
369 .unwrap();
370 let res = app.call(req).await.unwrap();
371 assert_eq!(res.status(), StatusCode::OK);
372 }
373
374 for _ in 0..100 {
376 let req = Request::builder()
377 .method(Method::POST)
378 .header(BYPASS_HEADER, "not_very_secret")
379 .body(Body::empty())
380 .unwrap();
381 let res = app.call(req).await.unwrap();
382 assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
383 }
384 }
385}