use axum::{
extract::{Request, State},
http::header,
middleware::Next,
response::Response,
};
use std::time::Instant;
use crate::auth::{Claims, JwtAuth};
use crate::error::ApiError;
use crate::state::AppState;
pub async fn auth_middleware(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Result<Response, ApiError> {
let path = request.uri().path();
let is_public = path == "/health"
|| path.starts_with("/public/")
|| path.starts_with("/swagger-ui")
|| path.starts_with("/api-docs")
|| path == "/.well-known/agent.json";
if is_public {
return Ok(next.run(request).await);
}
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.ok_or_else(|| ApiError::Unauthorized("Missing Authorization header".to_string()))?;
let token = JwtAuth::extract_from_header(auth_header)?;
let claims = state.jwt_auth().decode(token)?;
request.extensions_mut().insert(claims);
Ok(next.run(request).await)
}
pub async fn rate_limit_middleware(
State(state): State<AppState>,
request: Request,
next: Next,
) -> Result<Response, ApiError> {
let is_disabled = std::env::var("VEX_DISABLE_RATE_LIMIT")
.map(|v| v == "1" || v.to_lowercase() == "true")
.unwrap_or(false);
if is_disabled {
return Ok(next.run(request).await);
}
let tenant_id = request
.extensions()
.get::<Claims>()
.map(|c| c.sub.clone())
.or_else(|| {
request
.headers()
.get("x-client-id")
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
})
.unwrap_or_else(|| "anonymous".to_string());
state
.rate_limiter()
.check(&tenant_id)
.await
.map_err(|_| ApiError::RateLimited)?;
Ok(next.run(request).await)
}
pub async fn tracing_middleware(
State(state): State<AppState>,
request: Request,
next: Next,
) -> Response {
let start = Instant::now();
let method = request.method().clone();
let uri = request.uri().clone();
let path = uri.path().to_string();
let request_id = request
.extensions()
.get::<RequestId>()
.map(|id| id.0.clone())
.unwrap_or_else(|| "unknown".to_string());
let tenant_id = request
.extensions()
.get::<Claims>()
.map(|c| c.sub.clone())
.unwrap_or_else(|| "anonymous".to_string());
let span = tracing::info_span!(
"http_request",
method = %method,
path = %path,
request_id = %request_id,
tenant_id = %tenant_id,
status = tracing::field::Empty,
latency_ms = tracing::field::Empty,
);
let response = {
let _enter = span.enter();
next.run(request).await
};
let latency = start.elapsed();
let status = response.status();
state.metrics().record_llm_call(0, !status.is_success());
tracing::info!(
method = %method,
path = %path,
status = %status.as_u16(),
latency_ms = %latency.as_millis(),
"Request completed"
);
response
}
pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
let request_id = uuid::Uuid::new_v4().to_string();
request
.extensions_mut()
.insert(RequestId(request_id.clone()));
let mut response = next.run(request).await;
response
.headers_mut()
.insert("X-Request-ID", request_id.parse().unwrap());
response
}
#[derive(Clone, Debug)]
pub struct RequestId(pub String);
pub fn cors_layer() -> tower_http::cors::CorsLayer {
use tower_http::cors::{AllowOrigin, CorsLayer};
let origins = std::env::var("VEX_CORS_ORIGINS").ok();
let allow_origin = match origins {
Some(origins_str) if !origins_str.is_empty() => {
let origins: Vec<axum::http::HeaderValue> = origins_str
.split(',')
.filter_map(|s| s.trim().parse().ok())
.collect();
if origins.is_empty() {
tracing::warn!("VEX_CORS_ORIGINS is set but contains no valid origins, using restrictive default");
AllowOrigin::exact("https://localhost".parse().unwrap())
} else {
tracing::info!("CORS configured for {} origin(s)", origins.len());
AllowOrigin::list(origins)
}
}
_ => {
tracing::warn!("VEX_CORS_ORIGINS not set, using restrictive CORS (localhost only)");
AllowOrigin::exact("https://localhost".parse().unwrap())
}
};
CorsLayer::new()
.allow_origin(allow_origin)
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::PUT,
axum::http::Method::DELETE,
axum::http::Method::OPTIONS,
])
.allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE, header::ACCEPT])
.max_age(std::time::Duration::from_secs(3600))
}
#[allow(deprecated)]
pub fn timeout_layer(duration: std::time::Duration) -> tower_http::timeout::TimeoutLayer {
tower_http::timeout::TimeoutLayer::new(duration)
}
pub fn body_limit_layer(limit: usize) -> tower_http::limit::RequestBodyLimitLayer {
tower_http::limit::RequestBodyLimitLayer::new(limit)
}
pub async fn security_headers_middleware(request: Request, next: Next) -> Response {
let mut response = next.run(request).await;
let headers = response.headers_mut();
headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
headers.insert("X-Frame-Options", "DENY".parse().unwrap());
headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
headers.insert(
"Content-Security-Policy",
"default-src 'self'; frame-ancestors 'none'"
.parse()
.unwrap(),
);
if std::env::var("VEX_ENABLE_HSTS").is_ok() {
headers.insert(
"Strict-Transport-Security",
"max-age=31536000; includeSubDomains".parse().unwrap(),
);
}
headers.insert(
"Referrer-Policy",
"strict-origin-when-cross-origin".parse().unwrap(),
);
headers.insert(
"Permissions-Policy",
"geolocation=(), microphone=(), camera=()".parse().unwrap(),
);
response
}
#[cfg(test)]
mod tests {
#[test]
fn test_request_id() {
let id1 = uuid::Uuid::new_v4().to_string();
let id2 = uuid::Uuid::new_v4().to_string();
assert_ne!(id1, id2);
}
}