use axum::http::{HeaderValue, Method};
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
pub fn cors_layer(origin: Option<&str>) -> CorsLayer {
let layer = CorsLayer::new()
.allow_methods([
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::OPTIONS,
])
.allow_headers(Any)
.allow_credentials(false);
match origin {
Some(o) => {
let origins: Vec<String> = o
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
if origins.is_empty() || origins.iter().any(|s| s == "*") {
layer.allow_origin(Any)
} else if origins.len() == 1 {
if let Ok(val) = origins[0].parse::<HeaderValue>() {
layer
.allow_credentials(true)
.allow_origin(AllowOrigin::exact(val))
} else {
layer.allow_origin(Any)
}
} else {
layer
.allow_credentials(true)
.allow_origin(AllowOrigin::predicate(move |value, _| {
value
.to_str()
.map(|v| origins.iter().any(|o| o == v))
.unwrap_or(false)
}))
}
}
None => layer.allow_origin(Any),
}
}
use axum::extract::ConnectInfo;
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use serde_json::json;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::Mutex;
#[derive(Clone)]
pub struct RateLimiter {
state: Arc<Mutex<HashMap<String, (u64, Instant)>>>,
max_requests: u64,
window_secs: u64,
}
impl RateLimiter {
pub fn new(max_requests: u64, window_secs: u64) -> Self {
Self {
state: Arc::new(Mutex::new(HashMap::new())),
max_requests,
window_secs,
}
}
async fn check(&self, ip: &str) -> bool {
let mut state = self.state.lock().await;
let now = Instant::now();
let entry = state.entry(ip.to_string()).or_insert((0, now));
if now.duration_since(entry.1).as_secs() >= self.window_secs {
entry.0 = 0;
entry.1 = now;
}
entry.0 += 1;
entry.0 <= self.max_requests
}
}
pub async fn rate_limit_middleware(
axum::extract::State(limiter): axum::extract::State<RateLimiter>,
req: Request<axum::body::Body>,
next: Next,
) -> Response {
let ip = req
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.map(|s| s.split(',').next().unwrap_or("unknown").trim().to_string())
.or_else(|| {
req.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| ci.0.ip().to_string())
})
.unwrap_or_else(|| "unknown".to_string());
if !limiter.check(&ip).await {
tracing::warn!("Rate limited: {}", ip);
return (
StatusCode::TOO_MANY_REQUESTS,
axum::Json(json!({
"code": 429,
"msg": "Too many requests, please slow down",
})),
)
.into_response();
}
next.run(req).await
}