use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use axum::{body::Body, extract::ConnectInfo, middleware::Next, response::Response};
use bytes::Bytes;
use http_body::Frame;
use super::request_id::XRequestId;
pub async fn access_log_middleware(req: axum::extract::Request, next: Next) -> Response {
let start = std::time::Instant::now();
let method = req.method().to_string();
let uri = req.uri().path_and_query().map_or_else(
|| req.uri().path().to_owned(),
std::string::ToString::to_string,
);
let content_length: u64 = req
.headers()
.get(axum::http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let user_agent = req
.headers()
.get(axum::http::header::USER_AGENT)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_owned();
let request_id = req
.extensions()
.get::<XRequestId>()
.map_or_else(String::new, |x| x.0.clone());
let trace_id = req
.headers()
.get(modkit_http::otel::TRACEPARENT)
.and_then(|v| v.to_str().ok())
.and_then(modkit_http::otel::parse_trace_id)
.unwrap_or_default();
let (remote_addr, remote_addr_ip, remote_addr_port) = req
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| {
let addr = ci.0;
(addr.to_string(), addr.ip().to_string(), addr.port())
})
.unwrap_or_default();
let response = next.run(req).await;
let status = response.status().as_u16();
let log_ctx = AccessLogContext {
start,
pid: std::process::id(),
request_id,
trace_id,
method,
uri,
remote_addr,
remote_addr_ip,
remote_addr_port,
content_length,
user_agent,
status,
};
let (parts, body) = response.into_parts();
let counting_body = CountingBody {
inner: body,
bytes_sent: 0,
log_ctx: Some(log_ctx),
};
Response::from_parts(parts, Body::new(counting_body))
}
struct AccessLogContext {
start: std::time::Instant,
pid: u32,
request_id: String,
trace_id: String,
method: String,
uri: String,
remote_addr: String,
remote_addr_ip: String,
remote_addr_port: u16,
content_length: u64,
user_agent: String,
status: u16,
}
impl AccessLogContext {
fn emit(self, bytes_sent: u64) {
let elapsed = self.start.elapsed();
let duration_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX);
let duration_micros = u64::try_from(elapsed.as_micros()).unwrap_or(u64::MAX);
tracing::info!(
target: "access_log",
msg = "response completed",
pid = self.pid,
request_id = %self.request_id,
trace_id = %self.trace_id,
method = %self.method,
uri = %self.uri,
remote_addr = %self.remote_addr,
remote_addr_ip = %self.remote_addr_ip,
remote_addr_port = self.remote_addr_port,
content_length = self.content_length,
user_agent = %self.user_agent,
duration_ms = duration_ms,
duration = duration_micros,
status = self.status,
bytes_sent = bytes_sent,
);
}
}
struct CountingBody {
inner: Body,
bytes_sent: u64,
log_ctx: Option<AccessLogContext>,
}
impl http_body::Body for CountingBody {
type Data = Bytes;
type Error = axum::Error;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let this = self.get_mut();
let inner = Pin::new(&mut this.inner);
match inner.poll_frame(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => {
if let Some(ctx) = this.log_ctx.take() {
ctx.emit(this.bytes_sent);
}
Poll::Ready(None)
}
Poll::Ready(Some(Ok(frame))) => {
if let Some(data) = frame.data_ref() {
this.bytes_sent = this.bytes_sent.saturating_add(data.len() as u64);
}
Poll::Ready(Some(Ok(frame)))
}
Poll::Ready(Some(Err(e))) => {
if let Some(ctx) = this.log_ctx.take() {
ctx.emit(this.bytes_sent);
}
Poll::Ready(Some(Err(e)))
}
}
}
fn is_end_stream(&self) -> bool {
self.inner.is_end_stream()
}
fn size_hint(&self) -> http_body::SizeHint {
self.inner.size_hint()
}
}
impl Drop for CountingBody {
fn drop(&mut self) {
if let Some(ctx) = self.log_ctx.take() {
ctx.emit(self.bytes_sent);
}
}
}