use axum::{
extract::Request,
middleware::Next,
response::Response,
};
use std::time::Instant;
use tracing::{info, info_span, Instrument};
pub struct LoggingLayer;
impl LoggingLayer {
pub fn new() -> Self {
Self
}
}
impl Default for LoggingLayer {
fn default() -> Self {
Self::new()
}
}
pub async fn log_request(req: Request, next: Next) -> Response {
let method = req.method().clone();
let uri = req.uri().clone();
let version = req.version();
let request_id = req
.extensions()
.get::<super::request_id::RequestId>()
.map(|id| id.as_str().to_string())
.unwrap_or_else(|| "unknown".to_string());
let span = info_span!(
"http_request",
method = %method,
uri = %uri,
version = ?version,
request_id = %request_id,
);
async move {
let start = Instant::now();
info!("request started");
let response = next.run(req).await;
let latency = start.elapsed();
let status = response.status();
info!(
status = %status.as_u16(),
latency_ms = %latency.as_millis(),
"request completed"
);
response
}
.instrument(span)
.await
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode},
middleware,
routing::get,
Router,
};
use tower::ServiceExt;
#[tokio::test]
async fn test_logging_middleware() {
async fn handler() -> &'static str {
"ok"
}
let app = Router::new()
.route("/", get(handler))
.layer(middleware::from_fn(log_request));
let request = Request::builder().uri("/").body(Body::empty()).unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
}