#[cfg(feature = "http-api")]
use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
#[cfg(feature = "http-api")]
use subtle::ConstantTimeEq;
#[cfg(feature = "http-api")]
use governor::{
clock::DefaultClock,
state::{InMemoryState, NotKeyed},
Quota, RateLimiter,
};
#[cfg(feature = "http-api")]
use std::{
net::IpAddr,
num::NonZeroU32,
sync::{Arc, OnceLock},
};
#[cfg(feature = "http-api")]
use dashmap::DashMap;
#[cfg(feature = "http-api")]
use std::env;
#[cfg(feature = "http-api")]
pub async fn auth_middleware(request: Request, next: Next) -> Result<Response, StatusCode> {
let headers = request.headers();
let auth_header = headers.get("authorization");
let auth_value = match auth_header {
Some(value) => value.to_str().map_err(|_| StatusCode::UNAUTHORIZED)?,
None => return Err(StatusCode::UNAUTHORIZED),
};
if !auth_value.starts_with("Bearer ") {
return Err(StatusCode::UNAUTHORIZED);
}
let token = &auth_value[7..];
let expected_token = env::var("SYMBIONT_API_TOKEN").map_err(|_| {
tracing::error!("SYMBIONT_API_TOKEN environment variable not set");
StatusCode::UNAUTHORIZED
})?;
if !bool::from(token.as_bytes().ct_eq(expected_token.as_bytes())) {
tracing::warn!("Authentication failed: invalid token provided");
return Err(StatusCode::UNAUTHORIZED);
}
Ok(next.run(request).await)
}
#[cfg(feature = "http-api")]
type IpRateLimiter = Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>;
static RATE_LIMITERS: OnceLock<DashMap<IpAddr, IpRateLimiter>> = OnceLock::new();
#[cfg(feature = "http-api")]
fn get_rate_limiter_for_ip(ip: IpAddr) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
let limiters = RATE_LIMITERS.get_or_init(DashMap::new);
if let Some(limiter) = limiters.get(&ip) {
Arc::clone(&limiter)
} else {
let quota = Quota::per_minute(NonZeroU32::new(100).unwrap());
let limiter = Arc::new(RateLimiter::direct(quota));
limiters.insert(ip, Arc::clone(&limiter));
limiter
}
}
#[cfg(feature = "http-api")]
fn extract_client_ip(request: &Request) -> IpAddr {
if let Some(forwarded_for) = request.headers().get("x-forwarded-for") {
if let Ok(forwarded_str) = forwarded_for.to_str() {
if let Some(last_ip) = forwarded_str.split(',').next_back() {
if let Ok(ip) = last_ip.trim().parse::<IpAddr>() {
return ip;
}
}
}
}
if let Some(real_ip) = request.headers().get("x-real-ip") {
if let Ok(real_ip_str) = real_ip.to_str() {
if let Ok(ip) = real_ip_str.parse::<IpAddr>() {
return ip;
}
}
}
"127.0.0.1".parse().unwrap()
}
#[cfg(feature = "http-api")]
pub async fn rate_limit_middleware(request: Request, next: Next) -> Result<Response, StatusCode> {
let client_ip = extract_client_ip(&request);
let rate_limiter = get_rate_limiter_for_ip(client_ip);
match rate_limiter.check() {
Ok(_) => {
Ok(next.run(request).await)
}
Err(_) => {
tracing::warn!("Rate limit exceeded for IP: {}", client_ip);
Err(StatusCode::TOO_MANY_REQUESTS)
}
}
}
#[cfg(feature = "http-api")]
pub async fn logging_middleware(request: Request, next: Next) -> Result<Response, StatusCode> {
use std::time::Instant;
let method = request.method().clone();
let uri = request.uri().clone();
let client_ip = extract_client_ip(&request);
let span = tracing::info_span!(
"http_request",
method = %method,
uri = %uri,
client_ip = %client_ip,
status_code = tracing::field::Empty,
latency_ms = tracing::field::Empty,
response_size = tracing::field::Empty,
);
let _guard = span.enter();
let start_time = Instant::now();
tracing::info!("Processing request");
let response = next.run(request).await;
let latency = start_time.elapsed();
let latency_ms = latency.as_millis() as u64;
let status_code = response.status();
let response_size = response
.headers()
.get("content-length")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
span.record("status_code", status_code.as_u16());
span.record("latency_ms", latency_ms);
span.record("response_size", response_size);
tracing::info!(
status_code = status_code.as_u16(),
latency_ms = latency_ms,
response_size = response_size,
"Request completed"
);
Ok(response)
}
#[cfg(feature = "http-api")]
pub async fn security_headers_middleware(
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
use axum::http::HeaderValue;
let mut response = next.run(request).await;
let headers = response.headers_mut();
headers.insert(
"strict-transport-security",
HeaderValue::from_static("max-age=63072000; includeSubDomains; preload"),
);
headers.insert(
"x-content-type-options",
HeaderValue::from_static("nosniff"),
);
headers.insert("x-frame-options", HeaderValue::from_static("DENY"));
headers.insert(
"content-security-policy",
HeaderValue::from_static("default-src 'self'; frame-ancestors 'none'"),
);
Ok(response)
}