use std::{
collections::HashMap,
net::IpAddr,
sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
},
time::Instant,
};
use axum::http::{Method, Uri};
use crate::api::ApiResponse;
#[derive(Debug, Clone)]
pub struct RequestContext {
pub method: Method,
pub uri: Uri,
pub ip: Option<IpAddr>,
pub user_id: Option<i64>,
pub user_agent: Option<String>,
pub started_at: Instant,
}
impl RequestContext {
pub fn new(method: Method, uri: Uri) -> Self {
Self {
method,
uri,
ip: None,
user_id: None,
user_agent: None,
started_at: Instant::now(),
}
}
pub fn with_ip(mut self, ip: IpAddr) -> Self {
self.ip = Some(ip);
self
}
pub fn with_user_id(mut self, user_id: i64) -> Self {
self.user_id = Some(user_id);
self
}
pub fn with_user_agent(mut self, agent: String) -> Self {
self.user_agent = Some(agent);
self
}
}
pub trait BeforeHandler: Send + Sync {
fn before(&self, cx: &RequestContext) -> Result<(), ApiResponse>;
}
pub trait AfterHandler: Send + Sync {
fn after(&self, cx: &RequestContext, response: ApiResponse) -> ApiResponse;
}
pub struct LogRequest;
impl LogRequest {
pub fn new() -> Self {
Self
}
}
impl BeforeHandler for LogRequest {
fn before(&self, cx: &RequestContext) -> Result<(), ApiResponse> {
#[cfg(feature = "tracing")]
tracing::info!(
method = %cx.method,
path = %cx.uri,
ip = ?cx.ip,
user_agent = ?cx.user_agent,
"incoming request"
);
let _ = cx;
Ok(())
}
}
impl AfterHandler for LogRequest {
fn after(&self, cx: &RequestContext, response: ApiResponse) -> ApiResponse {
#[cfg(feature = "tracing")]
{
let status = response.status_code();
let elapsed = cx.started_at.elapsed();
tracing::info!(
method = %cx.method,
path = %cx.uri,
status = status.as_u16(),
duration_ms = elapsed.as_secs_f64() * 1000.0,
"request completed"
);
}
let _ = cx;
response
}
}
pub struct ThrottleInterceptor {
max_requests: u64,
window_secs: u64,
buckets: Arc<Mutex<HashMap<IpAddr, RateBucket>>>,
}
struct RateBucket {
count: AtomicU64,
window_start: Instant,
}
impl ThrottleInterceptor {
pub fn new(max_requests: u64, window_secs: u64) -> Self {
Self {
max_requests,
window_secs,
buckets: Arc::new(Mutex::new(HashMap::new())),
}
}
}
impl BeforeHandler for ThrottleInterceptor {
fn before(&self, cx: &RequestContext) -> Result<(), ApiResponse> {
let ip = match cx.ip {
Some(ip) => ip,
None => return Ok(()),
};
let buckets = self.buckets.clone();
let max = self.max_requests;
let window = self.window_secs;
let now = Instant::now();
let mut map = buckets.lock().unwrap();
let bucket = map.entry(ip).or_insert_with(|| RateBucket {
count: AtomicU64::new(0),
window_start: now,
});
if now.duration_since(bucket.window_start).as_secs() >= window {
bucket.count.store(1, Ordering::SeqCst);
bucket.window_start = now;
return Ok(());
}
let current = bucket.count.fetch_add(1, Ordering::SeqCst);
if current >= max {
return Err(ApiResponse::error(
"E_RATE_LIMIT_EXCEEDED",
"Too many requests. Please try again later.",
429,
));
}
Ok(())
}
}
pub struct ValidateBody<T> {
_marker: std::marker::PhantomData<T>,
}
impl<T> ValidateBody<T> {
pub fn new() -> Self {
Self {
_marker: std::marker::PhantomData,
}
}
}
impl<T: Send + Sync> BeforeHandler for ValidateBody<T> {
fn before(&self, _cx: &RequestContext) -> Result<(), ApiResponse> {
Ok(())
}
}
pub struct HookChain {
before_hooks: Vec<Arc<dyn BeforeHandler>>,
after_hooks: Vec<Arc<dyn AfterHandler>>,
}
impl HookChain {
pub fn new() -> Self {
Self {
before_hooks: Vec::new(),
after_hooks: Vec::new(),
}
}
pub fn push_before(mut self, hook: impl BeforeHandler + 'static) -> Self {
self.before_hooks.push(Arc::new(hook));
self
}
pub fn push_after(mut self, hook: impl AfterHandler + 'static) -> Self {
self.after_hooks.push(Arc::new(hook));
self
}
pub fn run_before(&self, cx: &RequestContext) -> Result<(), ApiResponse> {
for hook in &self.before_hooks {
hook.before(cx)?;
}
Ok(())
}
pub fn run_after(&self, cx: &RequestContext, response: ApiResponse) -> ApiResponse {
let mut response = response;
for hook in self.after_hooks.iter().rev() {
response = hook.after(cx, response);
}
response
}
}
impl Default for HookChain {
fn default() -> Self {
Self::new()
}
}