use axum::{
body::Body,
extract::Extension,
http::Request,
middleware::Next,
response::{IntoResponse, Response},
};
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
use std::time::Instant;
#[derive(Clone)]
pub struct MetricsState {
handle: PrometheusHandle,
}
impl MetricsState {
pub fn render(&self) -> String {
self.handle.render()
}
}
pub fn install_recorder() -> anyhow::Result<MetricsState> {
let handle = PrometheusBuilder::new()
.install_recorder()
.map_err(|e| anyhow::anyhow!("failed to install prometheus recorder: {e}"))?;
Ok(MetricsState { handle })
}
pub async fn metrics_handler(Extension(state): Extension<MetricsState>) -> Response {
let body = state.render();
(
[(
axum::http::header::CONTENT_TYPE,
"text/plain; version=0.0.4",
)],
body,
)
.into_response()
}
pub async fn request_metrics_middleware(request: Request<Body>, next: Next) -> Response {
let started = Instant::now();
let method = request.method().clone();
let endpoint = request
.extensions()
.get::<axum::extract::MatchedPath>()
.map(|m| m.as_str().to_string())
.unwrap_or_else(|| request.uri().path().to_string());
let response = next.run(request).await;
let status = response.status().as_u16().to_string();
let elapsed_ms = started.elapsed().as_millis() as f64;
metrics::counter!(
"trusty_requests_total",
"endpoint" => endpoint.clone(),
"method" => method.as_str().to_string(),
)
.increment(1);
metrics::histogram!(
"trusty_request_latency_ms",
"endpoint" => endpoint,
"status" => status,
)
.record(elapsed_ms);
response
}
pub fn set_index_count(count: usize) {
metrics::gauge!("trusty_index_count").set(count as f64);
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{routing::get, Router};
use tower::ServiceExt;
use std::sync::OnceLock;
fn shared_state() -> MetricsState {
static STATE: OnceLock<MetricsState> = OnceLock::new();
STATE
.get_or_init(|| install_recorder().expect("recorder installs"))
.clone()
}
#[tokio::test]
async fn metrics_handler_returns_prometheus_text() {
let state = shared_state();
let app = Router::new()
.route("/metrics", get(metrics_handler))
.layer(Extension(state));
metrics::counter!("trusty_test_counter").increment(1);
let resp = app
.oneshot(
axum::http::Request::builder()
.uri("/metrics")
.body(Body::empty())
.expect("valid request"),
)
.await
.expect("response");
assert_eq!(resp.status(), axum::http::StatusCode::OK);
assert_eq!(
resp.headers()
.get(axum::http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok()),
Some("text/plain; version=0.0.4")
);
let body_bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.expect("body");
let text = String::from_utf8_lossy(&body_bytes);
assert!(
text.contains("trusty_test_counter"),
"rendered metrics missing counter: {text}"
);
}
#[tokio::test]
async fn request_middleware_records_latency_and_total() {
let state = shared_state();
let app = Router::new()
.route("/ping", get(|| async { "pong" }))
.layer(axum::middleware::from_fn(request_metrics_middleware))
.route("/metrics", get(metrics_handler))
.layer(Extension(state));
let _ = app
.clone()
.oneshot(
axum::http::Request::builder()
.uri("/ping")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let resp = app
.oneshot(
axum::http::Request::builder()
.uri("/metrics")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let text = String::from_utf8_lossy(&body);
assert!(
text.contains("trusty_requests_total"),
"missing requests_total: {text}"
);
}
}