1use crate::error::HttpError;
20use crate::request::RequestType;
21use crate::response::ResponseBody;
22use bytes::Bytes;
23use http::{Request, Response};
24use http_body_util::Full;
25use opentelemetry::metrics::{Histogram, Meter};
26use opentelemetry::{KeyValue, global};
27use std::borrow::Cow;
28use std::future::Future;
29use std::pin::Pin;
30use std::sync::Arc;
31use std::task::{Context, Poll};
32use std::time::Instant;
33use tower::{Layer, Service};
34
35pub type ClassifyFn = Arc<dyn Fn(&Request<Full<Bytes>>) -> Cow<'static, str> + Send + Sync>;
43
44const DURATION_BOUNDARIES_SECS: &[f64] = &[
51 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0, 150.0, 300.0, 600.0,
52];
53
54#[must_use]
57pub fn default_classify(req: &Request<Full<Bytes>>) -> Cow<'static, str> {
58 let host = req.uri().host().unwrap_or("unknown");
59 Cow::Owned(format!("{} {}", normalize_method(req.method()), host))
63}
64
65fn normalize_method(method: &http::Method) -> &'static str {
73 match *method {
74 http::Method::GET => "GET",
75 http::Method::POST => "POST",
76 http::Method::PUT => "PUT",
77 http::Method::DELETE => "DELETE",
78 http::Method::PATCH => "PATCH",
79 http::Method::HEAD => "HEAD",
80 http::Method::OPTIONS => "OPTIONS",
81 http::Method::CONNECT => "CONNECT",
82 http::Method::TRACE => "TRACE",
83 _ => "_OTHER",
84 }
85}
86
87fn error_type(err: &HttpError) -> &'static str {
95 match err {
96 HttpError::Timeout(_) => "timeout",
97 HttpError::DeadlineExceeded(_) => "deadline_exceeded",
98 HttpError::Transport(_) => "transport",
99 HttpError::Tls(_) => "tls",
100 _ => "other",
101 }
102}
103
104#[derive(Clone)]
106pub struct MetricsLayer {
107 duration: Histogram<f64>,
108 classify: ClassifyFn,
109}
110
111impl MetricsLayer {
112 #[must_use]
118 pub fn new(client_type: &str, classify: ClassifyFn) -> Self {
119 let scope = opentelemetry::InstrumentationScope::builder(client_type.to_owned()).build();
120 let meter = global::meter_with_scope(scope);
121 Self::with_meter(&meter, classify)
122 }
123
124 #[must_use]
130 pub fn with_meter(meter: &Meter, classify: ClassifyFn) -> Self {
131 let duration = meter
132 .f64_histogram("http.client.request.duration")
133 .with_description("Duration of outbound HTTP client requests")
134 .with_unit("s")
135 .with_boundaries(DURATION_BOUNDARIES_SECS.to_vec())
136 .build();
137 Self { duration, classify }
138 }
139}
140
141impl<S> Layer<S> for MetricsLayer {
142 type Service = MetricsService<S>;
143
144 fn layer(&self, inner: S) -> Self::Service {
145 MetricsService {
146 inner,
147 duration: self.duration.clone(),
148 classify: self.classify.clone(),
149 }
150 }
151}
152
153#[derive(Clone)]
155pub struct MetricsService<S> {
156 inner: S,
157 duration: Histogram<f64>,
158 classify: ClassifyFn,
159}
160
161impl<S> Service<Request<Full<Bytes>>> for MetricsService<S>
162where
163 S: Service<Request<Full<Bytes>>, Response = Response<ResponseBody>, Error = HttpError>
164 + Clone
165 + Send
166 + 'static,
167 S::Future: Send,
168{
169 type Response = Response<ResponseBody>;
170 type Error = HttpError;
171 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
172
173 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
174 self.inner.poll_ready(cx)
175 }
176
177 fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
178 let route = (self.classify)(&req).into_owned();
181 let method = normalize_method(req.method());
182 let server_address = req.uri().host().unwrap_or("unknown").to_owned();
183 let server_port = req.uri().port_u16();
184 let request_type = req
187 .extensions()
188 .get::<RequestType>()
189 .map(|rt| rt.0.clone().into_owned());
190 let duration = self.duration.clone();
191
192 let clone = self.inner.clone();
195 let mut inner = std::mem::replace(&mut self.inner, clone);
196
197 Box::pin(async move {
198 let start = Instant::now();
199 let result = inner.call(req).await;
200 let elapsed = start.elapsed().as_secs_f64();
201
202 let mut attrs = vec![
203 KeyValue::new("http.request.method", method),
204 KeyValue::new("http.route", route),
205 KeyValue::new("server.address", server_address),
206 ];
207 if let Some(port) = server_port {
210 attrs.push(KeyValue::new("server.port", i64::from(port)));
211 }
212 if let Some(rt) = request_type {
213 attrs.push(KeyValue::new("request_type", rt));
214 }
215 match &result {
216 Ok(response) => attrs.push(KeyValue::new(
217 "http.response.status_code",
218 i64::from(response.status().as_u16()),
219 )),
220 Err(e) => attrs.push(KeyValue::new("error.type", error_type(e))),
221 }
222 duration.record(elapsed, &attrs);
223
224 result
225 })
226 }
227}
228
229#[cfg(test)]
230#[cfg_attr(coverage_nightly, coverage(off))]
231mod tests {
232 use super::*;
233 use crate::request::RequestType;
234 use http::StatusCode;
235 use http_body_util::{BodyExt, Empty};
236 use opentelemetry::metrics::MeterProvider;
237 use opentelemetry_sdk::metrics::data::{AggregatedMetrics, HistogramDataPoint, MetricData};
238 use opentelemetry_sdk::metrics::{InMemoryMetricExporter, SdkMeterProvider};
239 use std::convert::Infallible;
240 use tower::{ServiceBuilder, ServiceExt, service_fn};
241
242 fn empty_response(status: StatusCode) -> Response<ResponseBody> {
243 let body: ResponseBody = Empty::<Bytes>::new()
244 .map_err(|e: Infallible| -> Box<dyn std::error::Error + Send + Sync> { match e {} })
245 .boxed();
246 Response::builder().status(status).body(body).unwrap()
247 }
248
249 fn find_duration_point(
253 exporter: &InMemoryMetricExporter,
254 expected: &[(&str, &str)],
255 ) -> Option<HistogramDataPoint<f64>> {
256 let batches = exporter.get_finished_metrics().unwrap();
257 for rm in &batches {
258 for sm in rm.scope_metrics() {
259 for metric in sm.metrics() {
260 if metric.name() != "http.client.request.duration" {
261 continue;
262 }
263 let AggregatedMetrics::F64(MetricData::Histogram(hist)) = metric.data() else {
264 continue;
265 };
266 for dp in hist.data_points() {
267 let matches = expected.iter().all(|(k, v)| {
268 dp.attributes()
269 .any(|kv| kv.key.as_str() == *k && kv.value.to_string() == *v)
270 });
271 if matches {
272 return Some(dp.clone());
273 }
274 }
275 }
276 }
277 }
278 None
279 }
280
281 fn test_provider() -> (SdkMeterProvider, InMemoryMetricExporter) {
282 let exporter = InMemoryMetricExporter::default();
283 let provider = SdkMeterProvider::builder()
284 .with_periodic_exporter(exporter.clone())
285 .build();
286 (provider, exporter)
287 }
288
289 #[tokio::test]
290 async fn records_duration_with_attributes_on_success() {
291 let (provider, exporter) = test_provider();
292 let meter = provider.meter("test-client");
293 let classify: ClassifyFn = Arc::new(|_req| Cow::Borrowed("GET /users/{id}"));
294 let layer = MetricsLayer::with_meter(&meter, classify);
295
296 let inner = service_fn(|_req: Request<Full<Bytes>>| async {
297 Ok::<_, HttpError>(empty_response(StatusCode::OK))
298 });
299 let mut svc = ServiceBuilder::new().layer(layer).service(inner);
300 let req = Request::builder()
301 .method(http::Method::GET)
302 .uri("https://example.com:8443/users/123")
303 .body(Full::new(Bytes::new()))
304 .unwrap();
305
306 let resp = svc.ready().await.unwrap().call(req).await.unwrap();
307 assert_eq!(resp.status(), StatusCode::OK);
308
309 provider.force_flush().unwrap();
310 let point = find_duration_point(
311 &exporter,
312 &[
313 ("http.request.method", "GET"),
314 ("http.route", "GET /users/{id}"),
315 ("server.address", "example.com"),
316 ("server.port", "8443"),
317 ("http.response.status_code", "200"),
318 ],
319 )
320 .expect("a duration data point with the expected attributes should be exported");
321 assert_eq!(point.count(), 1, "exactly one observation recorded");
322 }
323
324 #[tokio::test]
325 async fn records_error_type_on_transport_failure() {
326 let (provider, exporter) = test_provider();
327 let meter = provider.meter("test-client");
328 let layer = MetricsLayer::with_meter(&meter, Arc::new(default_classify));
329
330 let inner = service_fn(|_req: Request<Full<Bytes>>| async {
331 Err::<Response<ResponseBody>, _>(HttpError::Timeout(std::time::Duration::from_secs(1)))
332 });
333 let mut svc = ServiceBuilder::new().layer(layer).service(inner);
334 let req = Request::builder()
335 .method(http::Method::GET)
336 .uri("https://example.com/")
337 .body(Full::new(Bytes::new()))
338 .unwrap();
339
340 let err = svc.ready().await.unwrap().call(req).await.unwrap_err();
341 assert!(matches!(err, HttpError::Timeout(_)));
342
343 provider.force_flush().unwrap();
344 let point = find_duration_point(&exporter, &[("error.type", "timeout")])
345 .expect("a duration data point tagged error.type=timeout should be exported");
346 assert_eq!(point.count(), 1);
347 assert!(
349 point
350 .attributes()
351 .all(|kv| kv.key.as_str() != "http.response.status_code"),
352 "transport failures must not record http.response.status_code"
353 );
354 }
355
356 #[test]
357 fn default_classify_normalizes_method_and_drops_path() {
358 let req = Request::builder()
359 .method(http::Method::POST)
360 .uri("https://api.example.com/users/abc-123-uuid")
361 .body(Full::new(Bytes::new()))
362 .unwrap();
363 assert_eq!(default_classify(&req), "POST api.example.com");
365
366 let exotic = Request::builder()
367 .method(http::Method::from_bytes(b"PROPFIND").unwrap())
368 .uri("https://api.example.com/dav")
369 .body(Full::new(Bytes::new()))
370 .unwrap();
371 assert_eq!(default_classify(&exotic), "_OTHER api.example.com");
372 }
373
374 #[test]
375 fn normalize_method_caps_unknown() {
376 assert_eq!(normalize_method(&http::Method::GET), "GET");
377 let custom = http::Method::from_bytes(b"PROPFIND").unwrap();
378 assert_eq!(normalize_method(&custom), "_OTHER");
379 }
380
381 #[tokio::test]
382 async fn records_request_type_attribute_when_set() {
383 let (provider, exporter) = test_provider();
384 let meter = provider.meter("test-client");
385 let layer = MetricsLayer::with_meter(&meter, Arc::new(default_classify));
386
387 let inner = service_fn(|_req: Request<Full<Bytes>>| async {
388 Ok::<_, HttpError>(empty_response(StatusCode::OK))
389 });
390 let mut svc = ServiceBuilder::new().layer(layer).service(inner);
391
392 let mut req = Request::builder()
393 .method(http::Method::GET)
394 .uri("https://example.com/tenants/123")
395 .body(Full::new(Bytes::new()))
396 .unwrap();
397 req.extensions_mut()
398 .insert(RequestType::new("tenants_resolve"));
399
400 svc.ready().await.unwrap().call(req).await.unwrap();
401
402 provider.force_flush().unwrap();
403 let point = find_duration_point(&exporter, &[("request_type", "tenants_resolve")])
404 .expect("request_type attribute should appear in exported metric");
405 assert_eq!(point.count(), 1);
406 }
407
408 #[tokio::test]
409 async fn omits_request_type_when_not_set() {
410 let (provider, exporter) = test_provider();
411 let meter = provider.meter("test-client");
412 let layer = MetricsLayer::with_meter(&meter, Arc::new(default_classify));
413
414 let inner = service_fn(|_req: Request<Full<Bytes>>| async {
415 Ok::<_, HttpError>(empty_response(StatusCode::OK))
416 });
417 let mut svc = ServiceBuilder::new().layer(layer).service(inner);
418
419 let req = Request::builder()
420 .method(http::Method::GET)
421 .uri("https://example.com/tenants/123")
422 .body(Full::new(Bytes::new()))
423 .unwrap();
424
425 svc.ready().await.unwrap().call(req).await.unwrap();
426
427 provider.force_flush().unwrap();
428 let dp = find_duration_point(&exporter, &[("http.request.method", "GET")])
429 .expect("a data point should be exported");
430 assert!(
431 dp.attributes().all(|kv| kv.key.as_str() != "request_type"),
432 "request_type must not appear when not set"
433 );
434 }
435
436 #[test]
437 fn error_type_maps_transport_class_failures() {
438 assert_eq!(
439 error_type(&HttpError::Timeout(std::time::Duration::from_secs(1))),
440 "timeout"
441 );
442 assert_eq!(
443 error_type(&HttpError::Transport("boom".into())),
444 "transport"
445 );
446 assert_eq!(error_type(&HttpError::Overloaded), "other");
447 }
448}