use axum::{
body::Body,
extract::MatchedPath,
http::{Request, Response, StatusCode},
middleware::Next,
response::IntoResponse,
};
use http_body_util::BodyExt;
use std::time::Instant;
use crate::metrics::{
record_http_request, record_http_request_size, record_http_response_size,
HTTP_CONNECTIONS_ACTIVE, HTTP_REQUEST_DURATION_SECONDS,
};
pub async fn metrics_middleware(
req: Request<Body>,
next: Next,
) -> Result<Response<Body>, StatusCode> {
let start = Instant::now();
let path = req
.extensions()
.get::<MatchedPath>()
.map(|p| p.as_str().to_string())
.unwrap_or_else(|| req.uri().path().to_string());
let method = req.method().to_string();
let (parts, body) = req.into_parts();
let body_bytes = body
.collect()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.to_bytes();
let request_size = body_bytes.len();
let req = Request::from_parts(parts, Body::from(body_bytes));
HTTP_CONNECTIONS_ACTIVE.with_label_values(&[&path]).inc();
let response = next.run(req).await;
let status = response.status();
let (parts, body) = response.into_parts();
let body_bytes = body
.collect()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.to_bytes();
let response_size = body_bytes.len();
let response = Response::from_parts(parts, Body::from(body_bytes));
let duration = start.elapsed().as_secs_f64();
HTTP_REQUEST_DURATION_SECONDS
.with_label_values(&[&path, &method])
.observe(duration);
record_http_request(&path, &method, status.as_u16());
record_http_request_size(&path, &method, request_size);
record_http_response_size(&path, &method, response_size);
HTTP_CONNECTIONS_ACTIVE.with_label_values(&[&path]).dec();
Ok(response)
}
pub async fn metrics_middleware_streaming(req: Request<Body>, next: Next) -> impl IntoResponse {
let start = Instant::now();
let path = req
.extensions()
.get::<MatchedPath>()
.map(|p| p.as_str().to_string())
.unwrap_or_else(|| req.uri().path().to_string());
let method = req.method().to_string();
HTTP_CONNECTIONS_ACTIVE.with_label_values(&[&path]).inc();
let response = next.run(req).await;
let status = response.status();
let duration = start.elapsed().as_secs_f64();
HTTP_REQUEST_DURATION_SECONDS
.with_label_values(&[&path, &method])
.observe(duration);
record_http_request(&path, &method, status.as_u16());
HTTP_CONNECTIONS_ACTIVE.with_label_values(&[&path]).dec();
response
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{routing::get, Router};
use tower::ServiceExt;
async fn test_handler() -> &'static str {
"Hello, World!"
}
#[tokio::test]
async fn test_metrics_middleware_streaming() {
let app = Router::new()
.route("/test", get(test_handler))
.layer(axum::middleware::from_fn(metrics_middleware_streaming));
let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let metrics = crate::metrics::encode_metrics().unwrap();
assert!(metrics.contains("ipfrs_http_requests_total"));
assert!(metrics.contains("ipfrs_http_request_duration_seconds"));
}
}