use std::{
sync::{Arc, atomic::Ordering},
time::Instant,
};
use axum::{body::Body, extract::State, http::Request, middleware::Next, response::Response};
use crate::metrics_server::MetricsCollector;
pub async fn metrics_middleware(
State(metrics): State<Arc<MetricsCollector>>,
request: Request<Body>,
next: Next,
) -> Response {
metrics.http_requests_total.fetch_add(1, Ordering::Relaxed);
let start = Instant::now();
let response = next.run(request).await;
#[allow(clippy::cast_possible_truncation)]
let elapsed_us = start.elapsed().as_micros() as u64;
metrics.http_request_duration.observe_us(elapsed_us);
let status = response.status();
match status.as_u16() {
200..=299 => {
metrics.http_responses_2xx.fetch_add(1, Ordering::Relaxed);
},
400..=499 => {
metrics.http_responses_4xx.fetch_add(1, Ordering::Relaxed);
},
500..=599 => {
metrics.http_responses_5xx.fetch_add(1, Ordering::Relaxed);
},
_ => {
},
}
response
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)] #![allow(clippy::cast_precision_loss)] #![allow(clippy::cast_sign_loss)] #![allow(clippy::cast_possible_truncation)] #![allow(clippy::cast_possible_wrap)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_errors_doc)] #![allow(missing_docs)] #![allow(clippy::items_after_statements)]
use axum::{
Router,
body::Body,
http::{Request, StatusCode},
middleware,
routing::get,
};
use tower::ServiceExt;
use super::*;
async fn ok_handler() -> StatusCode {
StatusCode::OK
}
async fn bad_request_handler() -> StatusCode {
StatusCode::BAD_REQUEST
}
async fn internal_error_handler() -> StatusCode {
StatusCode::INTERNAL_SERVER_ERROR
}
#[tokio::test]
async fn test_metrics_middleware_counts_requests() {
let metrics = Arc::new(MetricsCollector::new());
let app = Router::new()
.route("/ok", get(ok_handler))
.layer(middleware::from_fn_with_state(metrics.clone(), metrics_middleware));
let request = Request::builder().uri("/ok").body(Body::empty()).unwrap();
let _response = app.oneshot(request).await.unwrap();
assert_eq!(metrics.http_requests_total.load(Ordering::Relaxed), 1);
assert_eq!(metrics.http_responses_2xx.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_metrics_middleware_tracks_4xx() {
let metrics = Arc::new(MetricsCollector::new());
let app = Router::new()
.route("/bad", get(bad_request_handler))
.layer(middleware::from_fn_with_state(metrics.clone(), metrics_middleware));
let request = Request::builder().uri("/bad").body(Body::empty()).unwrap();
let _response = app.oneshot(request).await.unwrap();
assert_eq!(metrics.http_requests_total.load(Ordering::Relaxed), 1);
assert_eq!(metrics.http_responses_4xx.load(Ordering::Relaxed), 1);
assert_eq!(metrics.http_responses_2xx.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn test_metrics_middleware_tracks_5xx() {
let metrics = Arc::new(MetricsCollector::new());
let app = Router::new()
.route("/error", get(internal_error_handler))
.layer(middleware::from_fn_with_state(metrics.clone(), metrics_middleware));
let request = Request::builder().uri("/error").body(Body::empty()).unwrap();
let _response = app.oneshot(request).await.unwrap();
assert_eq!(metrics.http_requests_total.load(Ordering::Relaxed), 1);
assert_eq!(metrics.http_responses_5xx.load(Ordering::Relaxed), 1);
}
}