use crate::{middleware::KeyExtractor, storage::RateLimitStore};
use axum::{
http::{Request, StatusCode},
response::{IntoResponse, Response},
};
use std::{
convert::Infallible,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::Service;
#[derive(Clone)]
pub struct AxumIpExtractor;
impl<B> KeyExtractor<Request<B>> for AxumIpExtractor {
fn extract(&self, req: &Request<B>) -> Option<String> {
req.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
}
#[derive(Clone)]
pub struct AxumRateLimitService<S, Store> {
inner: S,
store: Store,
extractor: AxumIpExtractor,
}
impl<S, Store, B> Service<Request<B>> for AxumRateLimitService<S, Store>
where
S: Service<Request<B>, Response = Response, Error = Infallible> + Send,
S::Future: Send + 'static,
Store: RateLimitStore,
B: Send + 'static,
{
type Response = Response;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
if let Some(key) = self.extractor.extract(&req) {
if !self.store.allow(&key) {
let response =
(StatusCode::TOO_MANY_REQUESTS, "rate limit exceeded").into_response();
return Box::pin(async move { Ok(response) });
}
}
let fut = self.inner.call(req);
Box::pin(async move { fut.await })
}
}
#[derive(Clone)]
pub struct AxumRateLimitLayer<Store> {
store: Store,
}
impl<S, Store: Clone> tower::Layer<S> for AxumRateLimitLayer<Store> {
type Service = AxumRateLimitService<S, Store>;
fn layer(&self, inner: S) -> Self::Service {
AxumRateLimitService {
inner,
store: self.store.clone(),
extractor: AxumIpExtractor,
}
}
}
pub fn axum_rate_limit_layer<S>(store: S) -> AxumRateLimitLayer<S>
where
S: RateLimitStore + Clone,
{
AxumRateLimitLayer { store }
}