Skip to main content

toolkit_http/layers/
metrics.rs

1//! Tower layer that records OpenTelemetry metrics for outbound HTTP requests.
2//!
3//! Emits a single instrument following [OpenTelemetry HTTP client semantic
4//! conventions][semconv]:
5//! - `http.client.request.duration` — histogram (seconds)
6//!
7//! Attributes: `http.request.method`, `http.route`, `server.address`,
8//! `server.port` (when the URI carries an explicit port), and
9//! `http.response.status_code` (on success) or `error.type` (on failure).
10//!
11//! Modeled after go-appkit's `MetricsRoundTripper`: one duration histogram plus
12//! a build-time request classifier that produces the bounded `http.route`
13//! label, preventing cardinality explosion from raw paths. Like the Go version,
14//! this layer sits outside the retry loop, so it observes one logical request
15//! regardless of transport-level retries.
16//!
17//! [semconv]: https://opentelemetry.io/docs/specs/semconv/http/http-metrics/
18
19use crate::error::HttpError;
20use crate::request::RequestType;
21use crate::response::ResponseBody;
22use bytes::Bytes;
23use http::{Request, Response};
24use http_body_util::Full;
25use opentelemetry::metrics::{Histogram, Meter};
26use opentelemetry::{KeyValue, global};
27use std::borrow::Cow;
28use std::future::Future;
29use std::pin::Pin;
30use std::sync::Arc;
31use std::task::{Context, Poll};
32use std::time::Instant;
33use tower::{Layer, Service};
34
35/// Classifies a request into a low-cardinality route label (the `http.route`
36/// attribute). Set once when the client is built; invoked on every request.
37///
38/// This is the Rust analogue of go-appkit's `ClassifyRequest` callback. It must
39/// return a *bounded* set of values (e.g. route templates like
40/// `GET /users/{id}`), never a raw path containing identifiers, otherwise the
41/// metric cardinality is unbounded.
42pub type ClassifyFn = Arc<dyn Fn(&Request<Full<Bytes>>) -> Cow<'static, str> + Send + Sync>;
43
44/// Explicit histogram bucket boundaries (seconds) for request duration.
45///
46/// The SDK's default boundaries are count-oriented (hundreds–thousands) and
47/// useless for a seconds-valued duration. These mirror go-appkit's buckets with
48/// finer low-end resolution, so client-side percentiles stay meaningful and
49/// comparable across the two implementations.
50const DURATION_BOUNDARIES_SECS: &[f64] = &[
51    0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0, 150.0, 300.0, 600.0,
52];
53
54/// Default classifier producing `"METHOD host"` (mirrors go-appkit's default
55/// `summary`). Never returns a raw path, so it cannot blow up cardinality.
56#[must_use]
57pub fn default_classify(req: &Request<Full<Bytes>>) -> Cow<'static, str> {
58    let host = req.uri().host().unwrap_or("unknown");
59    // Use the normalized method (`_OTHER` for unknown verbs) so the route label
60    // stays consistent with the `http.request.method` attribute and cannot be
61    // widened by arbitrary method strings.
62    Cow::Owned(format!("{} {}", normalize_method(req.method()), host))
63}
64
65/// Normalize HTTP method per [OTel semantic conventions][semconv].
66///
67/// Unknown methods map to `_OTHER` to bound attribute cardinality. Mirrors the
68/// server-side helper in `api-gateway`'s `http_metrics` middleware (duplicated
69/// here so `toolkit-http` stays free of a dependency on that gear).
70///
71/// [semconv]: https://opentelemetry.io/docs/specs/semconv/http/http-metrics/
72fn normalize_method(method: &http::Method) -> &'static str {
73    match *method {
74        http::Method::GET => "GET",
75        http::Method::POST => "POST",
76        http::Method::PUT => "PUT",
77        http::Method::DELETE => "DELETE",
78        http::Method::PATCH => "PATCH",
79        http::Method::HEAD => "HEAD",
80        http::Method::OPTIONS => "OPTIONS",
81        http::Method::CONNECT => "CONNECT",
82        http::Method::TRACE => "TRACE",
83        _ => "_OTHER",
84    }
85}
86
87/// Low-cardinality `error.type` value for a transport-level failure.
88///
89/// This layer sits inside the load-shed/buffer layers and outside retry, and the
90/// inner service returns `Ok(Response)` for all HTTP statuses (including
91/// 4xx/5xx). Only transport-class failures reach the `Err` arm here — the
92/// `OTel` analogue of go-appkit's `status="0"`. Everything else collapses to
93/// `"other"` rather than enumerating variants that cannot occur at this point.
94fn error_type(err: &HttpError) -> &'static str {
95    match err {
96        HttpError::Timeout(_) => "timeout",
97        HttpError::DeadlineExceeded(_) => "deadline_exceeded",
98        HttpError::Transport(_) => "transport",
99        HttpError::Tls(_) => "tls",
100        _ => "other",
101    }
102}
103
104/// Tower layer recording HTTP client request-duration metrics.
105#[derive(Clone)]
106pub struct MetricsLayer {
107    duration: Histogram<f64>,
108    classify: ClassifyFn,
109}
110
111impl MetricsLayer {
112    /// Create a metrics layer.
113    ///
114    /// `client_type` names the OpenTelemetry instrumentation scope (the meter),
115    /// mirroring go-appkit's `ClientType` and the server-side `gear_name`.
116    /// `classify` produces the bounded `http.route` attribute for each request.
117    #[must_use]
118    pub fn new(client_type: &str, classify: ClassifyFn) -> Self {
119        let scope = opentelemetry::InstrumentationScope::builder(client_type.to_owned()).build();
120        let meter = global::meter_with_scope(scope);
121        Self::with_meter(&meter, classify)
122    }
123
124    /// Create a metrics layer using a caller-provided [`Meter`].
125    ///
126    /// Use this to bind the instrument to a specific `MeterProvider` instead of
127    /// the global one (e.g. for tests or multi-provider setups). The instrument
128    /// name, unit, bucket boundaries, and behavior are identical to [`new`](Self::new).
129    #[must_use]
130    pub fn with_meter(meter: &Meter, classify: ClassifyFn) -> Self {
131        let duration = meter
132            .f64_histogram("http.client.request.duration")
133            .with_description("Duration of outbound HTTP client requests")
134            .with_unit("s")
135            .with_boundaries(DURATION_BOUNDARIES_SECS.to_vec())
136            .build();
137        Self { duration, classify }
138    }
139}
140
141impl<S> Layer<S> for MetricsLayer {
142    type Service = MetricsService<S>;
143
144    fn layer(&self, inner: S) -> Self::Service {
145        MetricsService {
146            inner,
147            duration: self.duration.clone(),
148            classify: self.classify.clone(),
149        }
150    }
151}
152
153/// Service that records a duration metric for each outbound request.
154#[derive(Clone)]
155pub struct MetricsService<S> {
156    inner: S,
157    duration: Histogram<f64>,
158    classify: ClassifyFn,
159}
160
161impl<S> Service<Request<Full<Bytes>>> for MetricsService<S>
162where
163    S: Service<Request<Full<Bytes>>, Response = Response<ResponseBody>, Error = HttpError>
164        + Clone
165        + Send
166        + 'static,
167    S::Future: Send,
168{
169    type Response = Response<ResponseBody>;
170    type Error = HttpError;
171    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
172
173    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
174        self.inner.poll_ready(cx)
175    }
176
177    fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
178        // Compute attributes from the request up front; the request itself is
179        // moved into the inner service.
180        let route = (self.classify)(&req).into_owned();
181        let method = normalize_method(req.method());
182        let server_address = req.uri().host().unwrap_or("unknown").to_owned();
183        let server_port = req.uri().port_u16();
184        // Read request_type set by RequestBuilder::with_request_type — mirrors
185        // go-appkit's GetRequestTypeFromContext.
186        let request_type = req
187            .extensions()
188            .get::<RequestType>()
189            .map(|rt| rt.0.clone().into_owned());
190        let duration = self.duration.clone();
191
192        // Swap so we call the instance that was poll_ready'd, leaving a fresh
193        // clone for the next poll_ready cycle (Tower Service contract).
194        let clone = self.inner.clone();
195        let mut inner = std::mem::replace(&mut self.inner, clone);
196
197        Box::pin(async move {
198            let start = Instant::now();
199            let result = inner.call(req).await;
200            let elapsed = start.elapsed().as_secs_f64();
201
202            let mut attrs = vec![
203                KeyValue::new("http.request.method", method),
204                KeyValue::new("http.route", route),
205                KeyValue::new("server.address", server_address),
206            ];
207            // OTel client semconv pairs server.address with server.port; only
208            // explicit ports are present in the URI (default 80/443 are elided).
209            if let Some(port) = server_port {
210                attrs.push(KeyValue::new("server.port", i64::from(port)));
211            }
212            if let Some(rt) = request_type {
213                attrs.push(KeyValue::new("request_type", rt));
214            }
215            match &result {
216                Ok(response) => attrs.push(KeyValue::new(
217                    "http.response.status_code",
218                    i64::from(response.status().as_u16()),
219                )),
220                Err(e) => attrs.push(KeyValue::new("error.type", error_type(e))),
221            }
222            duration.record(elapsed, &attrs);
223
224            result
225        })
226    }
227}
228
229#[cfg(test)]
230#[cfg_attr(coverage_nightly, coverage(off))]
231mod tests {
232    use super::*;
233    use crate::request::RequestType;
234    use http::StatusCode;
235    use http_body_util::{BodyExt, Empty};
236    use opentelemetry::metrics::MeterProvider;
237    use opentelemetry_sdk::metrics::data::{AggregatedMetrics, HistogramDataPoint, MetricData};
238    use opentelemetry_sdk::metrics::{InMemoryMetricExporter, SdkMeterProvider};
239    use std::convert::Infallible;
240    use tower::{ServiceBuilder, ServiceExt, service_fn};
241
242    fn empty_response(status: StatusCode) -> Response<ResponseBody> {
243        let body: ResponseBody = Empty::<Bytes>::new()
244            .map_err(|e: Infallible| -> Box<dyn std::error::Error + Send + Sync> { match e {} })
245            .boxed();
246        Response::builder().status(status).body(body).unwrap()
247    }
248
249    /// Collect the histogram data point for `http.client.request.duration` whose
250    /// attributes contain every `(key, value)` in `expected`. Returns `None` if
251    /// no matching point was exported.
252    fn find_duration_point(
253        exporter: &InMemoryMetricExporter,
254        expected: &[(&str, &str)],
255    ) -> Option<HistogramDataPoint<f64>> {
256        let batches = exporter.get_finished_metrics().unwrap();
257        for rm in &batches {
258            for sm in rm.scope_metrics() {
259                for metric in sm.metrics() {
260                    if metric.name() != "http.client.request.duration" {
261                        continue;
262                    }
263                    let AggregatedMetrics::F64(MetricData::Histogram(hist)) = metric.data() else {
264                        continue;
265                    };
266                    for dp in hist.data_points() {
267                        let matches = expected.iter().all(|(k, v)| {
268                            dp.attributes()
269                                .any(|kv| kv.key.as_str() == *k && kv.value.to_string() == *v)
270                        });
271                        if matches {
272                            return Some(dp.clone());
273                        }
274                    }
275                }
276            }
277        }
278        None
279    }
280
281    fn test_provider() -> (SdkMeterProvider, InMemoryMetricExporter) {
282        let exporter = InMemoryMetricExporter::default();
283        let provider = SdkMeterProvider::builder()
284            .with_periodic_exporter(exporter.clone())
285            .build();
286        (provider, exporter)
287    }
288
289    #[tokio::test]
290    async fn records_duration_with_attributes_on_success() {
291        let (provider, exporter) = test_provider();
292        let meter = provider.meter("test-client");
293        let classify: ClassifyFn = Arc::new(|_req| Cow::Borrowed("GET /users/{id}"));
294        let layer = MetricsLayer::with_meter(&meter, classify);
295
296        let inner = service_fn(|_req: Request<Full<Bytes>>| async {
297            Ok::<_, HttpError>(empty_response(StatusCode::OK))
298        });
299        let mut svc = ServiceBuilder::new().layer(layer).service(inner);
300        let req = Request::builder()
301            .method(http::Method::GET)
302            .uri("https://example.com:8443/users/123")
303            .body(Full::new(Bytes::new()))
304            .unwrap();
305
306        let resp = svc.ready().await.unwrap().call(req).await.unwrap();
307        assert_eq!(resp.status(), StatusCode::OK);
308
309        provider.force_flush().unwrap();
310        let point = find_duration_point(
311            &exporter,
312            &[
313                ("http.request.method", "GET"),
314                ("http.route", "GET /users/{id}"),
315                ("server.address", "example.com"),
316                ("server.port", "8443"),
317                ("http.response.status_code", "200"),
318            ],
319        )
320        .expect("a duration data point with the expected attributes should be exported");
321        assert_eq!(point.count(), 1, "exactly one observation recorded");
322    }
323
324    #[tokio::test]
325    async fn records_error_type_on_transport_failure() {
326        let (provider, exporter) = test_provider();
327        let meter = provider.meter("test-client");
328        let layer = MetricsLayer::with_meter(&meter, Arc::new(default_classify));
329
330        let inner = service_fn(|_req: Request<Full<Bytes>>| async {
331            Err::<Response<ResponseBody>, _>(HttpError::Timeout(std::time::Duration::from_secs(1)))
332        });
333        let mut svc = ServiceBuilder::new().layer(layer).service(inner);
334        let req = Request::builder()
335            .method(http::Method::GET)
336            .uri("https://example.com/")
337            .body(Full::new(Bytes::new()))
338            .unwrap();
339
340        let err = svc.ready().await.unwrap().call(req).await.unwrap_err();
341        assert!(matches!(err, HttpError::Timeout(_)));
342
343        provider.force_flush().unwrap();
344        let point = find_duration_point(&exporter, &[("error.type", "timeout")])
345            .expect("a duration data point tagged error.type=timeout should be exported");
346        assert_eq!(point.count(), 1);
347        // Failures must not carry a status code.
348        assert!(
349            point
350                .attributes()
351                .all(|kv| kv.key.as_str() != "http.response.status_code"),
352            "transport failures must not record http.response.status_code"
353        );
354    }
355
356    #[test]
357    fn default_classify_normalizes_method_and_drops_path() {
358        let req = Request::builder()
359            .method(http::Method::POST)
360            .uri("https://api.example.com/users/abc-123-uuid")
361            .body(Full::new(Bytes::new()))
362            .unwrap();
363        // Raw path with an identifier must never leak into the label.
364        assert_eq!(default_classify(&req), "POST api.example.com");
365
366        let exotic = Request::builder()
367            .method(http::Method::from_bytes(b"PROPFIND").unwrap())
368            .uri("https://api.example.com/dav")
369            .body(Full::new(Bytes::new()))
370            .unwrap();
371        assert_eq!(default_classify(&exotic), "_OTHER api.example.com");
372    }
373
374    #[test]
375    fn normalize_method_caps_unknown() {
376        assert_eq!(normalize_method(&http::Method::GET), "GET");
377        let custom = http::Method::from_bytes(b"PROPFIND").unwrap();
378        assert_eq!(normalize_method(&custom), "_OTHER");
379    }
380
381    #[tokio::test]
382    async fn records_request_type_attribute_when_set() {
383        let (provider, exporter) = test_provider();
384        let meter = provider.meter("test-client");
385        let layer = MetricsLayer::with_meter(&meter, Arc::new(default_classify));
386
387        let inner = service_fn(|_req: Request<Full<Bytes>>| async {
388            Ok::<_, HttpError>(empty_response(StatusCode::OK))
389        });
390        let mut svc = ServiceBuilder::new().layer(layer).service(inner);
391
392        let mut req = Request::builder()
393            .method(http::Method::GET)
394            .uri("https://example.com/tenants/123")
395            .body(Full::new(Bytes::new()))
396            .unwrap();
397        req.extensions_mut()
398            .insert(RequestType::new("tenants_resolve"));
399
400        svc.ready().await.unwrap().call(req).await.unwrap();
401
402        provider.force_flush().unwrap();
403        let point = find_duration_point(&exporter, &[("request_type", "tenants_resolve")])
404            .expect("request_type attribute should appear in exported metric");
405        assert_eq!(point.count(), 1);
406    }
407
408    #[tokio::test]
409    async fn omits_request_type_when_not_set() {
410        let (provider, exporter) = test_provider();
411        let meter = provider.meter("test-client");
412        let layer = MetricsLayer::with_meter(&meter, Arc::new(default_classify));
413
414        let inner = service_fn(|_req: Request<Full<Bytes>>| async {
415            Ok::<_, HttpError>(empty_response(StatusCode::OK))
416        });
417        let mut svc = ServiceBuilder::new().layer(layer).service(inner);
418
419        let req = Request::builder()
420            .method(http::Method::GET)
421            .uri("https://example.com/tenants/123")
422            .body(Full::new(Bytes::new()))
423            .unwrap();
424
425        svc.ready().await.unwrap().call(req).await.unwrap();
426
427        provider.force_flush().unwrap();
428        let dp = find_duration_point(&exporter, &[("http.request.method", "GET")])
429            .expect("a data point should be exported");
430        assert!(
431            dp.attributes().all(|kv| kv.key.as_str() != "request_type"),
432            "request_type must not appear when not set"
433        );
434    }
435
436    #[test]
437    fn error_type_maps_transport_class_failures() {
438        assert_eq!(
439            error_type(&HttpError::Timeout(std::time::Duration::from_secs(1))),
440            "timeout"
441        );
442        assert_eq!(
443            error_type(&HttpError::Transport("boom".into())),
444            "transport"
445        );
446        assert_eq!(error_type(&HttpError::Overloaded), "other");
447    }
448}