use super::{ClientIp, Error};
use axum::extract::Request;
use axum::extract::State;
use axum::middleware::Next;
use axum::response::IntoResponse;
use axum::response::Response;
use scopeguard::defer;
use std::net::IpAddr;
use std::time::Duration;
use tibba_cache::RedisCache;
use tibba_state::AppState;
use tracing::debug;
type Result<T> = std::result::Result<T, tibba_error::Error>;
pub async fn processing_limit(
State(state): State<&AppState>,
req: Request,
next: Next,
) -> Result<impl IntoResponse> {
debug!(category = "middleware", "--> processing_limit");
defer!(debug!(category = "middleware", "<-- processing_limit"););
let limit = state.get_processing_limit();
if limit < 0 {
let res = next.run(req).await;
return Ok(res);
}
let count = state.inc_processing() + 1;
defer!(state.dec_processing(););
if count > limit {
return Err(Error::TooManyRequests {
limit: limit as i64,
current: count as i64,
}
.into());
}
let res = next.run(req).await;
Ok(res)
}
#[derive(Debug, Clone, Default)]
pub enum LimitType {
#[default]
Ip, Header(String), }
#[derive(Debug, Clone, Default)]
pub struct LimitParams {
pub limit_type: LimitType, pub category: String, pub max: i64, pub ttl: Duration, }
impl LimitParams {
pub fn new(max: i64, secs: u64, category: &str) -> Self {
LimitParams {
limit_type: LimitType::Ip,
category: category.to_string(),
max,
ttl: Duration::from_secs(secs),
}
}
}
fn get_limit_params(req: &Request, ip: IpAddr, params: &LimitParams) -> (String, Duration) {
let identifier = match ¶ms.limit_type {
LimitType::Header(header_name) => req
.headers()
.get(header_name)
.and_then(|value| value.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| ip.to_string()),
_ => ip.to_string(),
};
let key = if params.category.is_empty() {
identifier
} else {
format!("{}:{}", params.category, identifier)
};
let ttl = if params.ttl.is_zero() {
Duration::from_secs(5 * 60)
} else {
params.ttl
};
(key, ttl)
}
pub async fn error_limiter(
ClientIp(ip): ClientIp,
State(params): State<LimitParams>,
State(cache): State<&'static RedisCache>,
req: Request,
next: Next,
) -> Result<Response> {
let (key, ttl) = get_limit_params(&req, ip, ¶ms);
let current_count = cache.get::<i64>(&key).await.unwrap_or(0);
if current_count > params.max {
return Err(Error::TooManyRequests {
limit: params.max,
current: current_count,
}
.into());
}
let res = next.run(req).await;
if res.status().as_u16() >= 400 {
let _ = cache.incr(&key, 1, Some(ttl)).await;
}
Ok(res)
}
pub async fn limiter(
ClientIp(ip): ClientIp,
State(params): State<LimitParams>,
State(cache): State<&'static RedisCache>,
req: Request,
next: Next,
) -> Result<Response> {
let (key, ttl) = get_limit_params(&req, ip, ¶ms);
let count = cache.incr(&key, 1, Some(ttl)).await?;
if count > params.max {
return Err(Error::TooManyRequests {
limit: params.max,
current: count,
}
.into());
}
Ok(next.run(req).await)
}