use std::time::Instant;
use axum::{
extract::Request,
http::{header, Method},
middleware::Next,
response::Response,
};
use dashmap::DashMap;
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_requests: u32,
pub window_secs: u64,
pub enabled: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 100,
window_secs: 60,
enabled: false,
}
}
}
pub struct RateLimiter {
config: RateLimitConfig,
state: DashMap<String, (u32, Instant)>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
state: DashMap::new(),
}
}
pub fn check(&self, ip: &str) -> bool {
if !self.config.enabled {
return true;
}
let now = Instant::now();
let window = std::time::Duration::from_secs(self.config.window_secs);
let mut entry = self.state.entry(ip.to_string()).or_insert((0, now));
let (count, window_start) = entry.value_mut();
if now.duration_since(*window_start) > window {
*count = 1;
*window_start = now;
true
} else if *count < self.config.max_requests {
*count += 1;
true
} else {
false
}
}
}
pub async fn request_logging(request: Request, next: Next) -> Response {
let method = request.method().clone();
let uri = request.uri().clone();
let start = Instant::now();
let response = next.run(request).await;
let elapsed = start.elapsed();
tracing::info!(
method = %method,
uri = %uri,
status = %response.status(),
latency_ms = elapsed.as_millis(),
"HTTP request"
);
response
}
pub fn cors_headers() -> tower_http::cors::CorsLayer {
tower_http::cors::CorsLayer::new()
.allow_origin(tower_http::cors::Any)
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
.allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rate_limiter_allows_under_limit() {
let limiter = RateLimiter::new(RateLimitConfig {
max_requests: 5,
window_secs: 60,
enabled: true,
});
for _ in 0..5 {
assert!(limiter.check("127.0.0.1"));
}
}
#[test]
fn rate_limiter_blocks_over_limit() {
let limiter = RateLimiter::new(RateLimitConfig {
max_requests: 2,
window_secs: 60,
enabled: true,
});
assert!(limiter.check("127.0.0.1"));
assert!(limiter.check("127.0.0.1"));
assert!(!limiter.check("127.0.0.1"));
}
#[test]
fn rate_limiter_disabled_allows_all() {
let limiter = RateLimiter::new(RateLimitConfig::default());
for _ in 0..1000 {
assert!(limiter.check("127.0.0.1"));
}
}
#[test]
fn rate_limiter_separate_ips() {
let limiter = RateLimiter::new(RateLimitConfig {
max_requests: 1,
window_secs: 60,
enabled: true,
});
assert!(limiter.check("1.1.1.1"));
assert!(!limiter.check("1.1.1.1")); assert!(limiter.check("2.2.2.2")); }
}