use std::{
net::SocketAddr,
sync::Arc,
time::Duration,
};
use axum::{
body::Body,
extract::{ConnectInfo, State},
http::{header, HeaderMap, Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
Json,
};
use governor::{
clock::DefaultClock,
state::{InMemoryState, NotKeyed},
Quota, RateLimiter,
};
use tower_http::cors::{Any, CorsLayer};
use crate::{error::ErrorResponse, AppContext, Config};
pub fn cors_layer(config: &Config) -> CorsLayer {
let cors = CorsLayer::new()
.allow_methods(Any)
.allow_headers(Any)
.max_age(Duration::from_secs(3600));
if config.cors_origins.contains(&"*".to_string()) {
cors.allow_origin(Any)
} else {
cors.allow_origin(Any) }
}
pub type SharedRateLimiter = Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>;
pub fn create_rate_limiter(rps: u32) -> SharedRateLimiter {
let quota = Quota::per_second(std::num::NonZeroU32::new(rps).unwrap());
Arc::new(RateLimiter::direct(quota))
}
pub async fn rate_limit_middleware(
State(limiter): State<SharedRateLimiter>,
request: Request<Body>,
next: Next,
) -> Response {
match limiter.check() {
Ok(_) => next.run(request).await,
Err(_) => {
let body = ErrorResponse {
error: "rate_limit_exceeded".into(),
message: "Too many requests. Please slow down.".into(),
details: None,
request_id: None,
};
(StatusCode::TOO_MANY_REQUESTS, Json(body)).into_response()
}
}
}
pub async fn auth_middleware(
State(ctx): State<AppContext>,
headers: HeaderMap,
request: Request<Body>,
next: Next,
) -> Response {
let Some(expected_key) = &ctx.config.api_key else {
return next.run(request).await;
};
let auth_header = headers
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok());
match auth_header {
Some(auth) if auth.starts_with("Bearer ") => {
let provided_key = auth.trim_start_matches("Bearer ").trim();
if provided_key == expected_key {
next.run(request).await
} else {
unauthorized_response("Invalid API key")
}
}
Some(_) => unauthorized_response("Invalid authorization format. Use 'Bearer <api_key>'"),
None => unauthorized_response("Missing Authorization header"),
}
}
fn unauthorized_response(message: &str) -> Response {
let body = ErrorResponse {
error: "unauthorized".into(),
message: message.into(),
details: None,
request_id: None,
};
(StatusCode::UNAUTHORIZED, Json(body)).into_response()
}
pub async fn logging_middleware(
headers: HeaderMap,
request: Request<Body>,
next: Next,
) -> Response {
let method = request.method().clone();
let uri = request.uri().clone();
let request_id = headers
.get("x-request-id")
.and_then(|h| h.to_str().ok())
.map(String::from);
let start = std::time::Instant::now();
let response = next.run(request).await;
let latency = start.elapsed();
let status = response.status();
tracing::info!(
method = %method,
uri = %uri,
status = %status.as_u16(),
latency_ms = %latency.as_millis(),
request_id = ?request_id,
"HTTP request"
);
response
}
pub async fn json_content_type_middleware(
headers: HeaderMap,
request: Request<Body>,
next: Next,
) -> Response {
if matches!(
request.method().as_str(),
"POST" | "PUT" | "PATCH"
) {
let path = request.uri().path();
if path.contains("/recordings") {
return next.run(request).await;
}
let content_type = headers
.get(header::CONTENT_TYPE)
.and_then(|h| h.to_str().ok());
match content_type {
Some(ct) if ct.contains("application/json") => next.run(request).await,
Some(ct) => {
let body = ErrorResponse {
error: "unsupported_media_type".into(),
message: format!("Expected application/json, got {}", ct),
details: None,
request_id: None,
};
(StatusCode::UNSUPPORTED_MEDIA_TYPE, Json(body)).into_response()
}
None => {
let body = ErrorResponse {
error: "unsupported_media_type".into(),
message: "Missing Content-Type header".into(),
details: None,
request_id: None,
};
(StatusCode::UNSUPPORTED_MEDIA_TYPE, Json(body)).into_response()
}
}
} else {
next.run(request).await
}
}
pub struct BodyLimitMiddleware {
max_size: usize,
}
impl BodyLimitMiddleware {
pub fn new(max_size: usize) -> Self {
Self { max_size }
}
}
pub fn extract_client_ip(headers: &HeaderMap, connect_info: Option<&ConnectInfo<SocketAddr>>) -> Option<String> {
if let Some(forwarded) = headers
.get("x-forwarded-for")
.and_then(|h| h.to_str().ok())
{
if let Some(ip) = forwarded.split(',').next() {
return Some(ip.trim().to_string());
}
}
if let Some(real_ip) = headers
.get("x-real-ip")
.and_then(|h| h.to_str().ok())
{
return Some(real_ip.to_string());
}
connect_info.map(|ci| ci.0.ip().to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cors_layer_creation() {
let config = Config::default();
let _layer = cors_layer(&config);
}
#[test]
fn test_rate_limiter_creation() {
let limiter = create_rate_limiter(100);
assert!(limiter.check().is_ok());
}
#[test]
fn test_extract_client_ip_x_forwarded() {
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-for", "1.2.3.4, 5.6.7.8".parse().unwrap());
let ip = extract_client_ip(&headers, None);
assert_eq!(ip, Some("1.2.3.4".to_string()));
}
#[test]
fn test_extract_client_ip_x_real() {
let mut headers = HeaderMap::new();
headers.insert("x-real-ip", "10.0.0.1".parse().unwrap());
let ip = extract_client_ip(&headers, None);
assert_eq!(ip, Some("10.0.0.1".to_string()));
}
}