use crate::auth::AuthProvider;
use crate::config::AppConfig;
use axum::extract::Request;
use axum::middleware::Next;
use http::header::{CONTENT_TYPE, HeaderName};
use std::sync::Arc;
use std::time::Duration;
use tower_http::cors::{AllowOrigin, CorsLayer};
use tower_http::propagate_header::PropagateHeaderLayer;
use tower_http::request_id::{MakeRequestUuid, SetRequestIdLayer};
use tower_http::trace::TraceLayer;
pub static X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
pub(crate) fn apply_middleware(
router: axum::Router,
config: &AppConfig,
auth: Option<Arc<dyn AuthProvider>>,
) -> axum::Router {
let cors = build_cors_layer(config);
let router = router
.layer(PropagateHeaderLayer::new(X_REQUEST_ID.clone()))
.layer(cors);
let router = if let Some(provider) = auth {
router.layer(axum::middleware::from_fn(move |req, next| {
auth_middleware(provider.clone(), req, next)
}))
} else {
router
};
router
.layer(http_trace_layer())
.layer(SetRequestIdLayer::new(
X_REQUEST_ID.clone(),
MakeRequestUuid,
))
}
async fn auth_middleware(
provider: Arc<dyn AuthProvider>,
req: Request,
next: Next,
) -> axum::response::Response {
let (parts, body) = req.into_parts();
match provider.authenticate(&parts) {
Ok(()) => {
let req = Request::from_parts(parts, body);
next.run(req).await
}
Err(rejection) => *rejection,
}
}
fn http_trace_layer() -> TraceLayer<
tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>,
impl Fn(&Request) -> tracing::Span + Clone,
impl Fn(&Request, &tracing::Span) + Clone,
impl Fn(&http::Response<axum::body::Body>, Duration, &tracing::Span) + Clone,
> {
TraceLayer::new_for_http()
.make_span_with(|request: &Request| {
let method = request.method().as_str();
let path = request.uri().path();
let request_id = request
.headers()
.get(&X_REQUEST_ID)
.and_then(|v| v.to_str().ok())
.unwrap_or("-");
make_http_span(method, path, request_id)
})
.on_request(|_request: &Request, _span: &tracing::Span| {
tracing::info!(target: "polaris::http", "started processing request");
})
.on_response(
|response: &http::Response<axum::body::Body>,
latency: Duration,
span: &tracing::Span| {
let status = response.status().as_u16();
record_status(span, status);
tracing::info!(
target: "polaris::http",
latency_ms = latency.as_millis() as u64,
"finished processing request"
);
},
)
}
#[cfg(feature = "otel")]
fn make_http_span(method: &str, path: &str, request_id: &str) -> tracing::Span {
tracing::info_span!(
target: "polaris::http",
"HTTP",
otel.name = %format_args!("{method} {path}"),
otel.kind = "Server",
http.request.method = method,
url.path = path,
http.response.status_code = tracing::field::Empty,
polaris.http.request_id = %request_id,
)
}
#[cfg(not(feature = "otel"))]
fn make_http_span(method: &str, path: &str, request_id: &str) -> tracing::Span {
tracing::info_span!(
target: "polaris::http",
"polaris.http.request",
polaris.http.method = method,
polaris.http.path = path,
polaris.http.status_code = tracing::field::Empty,
polaris.http.request_id = %request_id,
)
}
#[cfg(feature = "otel")]
fn record_status(span: &tracing::Span, status: u16) {
span.record("http.response.status_code", status);
}
#[cfg(not(feature = "otel"))]
fn record_status(span: &tracing::Span, status: u16) {
span.record("polaris.http.status_code", status);
}
fn build_cors_layer(config: &AppConfig) -> CorsLayer {
let origins = config.cors_origins();
let allow_origin = if origins.is_empty() {
AllowOrigin::any()
} else {
let parsed: Vec<_> = origins
.iter()
.filter_map(|origin| match origin.parse() {
Ok(val) => Some(val),
Err(err) => {
tracing::warn!(
origin = %origin,
error = %err,
"ignoring invalid CORS origin"
);
None
}
})
.collect();
AllowOrigin::list(parsed)
};
CorsLayer::new()
.allow_origin(allow_origin)
.allow_methods(tower_http::cors::Any)
.allow_headers([CONTENT_TYPE])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_allows_any_origin() {
let config = AppConfig::default();
let _cors = build_cors_layer(&config);
}
#[test]
fn specific_origins_are_accepted() {
let config = AppConfig::new()
.with_cors_origin("http://localhost:3000")
.with_cors_origin("https://example.com");
let _cors = build_cors_layer(&config);
}
}