Skip to main content

rs_zero/observability/
rest.rs

1use std::time::Instant;
2
3use axum::{
4    Router,
5    body::Body,
6    extract::{MatchedPath, State},
7    http::Request,
8    middleware::Next,
9    response::IntoResponse,
10    routing::get,
11};
12
13use crate::observability::{CorrelationContext, HttpMetricLabels, MetricsRegistry};
14use tracing::Instrument;
15
16/// Records request metrics for an axum route.
17pub async fn record_metrics_middleware(
18    State(registry): State<MetricsRegistry>,
19    mut request: Request<Body>,
20    next: Next,
21) -> impl IntoResponse {
22    let method = request.method().to_string();
23    let route = request
24        .extensions()
25        .get::<MatchedPath>()
26        .map(|path| path.as_str().to_string())
27        .unwrap_or_else(|| "unknown".to_string());
28    let correlation = CorrelationContext::from_http_headers(
29        None,
30        method.clone(),
31        Some(&route),
32        request.headers(),
33    );
34    let request_id = correlation.request_id().unwrap_or("").to_string();
35    let incoming_traceparent = correlation.traceparent();
36    let started = Instant::now();
37    let span = tracing::info_span!(
38        "rs_zero.http.request",
39        http.method = %method,
40        http.route = %route,
41        service = "unknown",
42        transport = "http",
43        route = %route,
44        method = %method,
45        request_id = %request_id,
46        traceparent = incoming_traceparent.unwrap_or(""),
47        status = tracing::field::Empty,
48        trace_id = tracing::field::Empty,
49        span_id = tracing::field::Empty
50    );
51    registry.increment_http_in_flight();
52    if !request_id.is_empty() {
53        request
54            .extensions_mut()
55            .insert(crate::observability::CurrentRequestId(request_id.clone()));
56    }
57    #[cfg(feature = "otlp")]
58    crate::observability::set_span_parent_from_headers(&span, request.headers());
59    let response_future = next.run(request).instrument(span.clone());
60    let response = if request_id.is_empty() {
61        response_future.await
62    } else {
63        with_current_request_id(request_id.clone(), response_future).await
64    };
65    registry.decrement_http_in_flight();
66    let status = response.status().as_u16();
67    span.record("status", tracing::field::display(status));
68    let correlation = correlation.with_status(status.to_string());
69    if let Some(trace_id) = correlation.trace_id() {
70        span.record("trace_id", tracing::field::display(trace_id));
71    }
72    if let Some(span_id) = correlation.span_id() {
73        span.record("span_id", tracing::field::display(span_id));
74    }
75    registry.record_http_request(
76        HttpMetricLabels::new(method, route, status),
77        started.elapsed(),
78    );
79    response
80}
81
82async fn with_current_request_id<T>(
83    request_id: String,
84    future: impl std::future::Future<Output = T>,
85) -> T {
86    crate::layer::context::scope_request_id(request_id, future).await
87}
88
89/// Creates a router exposing Prometheus metrics at `/metrics`.
90pub fn metrics_router(registry: MetricsRegistry) -> Router {
91    Router::new().route(
92        "/metrics",
93        get(move || {
94            let registry = registry.clone();
95            async move { registry.render_prometheus() }
96        }),
97    )
98}