use std::sync::{Arc, atomic::Ordering};
use axum::{body::Body, extract::State, http::Request, middleware::Next, response::Response};
use crate::metrics_server::MetricsCollector;
pub async fn metrics_middleware(
State(metrics): State<Arc<MetricsCollector>>,
request: Request<Body>,
next: Next,
) -> Response {
metrics.http_requests_total.fetch_add(1, Ordering::Relaxed);
let response = next.run(request).await;
let status = response.status();
match status.as_u16() {
200..=299 => {
metrics.http_responses_2xx.fetch_add(1, Ordering::Relaxed);
},
400..=499 => {
metrics.http_responses_4xx.fetch_add(1, Ordering::Relaxed);
},
500..=599 => {
metrics.http_responses_5xx.fetch_add(1, Ordering::Relaxed);
},
_ => {
},
}
response
}
#[cfg(test)]
mod tests {
use axum::{
Router,
body::Body,
http::{Request, StatusCode},
middleware,
routing::get,
};
use tower::ServiceExt;
use super::*;
async fn ok_handler() -> StatusCode {
StatusCode::OK
}
async fn bad_request_handler() -> StatusCode {
StatusCode::BAD_REQUEST
}
async fn internal_error_handler() -> StatusCode {
StatusCode::INTERNAL_SERVER_ERROR
}
#[tokio::test]
async fn test_metrics_middleware_counts_requests() {
let metrics = Arc::new(MetricsCollector::new());
let app = Router::new()
.route("/ok", get(ok_handler))
.layer(middleware::from_fn_with_state(metrics.clone(), metrics_middleware));
let request = Request::builder().uri("/ok").body(Body::empty()).unwrap();
let _response = app.oneshot(request).await.unwrap();
assert_eq!(metrics.http_requests_total.load(Ordering::Relaxed), 1);
assert_eq!(metrics.http_responses_2xx.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_metrics_middleware_tracks_4xx() {
let metrics = Arc::new(MetricsCollector::new());
let app = Router::new()
.route("/bad", get(bad_request_handler))
.layer(middleware::from_fn_with_state(metrics.clone(), metrics_middleware));
let request = Request::builder().uri("/bad").body(Body::empty()).unwrap();
let _response = app.oneshot(request).await.unwrap();
assert_eq!(metrics.http_requests_total.load(Ordering::Relaxed), 1);
assert_eq!(metrics.http_responses_4xx.load(Ordering::Relaxed), 1);
assert_eq!(metrics.http_responses_2xx.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn test_metrics_middleware_tracks_5xx() {
let metrics = Arc::new(MetricsCollector::new());
let app = Router::new()
.route("/error", get(internal_error_handler))
.layer(middleware::from_fn_with_state(metrics.clone(), metrics_middleware));
let request = Request::builder().uri("/error").body(Body::empty()).unwrap();
let _response = app.oneshot(request).await.unwrap();
assert_eq!(metrics.http_requests_total.load(Ordering::Relaxed), 1);
assert_eq!(metrics.http_responses_5xx.load(Ordering::Relaxed), 1);
}
}