1use std::{
2 net::IpAddr,
3 sync::Arc,
4 task::{Context, Poll},
5 time::Duration,
6};
7
8use ::governor::{clock::QuantaInstant, middleware::NoOpMiddleware};
9use anyhow::{Error, anyhow};
10use axum::{body::Body, extract::Request, response::IntoResponse, response::Response};
11use futures::future::BoxFuture;
12use http::{HeaderName, StatusCode};
13use tower::{Layer, Service};
14use tower_governor::{
15 GovernorError, GovernorLayer,
16 governor::{Governor, GovernorConfig, GovernorConfigBuilder},
17 key_extractor::{GlobalKeyExtractor, KeyExtractor},
18};
19
20use crate::{hname, http::middleware::extract_ip_from_request};
21
22pub type GovernorLayerAxum<K> = GovernorLayer<K, NoOpMiddleware<QuantaInstant>, Body>;
23
24const BYPASS_HEADER: HeaderName = hname!("x-ratelimit-bypass-token");
25
26#[derive(Clone)]
28pub struct IpKeyExtractor;
29
30impl KeyExtractor for IpKeyExtractor {
31 type Key = IpAddr;
32
33 fn extract<B>(&self, req: &Request<B>) -> Result<Self::Key, GovernorError> {
34 extract_ip_from_request(req).ok_or(GovernorError::UnableToExtractKey)
35 }
36}
37
38#[derive(Clone)]
40pub struct RateLimiter<S, K: KeyExtractor> {
41 governor: Governor<K, NoOpMiddleware<QuantaInstant>, S, Body>,
42 bypass_token: Option<String>,
43 inner: S,
44}
45
46impl<S, K> Service<Request> for RateLimiter<S, K>
48where
49 S: Service<Request, Response = Response> + Send + 'static,
50 S::Future: Send + 'static,
51 K: KeyExtractor,
52{
53 type Response = S::Response;
54 type Error = S::Error;
55 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
56
57 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
58 self.inner.poll_ready(cx)
59 }
60
61 fn call(&mut self, request: Request) -> Self::Future {
62 let bypass = request
64 .headers()
65 .get(BYPASS_HEADER)
66 .zip(self.bypass_token.as_ref())
67 .map(|(hdr, token)| hdr.as_bytes() == token.as_bytes())
68 == Some(true);
69
70 if bypass {
72 let fut = self.inner.call(request);
73 return Box::pin(fut);
74 }
75
76 let fut = self.governor.call(request);
78 Box::pin(fut)
79 }
80}
81
82#[derive(Clone, derive_new::new)]
84pub struct RateLimiterLayer<K: KeyExtractor, R> {
85 config: Arc<GovernorConfig<K, NoOpMiddleware<QuantaInstant>>>,
86 rate_limited_response: R,
87 bypass_token: Option<String>,
88}
89
90impl<S, K, R> Layer<S> for RateLimiterLayer<K, R>
91where
92 S: Clone,
93 K: KeyExtractor,
94 R: IntoResponse + Clone + Send + Sync + 'static,
95{
96 type Service = RateLimiter<S, K>;
97
98 fn layer(&self, inner: S) -> Self::Service {
99 let rate_limited_response = self.rate_limited_response.clone();
100
101 let governor = Governor::new(inner.clone(), &self.config).error_handler(move |err| {
102 match err {
103 GovernorError::TooManyRequests { .. } => rate_limited_response.clone().into_response(),
104 GovernorError::UnableToExtractKey => (
105 StatusCode::INTERNAL_SERVER_ERROR,
106 "Unable to extract rate limiting key",
107 )
108 .into_response(),
109 GovernorError::Other { code, msg, headers } => (
110 StatusCode::INTERNAL_SERVER_ERROR,
111 format!(
112 "Rate limiter failed unexpectedly: code={code}, msg={msg:?}, headers={headers:?}"
113 ),
114 )
115 .into_response()
116 }
117 });
118
119 RateLimiter {
120 governor,
121 bypass_token: self.bypass_token.clone(),
122 inner,
123 }
124 }
125}
126
127pub fn layer_global<R: IntoResponse + Clone + Send + Sync + 'static>(
129 rps: u32,
130 burst_size: u32,
131 rate_limited_response: R,
132 bypass_token: Option<String>,
133) -> Result<RateLimiterLayer<GlobalKeyExtractor, R>, Error> {
134 layer(
135 rps,
136 burst_size,
137 GlobalKeyExtractor,
138 rate_limited_response,
139 bypass_token,
140 )
141}
142
143pub fn layer_by_ip<R: IntoResponse + Clone + Send + Sync + 'static>(
145 rps: u32,
146 burst_size: u32,
147 rate_limited_response: R,
148 bypass_token: Option<String>,
149) -> Result<RateLimiterLayer<IpKeyExtractor, R>, Error> {
150 layer(
151 rps,
152 burst_size,
153 IpKeyExtractor,
154 rate_limited_response,
155 bypass_token,
156 )
157}
158
159pub fn layer<K: KeyExtractor, R: IntoResponse + Clone + Send + Sync + 'static>(
161 rps: u32,
162 burst_size: u32,
163 key_extractor: K,
164 rate_limited_response: R,
165 bypass_token: Option<String>,
166) -> Result<RateLimiterLayer<K, R>, Error> {
167 let period = Duration::from_secs(1)
168 .checked_div(rps)
169 .ok_or_else(|| anyhow!("RPS is zero"))?;
170
171 let config = GovernorConfigBuilder::default()
172 .period(period)
173 .burst_size(burst_size)
174 .key_extractor(key_extractor)
175 .finish()
176 .ok_or_else(|| anyhow!("unable to build governor config"))?;
177
178 Ok(RateLimiterLayer::new(
179 Arc::new(config),
180 rate_limited_response,
181 bypass_token,
182 ))
183}
184
185#[cfg(test)]
186mod test {
187 use super::*;
188
189 use axum::{
190 Router,
191 body::{Body, to_bytes},
192 extract::Request,
193 response::IntoResponse,
194 routing::post,
195 };
196 use http::{Method, StatusCode};
197 use ic_bn_lib_common::types::http::ConnInfo;
198 use std::{sync::Arc, time::Duration};
199 use tokio::time::sleep;
200 use tower::Service;
201
202 async fn handler(_request: Request<Body>) -> impl IntoResponse {
203 "test_call"
204 }
205
206 async fn send_request(
207 router: &mut Router,
208 ) -> Result<http::Response<Body>, std::convert::Infallible> {
209 let conn_info = ConnInfo::default();
210 let mut request = Request::post("/").body(Body::from("".to_string())).unwrap();
211 request.extensions_mut().insert(Arc::new(conn_info));
212 router.call(request).await
213 }
214
215 #[tokio::test]
216 async fn test_rate_limiter_rps_limit() {
217 let rps = 5;
218 let burst_size = 5; let rate_limiter_mw = layer(
221 rps,
222 burst_size,
223 IpKeyExtractor,
224 (StatusCode::TOO_MANY_REQUESTS, "foo"),
225 None,
226 )
227 .expect("failed to build middleware");
228
229 let mut app = Router::new()
230 .route("/", post(handler))
231 .layer(rate_limiter_mw);
232
233 let delay_for_token_ms = 230; let test_cases = vec![
236 (0, StatusCode::OK),
238 (0, StatusCode::OK),
239 (0, StatusCode::OK),
240 (0, StatusCode::OK),
241 (0, StatusCode::OK),
242 (0, StatusCode::TOO_MANY_REQUESTS),
244 (delay_for_token_ms, StatusCode::OK),
246 (0, StatusCode::TOO_MANY_REQUESTS),
248 (2 * delay_for_token_ms, StatusCode::OK),
250 (0, StatusCode::OK),
251 (0, StatusCode::TOO_MANY_REQUESTS),
253 (5 * delay_for_token_ms, StatusCode::OK),
255 (0, StatusCode::OK),
256 (0, StatusCode::OK),
257 (0, StatusCode::OK),
258 (0, StatusCode::OK),
259 (0, StatusCode::TOO_MANY_REQUESTS),
261 (0, StatusCode::TOO_MANY_REQUESTS),
262 ];
263
264 for (idx, (delay_ms, expected_status)) in test_cases.into_iter().enumerate() {
266 if delay_ms > 0 {
267 sleep(Duration::from_millis(delay_ms)).await;
268 }
269 let result = send_request(&mut app).await.unwrap();
270 assert_eq!(result.status(), expected_status, "test {idx} failed");
271 }
272 }
273
274 #[tokio::test]
275 async fn test_rate_limiter_returns_server_error() {
276 let rps = 1;
277 let burst_size = 1;
278
279 let rate_limiter_mw = layer(
280 rps,
281 burst_size,
282 IpKeyExtractor,
283 (StatusCode::TOO_MANY_REQUESTS, "foo"),
284 None,
285 )
286 .expect("failed to build middleware");
287
288 let mut app = Router::new()
289 .route("/", post(handler))
290 .layer(rate_limiter_mw);
291
292 let request = Request::post("/").body(Body::from("".to_string())).unwrap();
294 let result = app.call(request).await.unwrap();
295
296 assert_eq!(result.status(), StatusCode::INTERNAL_SERVER_ERROR);
297 let body = to_bytes(result.into_body(), 1024).await.unwrap().to_vec();
298 assert_eq!(body, b"Unable to extract rate limiting key");
299 }
300
301 #[tokio::test]
302 async fn test_rate_limiter_bypass_token() {
303 let rate_limiter_mw = layer(
304 1,
305 10,
306 GlobalKeyExtractor,
307 (StatusCode::TOO_MANY_REQUESTS, "foo"),
308 Some("top_secret_token".into()),
309 )
310 .expect("failed to build middleware");
311
312 let mut app = Router::new()
313 .route("/", post(handler))
314 .layer(rate_limiter_mw);
315
316 for _ in 0..10 {
318 let req = Request::builder()
319 .method(Method::POST)
320 .body(Body::empty())
321 .unwrap();
322 let res = app.call(req).await.unwrap();
323 assert_eq!(res.status(), StatusCode::OK);
324 }
325
326 for _ in 0..100 {
328 let req = Request::builder()
329 .method(Method::POST)
330 .body(Body::empty())
331 .unwrap();
332 let res = app.call(req).await.unwrap();
333 assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
334 }
335
336 for _ in 0..100 {
338 let req = Request::builder()
339 .method(Method::POST)
340 .header(BYPASS_HEADER, "top_secret_token")
341 .body(Body::empty())
342 .unwrap();
343 let res = app.call(req).await.unwrap();
344 assert_eq!(res.status(), StatusCode::OK);
345 }
346
347 for _ in 0..100 {
349 let req = Request::builder()
350 .method(Method::POST)
351 .header(BYPASS_HEADER, "not_very_secret")
352 .body(Body::empty())
353 .unwrap();
354 let res = app.call(req).await.unwrap();
355 assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
356 }
357 }
358}