use std::{
future::Future,
marker::PhantomData,
num::NonZeroU32,
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
time::Duration,
};
use axum::{
body::Body,
http::{Request, Response, StatusCode},
response::IntoResponse,
};
use dashmap::DashMap;
use governor::{
clock::{Clock, DefaultClock, QuantaClock, QuantaInstant},
middleware::NoOpMiddleware,
state::{InMemoryState, NotKeyed},
Quota, RateLimiter,
};
use pin_project::pin_project;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tower::{BoxError, Layer, Service};
use tracing::{trace_span, warn};
use crate::{
errors,
layers::util::{
ExtractionError, KeyExtractor, PeerIpKeyExtractor, SmartIpKeyExtractor, UserIdKeyExtractor,
},
};
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum RateLimitError {
#[error(transparent)]
Extraction(#[from] ExtractionError),
#[error("Rate limit reached: available after {remaining_seconds} seconds")]
LimitReached {
remaining_seconds: u64,
},
}
impl RateLimitError {
fn http_status(&self) -> StatusCode {
match self {
Self::LimitReached { .. } => StatusCode::TOO_MANY_REQUESTS,
_ => StatusCode::BAD_REQUEST,
}
}
}
impl IntoResponse for RateLimitError {
fn into_response(self) -> Response<Body> {
let mut resp = problemdetails::new(self.http_status())
.with_type(errors::TAG_UXUM_RATE_LIMIT)
.with_title(self.to_string());
if let Self::LimitReached { remaining_seconds } = self {
resp = resp.with_value("retry_after", remaining_seconds);
}
resp.into_response()
}
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[non_exhaustive]
pub struct HandlerRateLimitConfig {
#[serde(default)]
key: RateLimitKey,
rps: NonZeroU32,
#[serde(default, skip_serializing_if = "Option::is_none")]
burst_rps: Option<NonZeroU32>,
#[serde(
default = "HandlerRateLimitConfig::default_burst_duration",
with = "humantime_serde"
)]
burst_duration: Duration,
}
impl HandlerRateLimitConfig {
#[must_use]
#[inline]
fn default_burst_duration() -> Duration {
Duration::from_secs(1)
}
pub fn burst_size(&self) -> NonZeroU32 {
let rps = match self.burst_rps {
Some(rps) if rps > self.rps => rps,
_ => self.rps,
};
let seconds = self.burst_duration.as_secs_f64();
NonZeroU32::new((seconds * f64::from(rps.get())).ceil() as u32).unwrap_or(self.rps)
}
pub fn period(&self) -> Duration {
Duration::from_secs(1) / self.rps.get()
}
pub fn make_layer<S, T>(&self) -> RateLimitLayer<S, T> {
self.into()
}
}
#[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Serialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
enum RateLimitKey {
#[default]
Global,
PeerIp,
SmartIp,
UserId,
}
pub struct RateLimitLayer<S, T> {
config: HandlerRateLimitConfig,
_phantom_service: PhantomData<S>,
_phantom_request: PhantomData<T>,
}
impl<S, T> From<&HandlerRateLimitConfig> for RateLimitLayer<S, T> {
fn from(value: &HandlerRateLimitConfig) -> Self {
Self {
config: value.clone(),
_phantom_service: PhantomData,
_phantom_request: PhantomData,
}
}
}
impl<S, T> Layer<S> for RateLimitLayer<S, T>
where
S: Service<Request<T>> + Send + 'static,
T: Send + 'static,
{
type Service = RateLimit<S, T>;
fn layer(&self, service: S) -> Self::Service {
RateLimit::new(service, &self.config)
}
}
pub struct RateLimit<S, T> {
inner: S,
limiter: Arc<Box<dyn Limiter<T> + Send + Sync>>,
}
impl<S, T> Clone for RateLimit<S, T>
where
S: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
limiter: Arc::clone(&self.limiter),
}
}
}
impl<S, T> Service<Request<T>> for RateLimit<S, T>
where
S: Service<Request<T>>,
S::Error: Into<BoxError>,
{
type Response = S::Response;
type Error = BoxError;
type Future = RateLimitFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.inner.poll_ready(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(res) => Poll::Ready(res.map_err(Into::into)),
}
}
fn call(&mut self, req: Request<T>) -> Self::Future {
let rate_result = {
let _span = trace_span!("rate").entered();
self.limiter.check_limit(&req)
};
match rate_result {
Ok(()) => RateLimitFuture::Positive {
inner: self.inner.call(req),
},
Err(error) => {
if let RateLimitError::LimitReached { remaining_seconds } = &error {
warn!(wait = remaining_seconds, "rate limit exceeded");
}
RateLimitFuture::Negative { error }
}
}
}
}
impl<S, T> RateLimit<S, T>
where
S: Service<Request<T>> + Send + 'static,
T: Send + 'static,
{
#[must_use]
pub fn new(inner: S, config: &HandlerRateLimitConfig) -> Self {
let limiter: Box<dyn Limiter<T> + Send + Sync> = match config.key {
RateLimitKey::Global => Box::new(GlobalLimiter::new(config)),
RateLimitKey::PeerIp => Box::new(KeyedLimiter::new(PeerIpKeyExtractor, config)),
RateLimitKey::SmartIp => Box::new(KeyedLimiter::new(SmartIpKeyExtractor, config)),
RateLimitKey::UserId => Box::new(KeyedLimiter::new(UserIdKeyExtractor, config)),
};
Self {
inner,
limiter: Arc::new(limiter),
}
}
}
#[pin_project(project = ProjectedOutcome)]
pub enum RateLimitFuture<F> {
Positive {
#[pin]
inner: F,
},
Negative {
error: RateLimitError,
},
}
impl<F, U, E> Future for RateLimitFuture<F>
where
F: Future<Output = Result<U, E>>,
E: Into<BoxError>,
{
type Output = Result<U, BoxError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project() {
ProjectedOutcome::Positive { inner } => {
let resp = ready!(inner.poll(cx).map_err(Into::into))?;
Poll::Ready(Ok(resp))
}
ProjectedOutcome::Negative { error } => Poll::Ready(Err(Box::new(error.clone()))),
}
}
}
trait Limiter<T> {
fn check_limit(&self, req: &Request<T>) -> Result<(), RateLimitError>;
}
struct GlobalLimiter {
limiter: RateLimiter<NotKeyed, InMemoryState, QuantaClock, NoOpMiddleware<QuantaInstant>>,
}
impl<T> Limiter<T> for GlobalLimiter {
fn check_limit(&self, _req: &Request<T>) -> Result<(), RateLimitError> {
self.limiter.check().map_err(|neg| {
let remaining_seconds = neg.wait_time_from(DefaultClock::default().now()).as_secs();
RateLimitError::LimitReached { remaining_seconds }
})
}
}
impl GlobalLimiter {
#[must_use]
fn new(config: &HandlerRateLimitConfig) -> Self {
Self {
limiter: RateLimiter::direct(
Quota::with_period(config.period())
.unwrap()
.allow_burst(config.burst_size()),
),
}
}
}
struct KeyedLimiter<K: KeyExtractor> {
extractor: K,
limiters: RateLimiter<
K::Key,
DashMap<K::Key, InMemoryState>,
QuantaClock,
NoOpMiddleware<QuantaInstant>,
>,
}
impl<T, K: KeyExtractor> Limiter<T> for KeyedLimiter<K> {
fn check_limit(&self, req: &Request<T>) -> Result<(), RateLimitError> {
let key = self.extractor.extract(req)?;
self.limiters.check_key(&key).map_err(|neg| {
let remaining_seconds = neg.wait_time_from(DefaultClock::default().now()).as_secs();
RateLimitError::LimitReached { remaining_seconds }
})
}
}
impl<K: KeyExtractor> KeyedLimiter<K> {
#[must_use]
fn new(extractor: K, config: &HandlerRateLimitConfig) -> Self {
Self {
extractor,
limiters: RateLimiter::keyed(
Quota::with_period(config.period())
.unwrap()
.allow_burst(config.burst_size()),
),
}
}
}