use std::sync::Arc;
use axum::{
body::Body,
extract::{MatchedPath, Request},
middleware::Next,
response::{IntoResponse, Response},
};
use crate::{
resil::{RedisPeriodLimiter, RedisTokenLimiter},
rest::{
RestError, RestRateLimiterConfig,
middleware::{MiddlewareMetrics, record_resilience_event},
},
};
#[derive(Debug, Clone)]
pub struct RestRateLimiter {
mode: Arc<RestRateLimiterMode>,
}
#[derive(Debug)]
enum RestRateLimiterMode {
Disabled,
Token {
limiter: Result<RedisTokenLimiter, String>,
fail_open: bool,
},
Period {
limiter: Result<RedisPeriodLimiter, String>,
fail_open: bool,
},
}
impl RestRateLimiter {
pub fn new(config: RestRateLimiterConfig) -> Self {
let mode = match config {
RestRateLimiterConfig::Disabled => RestRateLimiterMode::Disabled,
RestRateLimiterConfig::RedisToken(config) => {
let fail_open = config.fail_open;
RestRateLimiterMode::Token {
limiter: RedisTokenLimiter::new(config).map_err(|error| error.to_string()),
fail_open,
}
}
RestRateLimiterConfig::RedisPeriod(config) => {
let fail_open = config.fail_open;
RestRateLimiterMode::Period {
limiter: RedisPeriodLimiter::new(config).map_err(|error| error.to_string()),
fail_open,
}
}
};
Self {
mode: Arc::new(mode),
}
}
pub fn is_disabled(&self) -> bool {
matches!(self.mode.as_ref(), RestRateLimiterMode::Disabled)
}
async fn allow(&self, key: &str) -> RateLimitOutcome {
match self.mode.as_ref() {
RestRateLimiterMode::Disabled => RateLimitOutcome::Allowed,
RestRateLimiterMode::Token { limiter, fail_open } => {
allow_token(limiter, *fail_open, key).await
}
RestRateLimiterMode::Period { limiter, fail_open } => {
allow_period(limiter, *fail_open, key).await
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum RateLimitOutcome {
Allowed,
Rejected,
ErrorOpen,
ErrorClosed(String),
}
async fn allow_token(
limiter: &Result<RedisTokenLimiter, String>,
fail_open: bool,
key: &str,
) -> RateLimitOutcome {
match limiter {
Ok(limiter) => match limiter.try_allow_n(key, 1).await {
Ok(true) => RateLimitOutcome::Allowed,
Ok(false) => RateLimitOutcome::Rejected,
Err(error) if fail_open => {
let _ = error;
RateLimitOutcome::ErrorOpen
}
Err(error) => RateLimitOutcome::ErrorClosed(error.to_string()),
},
Err(error) if fail_open => {
let _ = error;
RateLimitOutcome::ErrorOpen
}
Err(error) => RateLimitOutcome::ErrorClosed(error.clone()),
}
}
async fn allow_period(
limiter: &Result<RedisPeriodLimiter, String>,
fail_open: bool,
key: &str,
) -> RateLimitOutcome {
match limiter {
Ok(limiter) => match limiter.try_allow(key).await {
Ok(true) => RateLimitOutcome::Allowed,
Ok(false) => RateLimitOutcome::Rejected,
Err(error) if fail_open => {
let _ = error;
RateLimitOutcome::ErrorOpen
}
Err(error) => RateLimitOutcome::ErrorClosed(error.to_string()),
},
Err(error) if fail_open => {
let _ = error;
RateLimitOutcome::ErrorOpen
}
Err(error) => RateLimitOutcome::ErrorClosed(error.clone()),
}
}
pub async fn rate_limiter_middleware(
limiter: RestRateLimiter,
metrics: MiddlewareMetrics,
request: Request<Body>,
next: Next,
) -> Response {
let key = limiter_key(&request);
match limiter.allow(&key).await {
RateLimitOutcome::Allowed => {
record_resilience_event(&metrics, "limiter", "allowed");
next.run(request).await
}
RateLimitOutcome::Rejected => {
record_resilience_event(&metrics, "limiter", "rejected");
RestError::RateLimited.into_response()
}
RateLimitOutcome::ErrorOpen => {
record_resilience_event(&metrics, "limiter", "error_open");
next.run(request).await
}
RateLimitOutcome::ErrorClosed(error) => {
record_resilience_event(&metrics, "limiter", "error_closed");
RestError::ServiceUnavailable(format!("rate limiter unavailable: {error}"))
.into_response()
}
}
}
fn limiter_key(request: &Request<Body>) -> String {
let method = request.method().as_str();
let route = request
.extensions()
.get::<MatchedPath>()
.map(|path| path.as_str())
.unwrap_or("unknown");
format!("{method}:{route}")
}