use super::{ClientIp, Error, LOG_TARGET};
use axum::extract::Request;
use axum::extract::State;
use axum::middleware::Next;
use axum::response::IntoResponse;
use axum::response::Response;
use scopeguard::defer;
use std::net::IpAddr;
use std::time::Duration;
use tibba_cache::RedisCache;
use tibba_state::{AppState, CTX};
use tracing::debug;
type Result<T> = std::result::Result<T, tibba_error::Error>;
pub async fn processing_limit(
State(state): State<&AppState>,
req: Request,
next: Next,
) -> Result<impl IntoResponse> {
debug!(target: LOG_TARGET, "--> processing_limit");
defer!(debug!(target: LOG_TARGET, "<-- processing_limit"););
let limit = state.get_processing_limit();
if limit < 0 {
let res = next.run(req).await;
if res.status().as_u16() >= 400 {
state.inc_error_requests();
}
return Ok(res);
}
let count = state.inc_processing();
defer!(state.dec_processing(););
if count > limit {
state.inc_error_requests();
return Err(Error::TooManyRequests {
limit: limit as i64,
current: count as i64,
}
.into());
}
let res = next.run(req).await;
if res.status().as_u16() >= 400 {
state.inc_error_requests();
}
Ok(res)
}
#[derive(Debug, Clone, Default)]
pub enum LimitType {
#[default]
Ip, Header(String), Account, }
#[derive(Debug, Clone, Default)]
pub struct LimitParams {
limit_type: LimitType, category: String, max: i64, ttl: Duration, }
impl LimitParams {
pub fn new(max: i64) -> Self {
Self {
limit_type: LimitType::Ip,
max,
ttl: Duration::from_secs(5 * 60),
..Default::default()
}
}
#[must_use]
pub fn with_category(mut self, category: impl Into<String>) -> Self {
self.category = category.into();
self
}
#[must_use]
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = ttl;
self
}
#[must_use]
pub fn with_limit_type(mut self, limit_type: LimitType) -> Self {
self.limit_type = limit_type;
self
}
}
fn get_limit_params(req: &Request, ip: IpAddr, params: &LimitParams) -> (String, Duration) {
let identifier = match ¶ms.limit_type {
LimitType::Header(header_name) => req
.headers()
.get(header_name)
.and_then(|value| value.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| ip.to_string()),
LimitType::Account => {
let account = CTX.get().get_account();
if account.is_empty() {
ip.to_string()
} else {
account.to_string()
}
}
LimitType::Ip => ip.to_string(),
};
let key = if params.category.is_empty() {
identifier
} else {
format!("{}:{}", params.category, identifier)
};
let ttl = if params.ttl.is_zero() {
Duration::from_secs(5 * 60)
} else {
params.ttl
};
(key, ttl)
}
pub async fn error_limiter(
ClientIp(ip): ClientIp,
State(params): State<LimitParams>,
State(cache): State<&'static RedisCache>,
req: Request,
next: Next,
) -> Result<Response> {
let (key, ttl) = get_limit_params(&req, ip, ¶ms);
let current_count = cache.get::<i64>(&key).await.unwrap_or(0);
if current_count > params.max {
return Err(Error::TooManyRequests {
limit: params.max,
current: current_count,
}
.into());
}
let res = next.run(req).await;
if res.status().as_u16() >= 400 {
let _ = cache.incr(&key, 1, Some(ttl)).await;
}
Ok(res)
}
pub async fn limiter(
ClientIp(ip): ClientIp,
State(params): State<LimitParams>,
State(cache): State<&'static RedisCache>,
req: Request,
next: Next,
) -> Result<Response> {
let (key, ttl) = get_limit_params(&req, ip, ¶ms);
let count = cache.incr(&key, 1, Some(ttl)).await?;
if count > params.max {
return Err(Error::TooManyRequests {
limit: params.max,
current: count,
}
.into());
}
Ok(next.run(req).await)
}