Skip to main content

nidus_http/middleware/
metrics.rs

1use std::{
2    borrow::Cow,
3    collections::{BTreeMap, BTreeSet},
4    future::Future,
5    pin::Pin,
6    sync::{Arc, Mutex},
7    task::{Context, Poll},
8    time::{Duration, Instant},
9};
10
11use axum::{Router, routing::get};
12use http::{Method, Request, Response, StatusCode};
13use tower::{Layer, Service};
14
15/// Creates a metrics hook layer without a stable route label.
16///
17/// The layer will use Axum's [`axum::extract::MatchedPath`] extension when it is
18/// available, otherwise metrics are recorded with route `"<unknown>"`.
19pub fn metrics_layer<H>(hook: H) -> MetricsLayer<H>
20where
21    H: HttpMetricsHook,
22{
23    MetricsLayer::new(hook)
24}
25
26/// Creates a metrics hook layer that records a stable route label.
27///
28/// Use this for route-specific layers when you want stable labels independent
29/// of Axum extension timing.
30pub fn route_metrics_layer<H>(route: impl Into<Cow<'static, str>>, hook: H) -> MetricsLayer<H>
31where
32    H: HttpMetricsHook,
33{
34    MetricsLayer::new(hook).route(route)
35}
36
37/// Backend-neutral hook for recording HTTP request metrics.
38///
39/// Implement this trait to bridge Nidus' middleware lifecycle into a concrete
40/// metrics backend. Hooks are called in-process: one `on_request` before the
41/// inner service, one `on_response` after a response, or `on_error` if the inner
42/// service returns an error before producing a response.
43pub trait HttpMetricsHook: Clone + Send + Sync + 'static {
44    /// Records that a request entered the service.
45    fn on_request(&self, method: &Method, route: Option<&str>);
46
47    /// Records that a response left the service.
48    fn on_response(
49        &self,
50        method: &Method,
51        route: Option<&str>,
52        status: StatusCode,
53        latency: Duration,
54    );
55
56    /// Records that the inner service returned an error before producing a response.
57    fn on_error(&self, _method: &Method, _route: Option<&str>, _latency: Duration) {}
58}
59
60/// In-memory Prometheus-format HTTP metrics collector.
61///
62/// This collector stores counters and bounded duration histograms in process
63/// memory and renders Prometheus text exposition. It is useful for small
64/// services, examples, and tests; it is not a durable metrics store and values
65/// reset on process restart. The default exclusions are `/health/live`,
66/// `/health/ready`, and `/metrics`.
67///
68/// # Label cardinality
69///
70/// By default the collector records every distinct route label it observes, so
71/// the caller is responsible for keeping cardinality bounded. Prefer route
72/// patterns (e.g. `"/users/{id}"`) over concrete paths. To harden against
73/// accidental high-cardinality labels (which would grow memory without bound in
74/// a long-running process), apply [`PrometheusMetrics::with_max_series`]: once
75/// the configured number of distinct route labels has been admitted, every
76/// further distinct label collapses into a single `"<overflow>"` route.
77///
78/// ```
79/// use axum::{Router, routing::get};
80/// use nidus_http::middleware::{PrometheusMetrics, route_metrics_layer};
81/// # async fn show_user() -> &'static str { "user" }
82///
83/// let metrics = PrometheusMetrics::new();
84/// let app = Router::new()
85///     .route("/users/{id}", get(show_user))
86///     .route_layer(route_metrics_layer("/users/{id}", metrics.clone()))
87///     .merge(metrics.routes());
88/// # let _: Router = app;
89/// ```
90#[derive(Clone, Debug)]
91pub struct PrometheusMetrics {
92    state: Arc<Mutex<PrometheusState>>,
93    excluded_routes: Arc<BTreeSet<String>>,
94    max_series: Option<usize>,
95}
96
97impl PrometheusMetrics {
98    /// Creates a Prometheus metrics collector with default internal route exclusions.
99    ///
100    /// The collector is unbounded by default (every distinct route label is
101    /// recorded); use [`Self::with_max_series`] to cap label cardinality.
102    pub fn new() -> Self {
103        Self {
104            state: Arc::new(Mutex::new(PrometheusState::default())),
105            excluded_routes: Arc::new(BTreeSet::from([
106                "/health/live".to_owned(),
107                "/health/ready".to_owned(),
108                "/metrics".to_owned(),
109            ])),
110            max_series: None,
111        }
112    }
113
114    /// Adds a route pattern to exclude from recording.
115    ///
116    /// Match the exact route label emitted by the metrics layer, such as a
117    /// static route supplied to [`route_metrics_layer`] or an Axum matched path.
118    pub fn exclude_route(mut self, route: impl Into<String>) -> Self {
119        Arc::make_mut(&mut self.excluded_routes).insert(route.into());
120        self
121    }
122
123    /// Bounds the number of distinct route labels retained in memory.
124    ///
125    /// The first `max_series` distinct route labels are recorded normally; any
126    /// further distinct label collapses into a single shared `"<overflow>"`
127    /// route. This prevents unbounded memory growth when a layer accidentally
128    /// emits high-cardinality labels (for example concrete request paths) while
129    /// still keeping the already-admitted routes intact. Without this cap the
130    /// collector records every distinct label it observes.
131    pub fn with_max_series(mut self, max_series: usize) -> Self {
132        self.max_series = Some(max_series);
133        self
134    }
135
136    /// Creates a metrics layer backed by this collector.
137    ///
138    /// The layer records request totals, errors, in-flight counts, and bounded
139    /// duration histograms. It does not expose a scrape endpoint; use
140    /// [`Self::routes`] or [`Self::routes_at`] for that.
141    pub fn layer(&self) -> MetricsLayer<Self> {
142        MetricsLayer::new(self.clone())
143    }
144
145    /// Creates a `/metrics` route for this collector.
146    pub fn routes(&self) -> Router {
147        self.routes_at("/metrics")
148    }
149
150    /// Creates a metrics route at a custom path.
151    pub fn routes_at(&self, path: &'static str) -> Router {
152        let metrics = self.clone();
153        Router::new().route(path, get(move || async move { metrics.render() }))
154    }
155
156    /// Renders metrics in Prometheus text exposition format.
157    ///
158    /// The output includes `nidus_http_requests_total`,
159    /// `nidus_http_request_duration_seconds_count`,
160    /// `nidus_http_request_duration_seconds_sum`,
161    /// `nidus_http_in_flight_requests`, and `nidus_http_errors_total`.
162    pub fn render(&self) -> String {
163        let state = self.snapshot();
164        render_prometheus(&state)
165    }
166
167    fn snapshot(&self) -> PrometheusState {
168        self.state
169            .lock()
170            .unwrap_or_else(|poisoned| poisoned.into_inner())
171            .clone()
172    }
173
174    fn should_record(&self, route: Option<&str>) -> bool {
175        route
176            .map(|route| !self.excluded_routes.contains(route))
177            .unwrap_or(true)
178    }
179}
180
181fn render_prometheus(state: &PrometheusState) -> String {
182    let mut output = String::new();
183    output.push_str("# TYPE nidus_http_requests_total counter\n");
184    for ((method, route, status), count) in &state.requests_total {
185        output.push_str(&format!(
186            "nidus_http_requests_total{{method=\"{}\",route=\"{}\",status=\"{}\"}} {}\n",
187            escape_label(method),
188            escape_label(route),
189            status,
190            count
191        ));
192    }
193    output.push_str("# TYPE nidus_http_request_duration_seconds histogram\n");
194    for ((method, route, status), histogram) in &state.durations {
195        for (bucket, count) in HTTP_DURATION_BUCKETS
196            .iter()
197            .zip(histogram.bucket_counts.iter())
198        {
199            output.push_str(&format!(
200                    "nidus_http_request_duration_seconds_bucket{{method=\"{}\",route=\"{}\",status=\"{}\",le=\"{}\"}} {}\n",
201                    escape_label(method),
202                    escape_label(route),
203                    status,
204                    format_bucket(*bucket),
205                    count
206                ));
207        }
208        output.push_str(&format!(
209                "nidus_http_request_duration_seconds_bucket{{method=\"{}\",route=\"{}\",status=\"{}\",le=\"+Inf\"}} {}\n",
210                escape_label(method),
211                escape_label(route),
212                status,
213                histogram.count
214            ));
215        output.push_str(&format!(
216                "nidus_http_request_duration_seconds_count{{method=\"{}\",route=\"{}\",status=\"{}\"}} {}\n",
217                escape_label(method),
218                escape_label(route),
219                status,
220                histogram.count
221            ));
222        output.push_str(&format!(
223                "nidus_http_request_duration_seconds_sum{{method=\"{}\",route=\"{}\",status=\"{}\"}} {:.6}\n",
224                escape_label(method),
225                escape_label(route),
226                status,
227                histogram.sum
228            ));
229    }
230    output.push_str("# TYPE nidus_http_in_flight_requests gauge\n");
231    for ((method, route), count) in &state.in_flight {
232        output.push_str(&format!(
233            "nidus_http_in_flight_requests{{method=\"{}\",route=\"{}\"}} {}\n",
234            escape_label(method),
235            escape_label(route),
236            count
237        ));
238    }
239    output.push_str("# TYPE nidus_http_errors_total counter\n");
240    for ((method, route, status), count) in &state.errors_total {
241        output.push_str(&format!(
242            "nidus_http_errors_total{{method=\"{}\",route=\"{}\",status=\"{}\"}} {}\n",
243            escape_label(method),
244            escape_label(route),
245            status,
246            count
247        ));
248    }
249    output
250}
251
252impl Default for PrometheusMetrics {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258impl HttpMetricsHook for PrometheusMetrics {
259    fn on_request(&self, method: &Method, route: Option<&str>) {
260        if !self.should_record(route) {
261            return;
262        }
263        let route = route.unwrap_or("<unknown>").to_owned();
264        let mut state = self
265            .state
266            .lock()
267            .unwrap_or_else(|poisoned| poisoned.into_inner());
268        let route = match self.max_series {
269            Some(max) => state.admit_route(route, max),
270            None => route,
271        };
272        *state
273            .in_flight
274            .entry((method.as_str().to_owned(), route))
275            .or_default() += 1;
276    }
277
278    fn on_response(
279        &self,
280        method: &Method,
281        route: Option<&str>,
282        status: StatusCode,
283        latency: Duration,
284    ) {
285        if !self.should_record(route) {
286            return;
287        }
288        let method = method.as_str().to_owned();
289        let route = route.unwrap_or("<unknown>").to_owned();
290        let status = status.as_u16();
291        let mut state = self
292            .state
293            .lock()
294            .unwrap_or_else(|poisoned| poisoned.into_inner());
295        let route = match self.max_series {
296            Some(max) => state.admit_route(route, max),
297            None => route,
298        };
299        *state
300            .requests_total
301            .entry((method.clone(), route.clone(), status))
302            .or_default() += 1;
303        state
304            .durations
305            .entry((method.clone(), route.clone(), status))
306            .or_default()
307            .observe(latency);
308        if StatusCode::from_u16(status)
309            .is_ok_and(|status| status.is_client_error() || status.is_server_error())
310        {
311            *state
312                .errors_total
313                .entry((method.clone(), route.clone(), status))
314                .or_default() += 1;
315        }
316        let key = (method, route);
317        if let Some(count) = state.in_flight.get_mut(&key) {
318            *count = count.saturating_sub(1);
319        }
320    }
321
322    fn on_error(&self, method: &Method, route: Option<&str>, latency: Duration) {
323        if !self.should_record(route) {
324            return;
325        }
326        let method = method.as_str().to_owned();
327        let route = route.unwrap_or("<unknown>").to_owned();
328        let mut state = self
329            .state
330            .lock()
331            .unwrap_or_else(|poisoned| poisoned.into_inner());
332        let route = match self.max_series {
333            Some(max) => state.admit_route(route, max),
334            None => route,
335        };
336        let status = StatusCode::INTERNAL_SERVER_ERROR.as_u16();
337        *state
338            .requests_total
339            .entry((method.clone(), route.clone(), status))
340            .or_default() += 1;
341        state
342            .durations
343            .entry((method.clone(), route.clone(), status))
344            .or_default()
345            .observe(latency);
346        *state
347            .errors_total
348            .entry((method.clone(), route.clone(), status))
349            .or_default() += 1;
350        let key = (method, route);
351        if let Some(count) = state.in_flight.get_mut(&key) {
352            *count = count.saturating_sub(1);
353        }
354    }
355}
356
357#[derive(Clone, Debug, Default)]
358struct PrometheusState {
359    requests_total: BTreeMap<(String, String, u16), u64>,
360    durations: BTreeMap<(String, String, u16), DurationHistogram>,
361    in_flight: BTreeMap<(String, String), u64>,
362    errors_total: BTreeMap<(String, String, u16), u64>,
363    known_routes: BTreeSet<String>,
364}
365
366impl PrometheusState {
367    /// Returns the label to record for `route`, honoring a cap on the number of
368    /// distinct route labels. Already-admitted routes are returned unchanged;
369    /// once the cap is reached, new labels collapse to `"<overflow>"`. Callers
370    /// with no cap must skip this call entirely (the uncapped path pays nothing).
371    fn admit_route(&mut self, route: String, max_series: usize) -> String {
372        if self.known_routes.contains(&route) {
373            route
374        } else if self.known_routes.len() < max_series {
375            self.known_routes.insert(route.clone());
376            route
377        } else {
378            "<overflow>".to_owned()
379        }
380    }
381}
382
383const HTTP_DURATION_BUCKETS: [f64; 11] = [
384    0.005, 0.010, 0.025, 0.050, 0.100, 0.250, 0.500, 1.000, 2.500, 5.000, 10.000,
385];
386
387#[derive(Clone, Debug, Default)]
388struct DurationHistogram {
389    count: u64,
390    sum: f64,
391    bucket_counts: [u64; HTTP_DURATION_BUCKETS.len()],
392}
393
394impl DurationHistogram {
395    fn observe(&mut self, latency: Duration) {
396        let seconds = latency.as_secs_f64();
397        self.count += 1;
398        self.sum += seconds;
399        for (bucket, count) in HTTP_DURATION_BUCKETS
400            .iter()
401            .zip(self.bucket_counts.iter_mut())
402        {
403            if seconds <= *bucket {
404                *count += 1;
405            }
406        }
407    }
408}
409
410/// Tower layer that invokes [`HttpMetricsHook`] for request lifecycle metrics.
411///
412/// Route labels come from [`Self::route`] when set, then from Axum
413/// [`axum::extract::MatchedPath`], and finally `"<unknown>"`.
414#[derive(Clone, Debug)]
415pub struct MetricsLayer<H> {
416    hook: H,
417    route: Option<Cow<'static, str>>,
418}
419
420impl<H> MetricsLayer<H>
421where
422    H: HttpMetricsHook,
423{
424    /// Creates a metrics layer without a route label.
425    pub fn new(hook: H) -> Self {
426        Self { hook, route: None }
427    }
428
429    /// Adds a stable route label to emitted metrics.
430    ///
431    /// Prefer route patterns such as `"/users/:id"` over concrete paths to keep
432    /// label cardinality bounded.
433    pub fn route(mut self, route: impl Into<Cow<'static, str>>) -> Self {
434        self.route = Some(route.into());
435        self
436    }
437}
438
439impl<S, H> Layer<S> for MetricsLayer<H>
440where
441    H: HttpMetricsHook,
442{
443    type Service = MetricsService<S, H>;
444
445    fn layer(&self, inner: S) -> Self::Service {
446        MetricsService {
447            inner,
448            hook: self.hook.clone(),
449            route: self.route.clone(),
450        }
451    }
452}
453
454/// Service produced by [`MetricsLayer`].
455#[derive(Clone, Debug)]
456pub struct MetricsService<S, H> {
457    inner: S,
458    hook: H,
459    route: Option<Cow<'static, str>>,
460}
461
462impl<S, H, RequestBody, ResponseBody> Service<Request<RequestBody>> for MetricsService<S, H>
463where
464    S: Service<Request<RequestBody>, Response = Response<ResponseBody>> + Send + 'static,
465    S::Future: Send + 'static,
466    S::Error: Send + 'static,
467    H: HttpMetricsHook,
468    RequestBody: Send + 'static,
469    ResponseBody: Send + 'static,
470{
471    type Response = Response<ResponseBody>;
472    type Error = S::Error;
473    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
474
475    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
476        self.inner.poll_ready(cx)
477    }
478
479    fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
480        let method = request.method().clone();
481        let hook = self.hook.clone();
482        let route = self.route.clone().or_else(|| {
483            request
484                .extensions()
485                .get::<axum::extract::MatchedPath>()
486                .map(|path| Cow::Owned(path.as_str().to_owned()))
487        });
488        hook.on_request(&method, route.as_deref());
489        let started_at = Instant::now();
490        let future = self.inner.call(request);
491
492        Box::pin(async move {
493            match future.await {
494                Ok(response) => {
495                    hook.on_response(
496                        &method,
497                        route.as_deref(),
498                        response.status(),
499                        started_at.elapsed(),
500                    );
501                    Ok(response)
502                }
503                Err(error) => {
504                    hook.on_error(&method, route.as_deref(), started_at.elapsed());
505                    Err(error)
506                }
507            }
508        })
509    }
510}
511
512fn escape_label(value: &str) -> String {
513    value
514        .replace('\\', r"\\")
515        .replace('\n', r"\n")
516        .replace('"', r#"\""#)
517}
518
519fn format_bucket(bucket: f64) -> String {
520    if bucket.fract() == 0.0 {
521        format!("{bucket:.0}")
522    } else {
523        let formatted = format!("{bucket:.3}");
524        formatted.trim_end_matches('0').to_owned()
525    }
526}