use std::{
collections::HashMap,
future::Future,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
time::{Duration, Instant},
};
use axum::{body::Body, extract::Request};
use http::{HeaderValue, Response, StatusCode};
use tower::{Layer, Service};
use crate::context::{IdentityExtractor, RequestIdentity};
type IdentityFn =
Arc<dyn Fn(&http::request::Parts) -> Option<RequestIdentity> + Send + Sync + 'static>;
#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
#[error("{message}")]
pub struct RateLimitError {
message: String,
}
impl RateLimitError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct RateLimitDecision {
pub allowed: bool,
pub limit: u64,
pub remaining: u64,
pub reset_after: Duration,
}
pub trait RateLimitStore: Send + Sync + 'static {
fn check(
&self,
identity: &RequestIdentity,
limit: u64,
window: Duration,
) -> Result<RateLimitDecision, RateLimitError>;
}
#[derive(Clone, Default)]
pub struct InMemoryRateLimitStore {
state: Arc<Mutex<HashMap<String, WindowState>>>,
}
impl InMemoryRateLimitStore {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.state
.lock()
.map(|state| state.len())
.unwrap_or_default()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl RateLimitStore for InMemoryRateLimitStore {
fn check(
&self,
identity: &RequestIdentity,
limit: u64,
window: Duration,
) -> Result<RateLimitDecision, RateLimitError> {
let now = Instant::now();
let mut state = self
.state
.lock()
.map_err(|_| RateLimitError::new("rate limit store poisoned"))?;
state.retain(|_, window_state| now.duration_since(window_state.started_at) < window);
let window_state = state
.entry(identity.as_str().to_owned())
.or_insert_with(|| WindowState {
started_at: now,
count: 0,
});
if now.duration_since(window_state.started_at) >= window {
window_state.started_at = now;
window_state.count = 0;
}
let allowed = window_state.count < limit;
if allowed {
window_state.count += 1;
}
let remaining = limit.saturating_sub(window_state.count);
let reset_after = window.saturating_sub(now.duration_since(window_state.started_at));
Ok(RateLimitDecision {
allowed,
limit,
remaining,
reset_after,
})
}
}
#[derive(Clone)]
struct WindowState {
started_at: Instant,
count: u64,
}
#[derive(Clone)]
pub struct RateLimitConfig {
limit: u64,
window: Duration,
store: Arc<dyn RateLimitStore>,
identity: IdentityFn,
fail_open: bool,
}
impl RateLimitConfig {
pub fn new(limit: u64, window: Duration, store: impl RateLimitStore) -> Self {
Self {
limit,
window,
store: Arc::new(store),
identity: Arc::new(|_parts| Some(RequestIdentity::new("anonymous"))),
fail_open: true,
}
}
pub fn identity(mut self, extractor: impl IdentityExtractor) -> Self {
self.identity = Arc::new(move |parts| extractor.extract(parts));
self
}
pub fn fail_open(mut self) -> Self {
self.fail_open = true;
self
}
pub fn fail_closed(mut self) -> Self {
self.fail_open = false;
self
}
pub fn layer(self) -> RateLimitLayer {
RateLimitLayer { config: self }
}
}
#[derive(Clone)]
pub struct RateLimitLayer {
config: RateLimitConfig,
}
impl<S> Layer<S> for RateLimitLayer {
type Service = RateLimitService<S>;
fn layer(&self, inner: S) -> Self::Service {
RateLimitService {
inner,
config: self.config.clone(),
}
}
}
#[derive(Clone)]
pub struct RateLimitService<S> {
inner: S,
config: RateLimitConfig,
}
impl<S> Service<Request> for RateLimitService<S>
where
S: Service<Request, Response = Response<Body>> + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request) -> Self::Future {
let config = self.config.clone();
let (parts, body) = request.into_parts();
let identity =
(config.identity)(&parts).unwrap_or_else(|| RequestIdentity::new("anonymous"));
let decision = config
.store
.check(&identity, config.limit, config.window)
.unwrap_or(RateLimitDecision {
allowed: config.fail_open,
limit: config.limit,
remaining: if config.fail_open { config.limit } else { 0 },
reset_after: config.window,
});
if !decision.allowed {
return Box::pin(async move { Ok(rate_limited_response(decision)) });
}
let future = self.inner.call(Request::from_parts(parts, body));
Box::pin(async move {
let mut response = future.await?;
insert_rate_limit_headers(response.headers_mut(), &decision);
Ok(response)
})
}
}
fn rate_limited_response(decision: RateLimitDecision) -> Response<Body> {
let mut response = Response::new(Body::from("rate limit exceeded"));
*response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
insert_rate_limit_headers(response.headers_mut(), &decision);
response.headers_mut().insert(
http::header::RETRY_AFTER,
HeaderValue::from_str(&decision.reset_after.as_secs().max(1).to_string())
.expect("retry-after must be a valid header"),
);
response
}
fn insert_rate_limit_headers(headers: &mut http::HeaderMap, decision: &RateLimitDecision) {
headers.insert(
"ratelimit-limit",
HeaderValue::from_str(&decision.limit.to_string()).expect("limit header must be valid"),
);
headers.insert(
"ratelimit-remaining",
HeaderValue::from_str(&decision.remaining.to_string())
.expect("remaining header must be valid"),
);
headers.insert(
"ratelimit-reset",
HeaderValue::from_str(&decision.reset_after.as_secs().max(1).to_string())
.expect("reset header must be valid"),
);
}