use crate::error::HttpError;
use crate::request::RequestType;
use crate::response::ResponseBody;
use bytes::Bytes;
use http::{Request, Response};
use http_body_util::Full;
use opentelemetry::metrics::{Histogram, Meter};
use opentelemetry::{KeyValue, global};
use std::borrow::Cow;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;
use tower::{Layer, Service};
pub type ClassifyFn = Arc<dyn Fn(&Request<Full<Bytes>>) -> Cow<'static, str> + Send + Sync>;
const DURATION_BOUNDARIES_SECS: &[f64] = &[
0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0, 150.0, 300.0, 600.0,
];
#[must_use]
pub fn default_classify(req: &Request<Full<Bytes>>) -> Cow<'static, str> {
let host = req.uri().host().unwrap_or("unknown");
Cow::Owned(format!("{} {}", normalize_method(req.method()), host))
}
fn normalize_method(method: &http::Method) -> &'static str {
match *method {
http::Method::GET => "GET",
http::Method::POST => "POST",
http::Method::PUT => "PUT",
http::Method::DELETE => "DELETE",
http::Method::PATCH => "PATCH",
http::Method::HEAD => "HEAD",
http::Method::OPTIONS => "OPTIONS",
http::Method::CONNECT => "CONNECT",
http::Method::TRACE => "TRACE",
_ => "_OTHER",
}
}
fn error_type(err: &HttpError) -> &'static str {
match err {
HttpError::Timeout(_) => "timeout",
HttpError::DeadlineExceeded(_) => "deadline_exceeded",
HttpError::Transport(_) => "transport",
HttpError::Tls(_) => "tls",
_ => "other",
}
}
#[derive(Clone)]
pub struct MetricsLayer {
duration: Histogram<f64>,
classify: ClassifyFn,
}
impl MetricsLayer {
#[must_use]
pub fn new(client_type: &str, classify: ClassifyFn) -> Self {
let scope = opentelemetry::InstrumentationScope::builder(client_type.to_owned()).build();
let meter = global::meter_with_scope(scope);
Self::with_meter(&meter, classify)
}
#[must_use]
pub fn with_meter(meter: &Meter, classify: ClassifyFn) -> Self {
let duration = meter
.f64_histogram("http.client.request.duration")
.with_description("Duration of outbound HTTP client requests")
.with_unit("s")
.with_boundaries(DURATION_BOUNDARIES_SECS.to_vec())
.build();
Self { duration, classify }
}
}
impl<S> Layer<S> for MetricsLayer {
type Service = MetricsService<S>;
fn layer(&self, inner: S) -> Self::Service {
MetricsService {
inner,
duration: self.duration.clone(),
classify: self.classify.clone(),
}
}
}
#[derive(Clone)]
pub struct MetricsService<S> {
inner: S,
duration: Histogram<f64>,
classify: ClassifyFn,
}
impl<S> Service<Request<Full<Bytes>>> for MetricsService<S>
where
S: Service<Request<Full<Bytes>>, Response = Response<ResponseBody>, Error = HttpError>
+ Clone
+ Send
+ 'static,
S::Future: Send,
{
type Response = Response<ResponseBody>;
type Error = HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
let route = (self.classify)(&req).into_owned();
let method = normalize_method(req.method());
let server_address = req.uri().host().unwrap_or("unknown").to_owned();
let server_port = req.uri().port_u16();
let request_type = req
.extensions()
.get::<RequestType>()
.map(|rt| rt.0.clone().into_owned());
let duration = self.duration.clone();
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
Box::pin(async move {
let start = Instant::now();
let result = inner.call(req).await;
let elapsed = start.elapsed().as_secs_f64();
let mut attrs = vec![
KeyValue::new("http.request.method", method),
KeyValue::new("http.route", route),
KeyValue::new("server.address", server_address),
];
if let Some(port) = server_port {
attrs.push(KeyValue::new("server.port", i64::from(port)));
}
if let Some(rt) = request_type {
attrs.push(KeyValue::new("request_type", rt));
}
match &result {
Ok(response) => attrs.push(KeyValue::new(
"http.response.status_code",
i64::from(response.status().as_u16()),
)),
Err(e) => attrs.push(KeyValue::new("error.type", error_type(e))),
}
duration.record(elapsed, &attrs);
result
})
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
use crate::request::RequestType;
use http::StatusCode;
use http_body_util::{BodyExt, Empty};
use opentelemetry::metrics::MeterProvider;
use opentelemetry_sdk::metrics::data::{AggregatedMetrics, HistogramDataPoint, MetricData};
use opentelemetry_sdk::metrics::{InMemoryMetricExporter, SdkMeterProvider};
use std::convert::Infallible;
use tower::{ServiceBuilder, ServiceExt, service_fn};
fn empty_response(status: StatusCode) -> Response<ResponseBody> {
let body: ResponseBody = Empty::<Bytes>::new()
.map_err(|e: Infallible| -> Box<dyn std::error::Error + Send + Sync> { match e {} })
.boxed();
Response::builder().status(status).body(body).unwrap()
}
fn find_duration_point(
exporter: &InMemoryMetricExporter,
expected: &[(&str, &str)],
) -> Option<HistogramDataPoint<f64>> {
let batches = exporter.get_finished_metrics().unwrap();
for rm in &batches {
for sm in rm.scope_metrics() {
for metric in sm.metrics() {
if metric.name() != "http.client.request.duration" {
continue;
}
let AggregatedMetrics::F64(MetricData::Histogram(hist)) = metric.data() else {
continue;
};
for dp in hist.data_points() {
let matches = expected.iter().all(|(k, v)| {
dp.attributes()
.any(|kv| kv.key.as_str() == *k && kv.value.to_string() == *v)
});
if matches {
return Some(dp.clone());
}
}
}
}
}
None
}
fn test_provider() -> (SdkMeterProvider, InMemoryMetricExporter) {
let exporter = InMemoryMetricExporter::default();
let provider = SdkMeterProvider::builder()
.with_periodic_exporter(exporter.clone())
.build();
(provider, exporter)
}
#[tokio::test]
async fn records_duration_with_attributes_on_success() {
let (provider, exporter) = test_provider();
let meter = provider.meter("test-client");
let classify: ClassifyFn = Arc::new(|_req| Cow::Borrowed("GET /users/{id}"));
let layer = MetricsLayer::with_meter(&meter, classify);
let inner = service_fn(|_req: Request<Full<Bytes>>| async {
Ok::<_, HttpError>(empty_response(StatusCode::OK))
});
let mut svc = ServiceBuilder::new().layer(layer).service(inner);
let req = Request::builder()
.method(http::Method::GET)
.uri("https://example.com:8443/users/123")
.body(Full::new(Bytes::new()))
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
provider.force_flush().unwrap();
let point = find_duration_point(
&exporter,
&[
("http.request.method", "GET"),
("http.route", "GET /users/{id}"),
("server.address", "example.com"),
("server.port", "8443"),
("http.response.status_code", "200"),
],
)
.expect("a duration data point with the expected attributes should be exported");
assert_eq!(point.count(), 1, "exactly one observation recorded");
}
#[tokio::test]
async fn records_error_type_on_transport_failure() {
let (provider, exporter) = test_provider();
let meter = provider.meter("test-client");
let layer = MetricsLayer::with_meter(&meter, Arc::new(default_classify));
let inner = service_fn(|_req: Request<Full<Bytes>>| async {
Err::<Response<ResponseBody>, _>(HttpError::Timeout(std::time::Duration::from_secs(1)))
});
let mut svc = ServiceBuilder::new().layer(layer).service(inner);
let req = Request::builder()
.method(http::Method::GET)
.uri("https://example.com/")
.body(Full::new(Bytes::new()))
.unwrap();
let err = svc.ready().await.unwrap().call(req).await.unwrap_err();
assert!(matches!(err, HttpError::Timeout(_)));
provider.force_flush().unwrap();
let point = find_duration_point(&exporter, &[("error.type", "timeout")])
.expect("a duration data point tagged error.type=timeout should be exported");
assert_eq!(point.count(), 1);
assert!(
point
.attributes()
.all(|kv| kv.key.as_str() != "http.response.status_code"),
"transport failures must not record http.response.status_code"
);
}
#[test]
fn default_classify_normalizes_method_and_drops_path() {
let req = Request::builder()
.method(http::Method::POST)
.uri("https://api.example.com/users/abc-123-uuid")
.body(Full::new(Bytes::new()))
.unwrap();
assert_eq!(default_classify(&req), "POST api.example.com");
let exotic = Request::builder()
.method(http::Method::from_bytes(b"PROPFIND").unwrap())
.uri("https://api.example.com/dav")
.body(Full::new(Bytes::new()))
.unwrap();
assert_eq!(default_classify(&exotic), "_OTHER api.example.com");
}
#[test]
fn normalize_method_caps_unknown() {
assert_eq!(normalize_method(&http::Method::GET), "GET");
let custom = http::Method::from_bytes(b"PROPFIND").unwrap();
assert_eq!(normalize_method(&custom), "_OTHER");
}
#[tokio::test]
async fn records_request_type_attribute_when_set() {
let (provider, exporter) = test_provider();
let meter = provider.meter("test-client");
let layer = MetricsLayer::with_meter(&meter, Arc::new(default_classify));
let inner = service_fn(|_req: Request<Full<Bytes>>| async {
Ok::<_, HttpError>(empty_response(StatusCode::OK))
});
let mut svc = ServiceBuilder::new().layer(layer).service(inner);
let mut req = Request::builder()
.method(http::Method::GET)
.uri("https://example.com/tenants/123")
.body(Full::new(Bytes::new()))
.unwrap();
req.extensions_mut()
.insert(RequestType::new("tenants_resolve"));
svc.ready().await.unwrap().call(req).await.unwrap();
provider.force_flush().unwrap();
let point = find_duration_point(&exporter, &[("request_type", "tenants_resolve")])
.expect("request_type attribute should appear in exported metric");
assert_eq!(point.count(), 1);
}
#[tokio::test]
async fn omits_request_type_when_not_set() {
let (provider, exporter) = test_provider();
let meter = provider.meter("test-client");
let layer = MetricsLayer::with_meter(&meter, Arc::new(default_classify));
let inner = service_fn(|_req: Request<Full<Bytes>>| async {
Ok::<_, HttpError>(empty_response(StatusCode::OK))
});
let mut svc = ServiceBuilder::new().layer(layer).service(inner);
let req = Request::builder()
.method(http::Method::GET)
.uri("https://example.com/tenants/123")
.body(Full::new(Bytes::new()))
.unwrap();
svc.ready().await.unwrap().call(req).await.unwrap();
provider.force_flush().unwrap();
let dp = find_duration_point(&exporter, &[("http.request.method", "GET")])
.expect("a data point should be exported");
assert!(
dp.attributes().all(|kv| kv.key.as_str() != "request_type"),
"request_type must not appear when not set"
);
}
#[test]
fn error_type_maps_transport_class_failures() {
assert_eq!(
error_type(&HttpError::Timeout(std::time::Duration::from_secs(1))),
"timeout"
);
assert_eq!(
error_type(&HttpError::Transport("boom".into())),
"transport"
);
assert_eq!(error_type(&HttpError::Overloaded), "other");
}
}