use std::borrow::Cow;
use std::fmt::Display;
use std::sync::Arc;
use axum::extract::State;
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use hyper::header::USER_AGENT;
use lib::config::Config;
use lib::types::RequestId;
use tower_http::trace::MakeSpan;
use tracing::{error_span, info};
use crate::auth::API_KEY_PREFIX;
#[derive(Clone, Debug)]
pub struct ApiMakeSpan {
service_name: String,
}
impl ApiMakeSpan {
pub fn new(service_name: String) -> Self {
Self { service_name }
}
}
impl<B> MakeSpan<B> for ApiMakeSpan {
fn make_span(&mut self, request: &Request<B>) -> tracing::Span {
let user_agent = request.headers().get(USER_AGENT);
let request_id = request
.extensions()
.get::<RequestId>()
.map(ToString::to_string);
error_span!(
target: "request_response_tracing_metadata",
"http_request",
service = %self.service_name,
request_id = %request_id.unwrap_or_default(),
method = %request.method(),
uri = %request.uri(),
version = ?request.version(),
user_agent = ?user_agent,
)
}
}
pub async fn trace_request_response(
State(config): State<Arc<Config>>,
mut request: Request<axum::body::Body>,
next: Next<axum::body::Body>,
) -> Result<impl IntoResponse, Response> {
let config = &config.api;
if config.log_request_body {
let (parts, body) = request.into_parts();
let bytes = buffer_and_print_body(body, "Got request", None).await?;
request = Request::from_parts(parts, axum::body::Body::from(bytes));
}
let resp = next.run(request).await;
if config.log_response_body {
let status = resp.status().as_u16();
let (parts, body) = resp.into_parts();
let bytes =
buffer_and_print_body(body, "Sent response", Some(status)).await?;
return Ok(Response::from_parts(parts, axum::body::Body::from(bytes))
.into_response());
}
Ok(resp)
}
async fn buffer_and_print_body<B>(
body: B,
msg: &str,
status: Option<u16>,
) -> Result<axum::body::Bytes, Response>
where
B: axum::body::HttpBody,
<B as axum::body::HttpBody>::Error: Display,
{
let bytes = hyper::body::to_bytes(body).await.map_err(|err| {
(StatusCode::BAD_REQUEST, err.to_string()).into_response()
})?;
let mut body_str = String::from_utf8_lossy(&bytes);
if body_str.find(API_KEY_PREFIX).is_some() {
body_str = Cow::from("REDACTED");
}
info!(target: "request_response_tracing", body = %body_str, status, msg);
Ok(bytes)
}