use axum::{body::Body, extract::Request, http::Response, middleware::Next};
use serde_json::Value;
#[derive(Clone)]
pub struct BufferedResponse {
pub status: u16,
pub headers: http::HeaderMap,
pub body: axum::body::Bytes,
}
impl BufferedResponse {
pub fn json(&self) -> Option<Value> {
serde_json::from_slice(&self.body).ok()
}
pub fn text(&self) -> String {
String::from_utf8_lossy(&self.body).to_string()
}
}
pub async fn buffer_response_middleware(req: Request, next: Next) -> Response<Body> {
let response = next.run(req).await;
let (parts, body) = response.into_parts();
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes,
Err(e) => {
tracing::warn!("Failed to buffer response body: {}", e);
return Response::builder()
.status(http::StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Failed to buffer response"))
.expect("static response body should never fail to build");
}
};
let buffered = BufferedResponse {
status: parts.status.as_u16(),
headers: parts.headers.clone(),
body: body_bytes.clone(),
};
let mut response_builder = Response::builder().status(parts.status).version(parts.version);
for (name, value) in parts.headers.iter() {
response_builder = response_builder.header(name, value);
}
let mut response = match response_builder.body(Body::from(body_bytes)) {
Ok(resp) => resp,
Err(e) => {
tracing::error!("Failed to build response: {}", e);
return Response::builder()
.status(http::StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Failed to build response"))
.expect("static response body should never fail to build");
}
};
response.extensions_mut().insert(buffered);
response
}
pub fn get_buffered_response(response: &Response<Body>) -> Option<BufferedResponse> {
response.extensions().get::<BufferedResponse>().cloned()
}