use std::{
net::IpAddr,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use ::governor::{clock::QuantaInstant, middleware::NoOpMiddleware};
use anyhow::{Error, anyhow};
use axum::{body::Body, extract::Request, response::IntoResponse, response::Response};
use bytes::Bytes;
use futures::future::BoxFuture;
use http::{HeaderName, HeaderValue, StatusCode};
use tower::{Layer, Service};
use tower_governor::{
GovernorError, GovernorLayer,
governor::{Governor, GovernorConfig, GovernorConfigBuilder},
key_extractor::{GlobalKeyExtractor, KeyExtractor},
};
use crate::{hname, http::middleware::extract_ip_from_request};
pub type GovernorLayerAxum<K> = GovernorLayer<K, NoOpMiddleware<QuantaInstant>, Body>;
const BYPASS_HEADER: HeaderName = hname!("x-ratelimit-bypass-token");
#[derive(Clone)]
pub struct IpKeyExtractor;
impl KeyExtractor for IpKeyExtractor {
type Key = IpAddr;
fn extract<B>(&self, req: &Request<B>) -> Result<Self::Key, GovernorError> {
extract_ip_from_request(req).ok_or(GovernorError::UnableToExtractKey)
}
}
#[derive(Clone)]
pub struct RateLimiter<S, K: KeyExtractor> {
governor: Governor<K, NoOpMiddleware<QuantaInstant>, S, Body>,
bypass_token: Option<String>,
inner: S,
}
impl<S, K> Service<Request> for RateLimiter<S, K>
where
S: Service<Request, Response = Response> + Send + 'static,
S::Future: Send + 'static,
K: KeyExtractor,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request) -> Self::Future {
let bypass = request
.headers()
.get(BYPASS_HEADER)
.zip(self.bypass_token.as_ref())
.map(|(hdr, token)| hdr.as_bytes() == token.as_bytes())
== Some(true);
if bypass {
let fut = self.inner.call(request);
return Box::pin(fut);
}
let fut = self.governor.call(request);
Box::pin(fut)
}
}
#[derive(Clone, derive_new::new)]
pub struct RateLimiterLayer<K: KeyExtractor, R> {
config: Arc<GovernorConfig<K, NoOpMiddleware<QuantaInstant>>>,
rate_limited_response: R,
bypass_token: Option<String>,
}
impl<S, K, R> Layer<S> for RateLimiterLayer<K, R>
where
S: Clone,
K: KeyExtractor,
R: IntoResponse + Clone + Send + Sync + 'static,
{
type Service = RateLimiter<S, K>;
fn layer(&self, inner: S) -> Self::Service {
let rate_limited_response = self.rate_limited_response.clone();
let governor = Governor::new(inner.clone(), &self.config).error_handler(move |err| {
match err {
GovernorError::TooManyRequests { wait_time, headers: _ } => {
let mut response = rate_limited_response.clone().into_response();
let retry_secs = ((wait_time / 1000).max(1)) as u32;
let header_value = HeaderValue::from_maybe_shared(Bytes::from(retry_secs.to_string())).unwrap();
response.headers_mut().insert(http::header::RETRY_AFTER, header_value);
response
},
GovernorError::UnableToExtractKey => (
StatusCode::INTERNAL_SERVER_ERROR,
"Unable to extract rate limiting key",
)
.into_response(),
GovernorError::Other { code, msg, headers } => (
StatusCode::INTERNAL_SERVER_ERROR,
format!(
"Rate limiter failed unexpectedly: code={code}, msg={msg:?}, headers={headers:?}"
),
)
.into_response()
}
});
RateLimiter {
governor,
bypass_token: self.bypass_token.clone(),
inner,
}
}
}
pub fn layer_global<R: IntoResponse + Clone + Send + Sync + 'static>(
rps: u32,
burst_size: u32,
rate_limited_response: R,
bypass_token: Option<String>,
) -> Result<RateLimiterLayer<GlobalKeyExtractor, R>, Error> {
layer(
rps,
burst_size,
GlobalKeyExtractor,
rate_limited_response,
bypass_token,
)
}
pub fn layer_by_ip<R: IntoResponse + Clone + Send + Sync + 'static>(
rps: u32,
burst_size: u32,
rate_limited_response: R,
bypass_token: Option<String>,
) -> Result<RateLimiterLayer<IpKeyExtractor, R>, Error> {
layer(
rps,
burst_size,
IpKeyExtractor,
rate_limited_response,
bypass_token,
)
}
pub fn layer<K: KeyExtractor, R: IntoResponse + Clone + Send + Sync + 'static>(
rps: u32,
burst_size: u32,
key_extractor: K,
rate_limited_response: R,
bypass_token: Option<String>,
) -> Result<RateLimiterLayer<K, R>, Error> {
let period = Duration::from_secs(1)
.checked_div(rps)
.ok_or_else(|| anyhow!("RPS is zero"))?;
let config = GovernorConfigBuilder::default()
.period(period)
.burst_size(burst_size)
.key_extractor(key_extractor)
.finish()
.ok_or_else(|| anyhow!("unable to build governor config"))?;
Ok(RateLimiterLayer::new(
Arc::new(config),
rate_limited_response,
bypass_token,
))
}
#[cfg(test)]
mod test {
use super::*;
use axum::{
Router,
body::{Body, to_bytes},
extract::Request,
response::IntoResponse,
routing::post,
};
use http::{Method, StatusCode};
use ic_bn_lib_common::types::http::ConnInfo;
use std::{sync::Arc, time::Duration};
use tokio::time::sleep;
use tower::Service;
async fn handler(_request: Request<Body>) -> impl IntoResponse {
"test_call"
}
async fn send_request(
router: &mut Router,
) -> Result<http::Response<Body>, std::convert::Infallible> {
let conn_info = ConnInfo::default();
let mut request = Request::post("/").body(Body::from("".to_string())).unwrap();
request.extensions_mut().insert(Arc::new(conn_info));
router.call(request).await
}
#[tokio::test]
async fn test_rate_limiter_rps_limit() {
let rps = 5;
let burst_size = 5;
let rate_limiter_mw = layer(
rps,
burst_size,
IpKeyExtractor,
(StatusCode::TOO_MANY_REQUESTS, "foo"),
None,
)
.expect("failed to build middleware");
let mut app = Router::new()
.route("/", post(handler))
.layer(rate_limiter_mw);
let delay_for_token_ms = 230; let test_cases = vec![
(0, StatusCode::OK),
(0, StatusCode::OK),
(0, StatusCode::OK),
(0, StatusCode::OK),
(0, StatusCode::OK),
(0, StatusCode::TOO_MANY_REQUESTS),
(delay_for_token_ms, StatusCode::OK),
(0, StatusCode::TOO_MANY_REQUESTS),
(2 * delay_for_token_ms, StatusCode::OK),
(0, StatusCode::OK),
(0, StatusCode::TOO_MANY_REQUESTS),
(5 * delay_for_token_ms, StatusCode::OK),
(0, StatusCode::OK),
(0, StatusCode::OK),
(0, StatusCode::OK),
(0, StatusCode::OK),
(0, StatusCode::TOO_MANY_REQUESTS),
(0, StatusCode::TOO_MANY_REQUESTS),
];
for (idx, (delay_ms, expected_status)) in test_cases.into_iter().enumerate() {
if delay_ms > 0 {
sleep(Duration::from_millis(delay_ms)).await;
}
let result = send_request(&mut app).await.unwrap();
assert_eq!(result.status(), expected_status, "test {idx} failed");
if expected_status == StatusCode::TOO_MANY_REQUESTS {
let retry_after = result.headers().get(http::header::RETRY_AFTER);
assert!(
retry_after.is_some(),
"test {idx}: Retry-After header missing on 429 response"
);
if let Some(header_value) = retry_after {
let retry_secs: u32 = header_value.to_str().unwrap().parse().unwrap();
assert!(
(1..=10).contains(&retry_secs),
"test {idx}: Retry-After value {retry_secs} is outside expected range [1, 10]"
);
}
}
}
}
#[tokio::test]
async fn test_rate_limiter_returns_server_error() {
let rps = 1;
let burst_size = 1;
let rate_limiter_mw = layer(
rps,
burst_size,
IpKeyExtractor,
(StatusCode::TOO_MANY_REQUESTS, "foo"),
None,
)
.expect("failed to build middleware");
let mut app = Router::new()
.route("/", post(handler))
.layer(rate_limiter_mw);
let request = Request::post("/").body(Body::from("".to_string())).unwrap();
let result = app.call(request).await.unwrap();
assert_eq!(result.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body = to_bytes(result.into_body(), 1024).await.unwrap().to_vec();
assert_eq!(body, b"Unable to extract rate limiting key");
}
#[tokio::test]
async fn test_rate_limiter_bypass_token() {
let rate_limiter_mw = layer(
1,
10,
GlobalKeyExtractor,
(StatusCode::TOO_MANY_REQUESTS, "foo"),
Some("top_secret_token".into()),
)
.expect("failed to build middleware");
let mut app = Router::new()
.route("/", post(handler))
.layer(rate_limiter_mw);
for _ in 0..10 {
let req = Request::builder()
.method(Method::POST)
.body(Body::empty())
.unwrap();
let res = app.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
for _ in 0..100 {
let req = Request::builder()
.method(Method::POST)
.body(Body::empty())
.unwrap();
let res = app.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
}
for _ in 0..100 {
let req = Request::builder()
.method(Method::POST)
.header(BYPASS_HEADER, "top_secret_token")
.body(Body::empty())
.unwrap();
let res = app.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
for _ in 0..100 {
let req = Request::builder()
.method(Method::POST)
.header(BYPASS_HEADER, "not_very_secret")
.body(Body::empty())
.unwrap();
let res = app.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
}
}
}