use std::time::Instant;
use axum::body::Body;
use axum::http::{HeaderName, HeaderValue, Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
pub const REQUEST_ID_HEADER: &str = "x-request-id";
pub const HEALTH_PATH: &str = "/health";
pub const READY_PATH: &str = "/ready";
pub async fn health() -> Response {
(StatusCode::OK, "ok").into_response()
}
pub async fn ready() -> Response {
(StatusCode::OK, "ready").into_response()
}
pub async fn request_id_middleware(mut req: Request<Body>, next: Next) -> Response {
let incoming = req
.headers()
.get(REQUEST_ID_HEADER)
.and_then(|v| v.to_str().ok())
.filter(|s| !s.is_empty() && s.len() <= 128)
.map(|s| s.to_string());
let request_id = incoming.unwrap_or_else(super::security::random_token);
let method = req.method().clone();
let path = req.uri().path().to_string();
req.extensions_mut().insert(RequestId(request_id.clone()));
if let Ok(value) = HeaderValue::from_str(&request_id) {
req.headers_mut()
.insert(HeaderName::from_static("x-request-id"), value);
}
let started = Instant::now();
let mut res = next.run(req).await;
let latency_ms = started.elapsed().as_millis();
let status = res.status().as_u16();
tracing::info!(
request_id = %request_id,
method = %method,
path = %path,
status,
latency_ms,
"request"
);
if let Ok(value) = HeaderValue::from_str(&request_id) {
res.headers_mut()
.insert(HeaderName::from_static("x-request-id"), value);
}
res
}
#[derive(Debug, Clone)]
pub struct RequestId(pub String);
pub async fn shutdown_signal() {
let ctrl_c = async {
let _ = tokio::signal::ctrl_c().await;
};
#[cfg(unix)]
let terminate = async {
match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
Ok(mut sig) => {
sig.recv().await;
}
Err(_) => std::future::pending::<()>().await,
}
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {}
_ = terminate => {}
}
tracing::info!("shutdown signal received; draining connections");
}