Skip to main content

nidus_http/middleware/
metrics.rs

1use std::{
2    borrow::Cow,
3    collections::{BTreeMap, BTreeSet},
4    future::Future,
5    pin::Pin,
6    sync::{Arc, Mutex},
7    task::{Context, Poll},
8    time::{Duration, Instant},
9};
10
11use axum::{Router, routing::get};
12use http::{Method, Request, Response, StatusCode};
13use tower::{Layer, Service};
14
15/// Creates a metrics hook layer without a stable route label.
16///
17/// The layer will use Axum's [`axum::extract::MatchedPath`] extension when it is
18/// available, otherwise metrics are recorded with route `"<unknown>"`.
19pub fn metrics_layer<H>(hook: H) -> MetricsLayer<H>
20where
21    H: HttpMetricsHook,
22{
23    MetricsLayer::new(hook)
24}
25
26/// Creates a metrics hook layer that records a stable route label.
27///
28/// Use this for route-specific layers when you want stable labels independent
29/// of Axum extension timing.
30pub fn route_metrics_layer<H>(route: impl Into<Cow<'static, str>>, hook: H) -> MetricsLayer<H>
31where
32    H: HttpMetricsHook,
33{
34    MetricsLayer::new(hook).route(route)
35}
36
37/// Backend-neutral hook for recording HTTP request metrics.
38///
39/// Implement this trait to bridge Nidus' middleware lifecycle into a concrete
40/// metrics backend. Hooks are called in-process: one `on_request` before the
41/// inner service, one `on_response` after a response, or `on_error` if the inner
42/// service returns an error before producing a response.
43pub trait HttpMetricsHook: Clone + Send + Sync + 'static {
44    /// Records that a request entered the service.
45    fn on_request(&self, method: &Method, route: Option<&str>);
46
47    /// Records that a response left the service.
48    fn on_response(
49        &self,
50        method: &Method,
51        route: Option<&str>,
52        status: StatusCode,
53        latency: Duration,
54    );
55
56    /// Records that the inner service returned an error before producing a response.
57    fn on_error(&self, _method: &Method, _route: Option<&str>, _latency: Duration) {}
58}
59
60/// In-memory Prometheus-format HTTP metrics collector.
61///
62/// This collector stores counters and bounded duration histograms in process
63/// memory and renders Prometheus text exposition. It is useful for small
64/// services, examples, and tests; it is not a durable metrics store and values
65/// reset on process restart. The default exclusions are `/health/live`,
66/// `/health/ready`, and `/metrics`.
67///
68/// # Label cardinality
69///
70/// By default the collector records every distinct route label it observes, so
71/// the caller is responsible for keeping cardinality bounded — prefer route
72/// patterns (e.g. `"/users/:id"`) over concrete paths. To harden against
73/// accidental high-cardinality labels (which would grow memory without bound in
74/// a long-running process), apply [`PrometheusMetrics::with_max_series`]: once
75/// the configured number of distinct route labels has been admitted, every
76/// further distinct label collapses into a single `"<overflow>"` route.
77///
78/// ```ignore
79/// use axum::{Router, routing::get};
80/// use nidus_http::middleware::{PrometheusMetrics, route_metrics_layer};
81///
82/// let metrics = PrometheusMetrics::new();
83/// let app = Router::new()
84///     .route("/users/:id", get(show_user))
85///     .route_layer(route_metrics_layer("/users/:id", metrics.clone()))
86///     .merge(metrics.routes());
87/// ```
88#[derive(Clone, Debug)]
89pub struct PrometheusMetrics {
90    state: Arc<Mutex<PrometheusState>>,
91    excluded_routes: Arc<BTreeSet<String>>,
92    max_series: Option<usize>,
93}
94
95impl PrometheusMetrics {
96    /// Creates a Prometheus metrics collector with default internal route exclusions.
97    ///
98    /// The collector is unbounded by default (every distinct route label is
99    /// recorded); use [`Self::with_max_series`] to cap label cardinality.
100    pub fn new() -> Self {
101        Self {
102            state: Arc::new(Mutex::new(PrometheusState::default())),
103            excluded_routes: Arc::new(BTreeSet::from([
104                "/health/live".to_owned(),
105                "/health/ready".to_owned(),
106                "/metrics".to_owned(),
107            ])),
108            max_series: None,
109        }
110    }
111
112    /// Adds a route pattern to exclude from recording.
113    ///
114    /// Match the exact route label emitted by the metrics layer, such as a
115    /// static route supplied to [`route_metrics_layer`] or an Axum matched path.
116    pub fn exclude_route(mut self, route: impl Into<String>) -> Self {
117        Arc::make_mut(&mut self.excluded_routes).insert(route.into());
118        self
119    }
120
121    /// Bounds the number of distinct route labels retained in memory.
122    ///
123    /// The first `max_series` distinct route labels are recorded normally; any
124    /// further distinct label collapses into a single shared `"<overflow>"`
125    /// route. This prevents unbounded memory growth when a layer accidentally
126    /// emits high-cardinality labels (for example concrete request paths) while
127    /// still keeping the already-admitted routes intact. Without this cap the
128    /// collector records every distinct label it observes.
129    pub fn with_max_series(mut self, max_series: usize) -> Self {
130        self.max_series = Some(max_series);
131        self
132    }
133
134    /// Creates a metrics layer backed by this collector.
135    ///
136    /// The layer records request totals, errors, in-flight counts, and bounded
137    /// duration histograms. It does not expose a scrape endpoint; use
138    /// [`Self::routes`] or [`Self::routes_at`] for that.
139    pub fn layer(&self) -> MetricsLayer<Self> {
140        MetricsLayer::new(self.clone())
141    }
142
143    /// Creates a `/metrics` route for this collector.
144    pub fn routes(&self) -> Router {
145        self.routes_at("/metrics")
146    }
147
148    /// Creates a metrics route at a custom path.
149    pub fn routes_at(&self, path: &'static str) -> Router {
150        let metrics = self.clone();
151        Router::new().route(path, get(move || async move { metrics.render() }))
152    }
153
154    /// Renders metrics in Prometheus text exposition format.
155    ///
156    /// The output includes `nidus_http_requests_total`,
157    /// `nidus_http_request_duration_seconds_count`,
158    /// `nidus_http_request_duration_seconds_sum`,
159    /// `nidus_http_in_flight_requests`, and `nidus_http_errors_total`.
160    pub fn render(&self) -> String {
161        let state = self.snapshot();
162        render_prometheus(&state)
163    }
164
165    fn snapshot(&self) -> PrometheusState {
166        self.state
167            .lock()
168            .unwrap_or_else(|poisoned| poisoned.into_inner())
169            .clone()
170    }
171
172    fn should_record(&self, route: Option<&str>) -> bool {
173        route
174            .map(|route| !self.excluded_routes.contains(route))
175            .unwrap_or(true)
176    }
177}
178
179fn render_prometheus(state: &PrometheusState) -> String {
180    let mut output = String::new();
181    output.push_str("# TYPE nidus_http_requests_total counter\n");
182    for ((method, route, status), count) in &state.requests_total {
183        output.push_str(&format!(
184            "nidus_http_requests_total{{method=\"{}\",route=\"{}\",status=\"{}\"}} {}\n",
185            escape_label(method),
186            escape_label(route),
187            status,
188            count
189        ));
190    }
191    output.push_str("# TYPE nidus_http_request_duration_seconds histogram\n");
192    for ((method, route, status), histogram) in &state.durations {
193        for (bucket, count) in HTTP_DURATION_BUCKETS
194            .iter()
195            .zip(histogram.bucket_counts.iter())
196        {
197            output.push_str(&format!(
198                    "nidus_http_request_duration_seconds_bucket{{method=\"{}\",route=\"{}\",status=\"{}\",le=\"{}\"}} {}\n",
199                    escape_label(method),
200                    escape_label(route),
201                    status,
202                    format_bucket(*bucket),
203                    count
204                ));
205        }
206        output.push_str(&format!(
207                "nidus_http_request_duration_seconds_bucket{{method=\"{}\",route=\"{}\",status=\"{}\",le=\"+Inf\"}} {}\n",
208                escape_label(method),
209                escape_label(route),
210                status,
211                histogram.count
212            ));
213        output.push_str(&format!(
214                "nidus_http_request_duration_seconds_count{{method=\"{}\",route=\"{}\",status=\"{}\"}} {}\n",
215                escape_label(method),
216                escape_label(route),
217                status,
218                histogram.count
219            ));
220        output.push_str(&format!(
221                "nidus_http_request_duration_seconds_sum{{method=\"{}\",route=\"{}\",status=\"{}\"}} {:.6}\n",
222                escape_label(method),
223                escape_label(route),
224                status,
225                histogram.sum
226            ));
227    }
228    output.push_str("# TYPE nidus_http_in_flight_requests gauge\n");
229    for ((method, route), count) in &state.in_flight {
230        output.push_str(&format!(
231            "nidus_http_in_flight_requests{{method=\"{}\",route=\"{}\"}} {}\n",
232            escape_label(method),
233            escape_label(route),
234            count
235        ));
236    }
237    output.push_str("# TYPE nidus_http_errors_total counter\n");
238    for ((method, route, status), count) in &state.errors_total {
239        output.push_str(&format!(
240            "nidus_http_errors_total{{method=\"{}\",route=\"{}\",status=\"{}\"}} {}\n",
241            escape_label(method),
242            escape_label(route),
243            status,
244            count
245        ));
246    }
247    output
248}
249
250impl Default for PrometheusMetrics {
251    fn default() -> Self {
252        Self::new()
253    }
254}
255
256impl HttpMetricsHook for PrometheusMetrics {
257    fn on_request(&self, method: &Method, route: Option<&str>) {
258        if !self.should_record(route) {
259            return;
260        }
261        let route = route.unwrap_or("<unknown>").to_owned();
262        let mut state = self
263            .state
264            .lock()
265            .unwrap_or_else(|poisoned| poisoned.into_inner());
266        let route = match self.max_series {
267            Some(max) => state.admit_route(route, max),
268            None => route,
269        };
270        *state
271            .in_flight
272            .entry((method.as_str().to_owned(), route))
273            .or_default() += 1;
274    }
275
276    fn on_response(
277        &self,
278        method: &Method,
279        route: Option<&str>,
280        status: StatusCode,
281        latency: Duration,
282    ) {
283        if !self.should_record(route) {
284            return;
285        }
286        let method = method.as_str().to_owned();
287        let route = route.unwrap_or("<unknown>").to_owned();
288        let status = status.as_u16();
289        let mut state = self
290            .state
291            .lock()
292            .unwrap_or_else(|poisoned| poisoned.into_inner());
293        let route = match self.max_series {
294            Some(max) => state.admit_route(route, max),
295            None => route,
296        };
297        *state
298            .requests_total
299            .entry((method.clone(), route.clone(), status))
300            .or_default() += 1;
301        state
302            .durations
303            .entry((method.clone(), route.clone(), status))
304            .or_default()
305            .observe(latency);
306        if StatusCode::from_u16(status)
307            .is_ok_and(|status| status.is_client_error() || status.is_server_error())
308        {
309            *state
310                .errors_total
311                .entry((method.clone(), route.clone(), status))
312                .or_default() += 1;
313        }
314        let key = (method, route);
315        if let Some(count) = state.in_flight.get_mut(&key) {
316            *count = count.saturating_sub(1);
317        }
318    }
319
320    fn on_error(&self, method: &Method, route: Option<&str>, latency: Duration) {
321        if !self.should_record(route) {
322            return;
323        }
324        let method = method.as_str().to_owned();
325        let route = route.unwrap_or("<unknown>").to_owned();
326        let mut state = self
327            .state
328            .lock()
329            .unwrap_or_else(|poisoned| poisoned.into_inner());
330        let route = match self.max_series {
331            Some(max) => state.admit_route(route, max),
332            None => route,
333        };
334        let status = StatusCode::INTERNAL_SERVER_ERROR.as_u16();
335        *state
336            .requests_total
337            .entry((method.clone(), route.clone(), status))
338            .or_default() += 1;
339        state
340            .durations
341            .entry((method.clone(), route.clone(), status))
342            .or_default()
343            .observe(latency);
344        *state
345            .errors_total
346            .entry((method.clone(), route.clone(), status))
347            .or_default() += 1;
348        let key = (method, route);
349        if let Some(count) = state.in_flight.get_mut(&key) {
350            *count = count.saturating_sub(1);
351        }
352    }
353}
354
355#[derive(Clone, Debug, Default)]
356struct PrometheusState {
357    requests_total: BTreeMap<(String, String, u16), u64>,
358    durations: BTreeMap<(String, String, u16), DurationHistogram>,
359    in_flight: BTreeMap<(String, String), u64>,
360    errors_total: BTreeMap<(String, String, u16), u64>,
361    known_routes: BTreeSet<String>,
362}
363
364impl PrometheusState {
365    /// Returns the label to record for `route`, honoring a cap on the number of
366    /// distinct route labels. Already-admitted routes are returned unchanged;
367    /// once the cap is reached, new labels collapse to `"<overflow>"`. Callers
368    /// with no cap must skip this call entirely (the uncapped path pays nothing).
369    fn admit_route(&mut self, route: String, max_series: usize) -> String {
370        if self.known_routes.contains(&route) {
371            route
372        } else if self.known_routes.len() < max_series {
373            self.known_routes.insert(route.clone());
374            route
375        } else {
376            "<overflow>".to_owned()
377        }
378    }
379}
380
381const HTTP_DURATION_BUCKETS: [f64; 11] = [
382    0.005, 0.010, 0.025, 0.050, 0.100, 0.250, 0.500, 1.000, 2.500, 5.000, 10.000,
383];
384
385#[derive(Clone, Debug, Default)]
386struct DurationHistogram {
387    count: u64,
388    sum: f64,
389    bucket_counts: [u64; HTTP_DURATION_BUCKETS.len()],
390}
391
392impl DurationHistogram {
393    fn observe(&mut self, latency: Duration) {
394        let seconds = latency.as_secs_f64();
395        self.count += 1;
396        self.sum += seconds;
397        for (bucket, count) in HTTP_DURATION_BUCKETS
398            .iter()
399            .zip(self.bucket_counts.iter_mut())
400        {
401            if seconds <= *bucket {
402                *count += 1;
403            }
404        }
405    }
406}
407
408/// Tower layer that invokes [`HttpMetricsHook`] for request lifecycle metrics.
409///
410/// Route labels come from [`Self::route`] when set, then from Axum
411/// [`axum::extract::MatchedPath`], and finally `"<unknown>"`.
412#[derive(Clone, Debug)]
413pub struct MetricsLayer<H> {
414    hook: H,
415    route: Option<Cow<'static, str>>,
416}
417
418impl<H> MetricsLayer<H>
419where
420    H: HttpMetricsHook,
421{
422    /// Creates a metrics layer without a route label.
423    pub fn new(hook: H) -> Self {
424        Self { hook, route: None }
425    }
426
427    /// Adds a stable route label to emitted metrics.
428    ///
429    /// Prefer route patterns such as `"/users/:id"` over concrete paths to keep
430    /// label cardinality bounded.
431    pub fn route(mut self, route: impl Into<Cow<'static, str>>) -> Self {
432        self.route = Some(route.into());
433        self
434    }
435}
436
437impl<S, H> Layer<S> for MetricsLayer<H>
438where
439    H: HttpMetricsHook,
440{
441    type Service = MetricsService<S, H>;
442
443    fn layer(&self, inner: S) -> Self::Service {
444        MetricsService {
445            inner,
446            hook: self.hook.clone(),
447            route: self.route.clone(),
448        }
449    }
450}
451
452/// Service produced by [`MetricsLayer`].
453#[derive(Clone, Debug)]
454pub struct MetricsService<S, H> {
455    inner: S,
456    hook: H,
457    route: Option<Cow<'static, str>>,
458}
459
460impl<S, H, RequestBody, ResponseBody> Service<Request<RequestBody>> for MetricsService<S, H>
461where
462    S: Service<Request<RequestBody>, Response = Response<ResponseBody>> + Send + 'static,
463    S::Future: Send + 'static,
464    S::Error: Send + 'static,
465    H: HttpMetricsHook,
466    RequestBody: Send + 'static,
467    ResponseBody: Send + 'static,
468{
469    type Response = Response<ResponseBody>;
470    type Error = S::Error;
471    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
472
473    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
474        self.inner.poll_ready(cx)
475    }
476
477    fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
478        let method = request.method().clone();
479        let hook = self.hook.clone();
480        let route = self.route.clone().or_else(|| {
481            request
482                .extensions()
483                .get::<axum::extract::MatchedPath>()
484                .map(|path| Cow::Owned(path.as_str().to_owned()))
485        });
486        hook.on_request(&method, route.as_deref());
487        let started_at = Instant::now();
488        let future = self.inner.call(request);
489
490        Box::pin(async move {
491            match future.await {
492                Ok(response) => {
493                    hook.on_response(
494                        &method,
495                        route.as_deref(),
496                        response.status(),
497                        started_at.elapsed(),
498                    );
499                    Ok(response)
500                }
501                Err(error) => {
502                    hook.on_error(&method, route.as_deref(), started_at.elapsed());
503                    Err(error)
504                }
505            }
506        })
507    }
508}
509
510fn escape_label(value: &str) -> String {
511    value
512        .replace('\\', r"\\")
513        .replace('\n', r"\n")
514        .replace('"', r#"\""#)
515}
516
517fn format_bucket(bucket: f64) -> String {
518    if bucket.fract() == 0.0 {
519        format!("{bucket:.0}")
520    } else {
521        let formatted = format!("{bucket:.3}");
522        formatted.trim_end_matches('0').to_owned()
523    }
524}