Skip to main content

tonic_otel_layer/
lib.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5use std::time::Instant;
6
7use opentelemetry::metrics::{Counter, Histogram, UpDownCounter};
8use opentelemetry::{KeyValue, global};
9use pin_project::pin_project;
10use tonic::Code;
11use tonic::codegen::http::{request, response};
12use tower::{Layer, Service};
13
14#[derive(Clone)]
15pub struct MetricsLayer {
16    metrics: Metrics,
17}
18
19#[derive(Clone)]
20pub struct Metrics {
21    pub started_total: Counter<u64>,
22    pub handled_total: Counter<u64>,
23    pub handling_duration: Histogram<f64>,
24    pub active_requests: UpDownCounter<i64>,
25}
26
27const DEFAULT_HISTOGRAM_BUCKETS: [f64; 10] = [
28    0.001, 0.005, 0.01, 0.015, 0.020, 0.025, 0.50, 0.75, 1.0, 2.0,
29];
30
31#[derive(Default)]
32pub struct MetricsLayerBuilder {
33    buckets: Option<Vec<f64>>,
34    provider: Option<Arc<dyn opentelemetry::metrics::MeterProvider + Send + Sync>>,
35}
36
37impl MetricsLayerBuilder {
38    pub fn new() -> Self {
39        MetricsLayerBuilder::default()
40    }
41    pub fn with_buckets(mut self, buckets: Vec<f64>) -> Self {
42        self.buckets = Some(buckets);
43        self
44    }
45
46    pub fn with_provider<P>(mut self, provider: P) -> Self
47    where
48        P: opentelemetry::metrics::MeterProvider + Send + Sync + 'static,
49    {
50        self.provider = Some(Arc::new(provider));
51        self
52    }
53    pub fn build(self) -> MetricsLayer {
54        let provider = self.provider.unwrap_or_else(|| global::meter_provider());
55
56        let meter = provider.meter("tonic");
57
58        let buckets = self
59            .buckets
60            .unwrap_or_else(|| DEFAULT_HISTOGRAM_BUCKETS.to_vec());
61
62        let started_total = meter
63            .u64_counter("grpc_server_started")
64            .with_description("Total number of RPCs started on the server.")
65            .build();
66        let handled_total = meter
67            .u64_counter("grpc_server_handled")
68            .with_description("Total number of RPCs completed on the server.")
69            .build();
70        let handling_duration = meter
71            .f64_histogram("grpc_server_handling_duration_seconds")
72            .with_description("Rpc call duration")
73            .with_boundaries(buckets)
74            .build();
75        let active_requests = meter
76            .i64_up_down_counter("grpc_server_active_requests")
77            .with_description("Current number of active server requests.")
78            .build();
79        let metrics = Metrics {
80            started_total,
81            handled_total,
82            handling_duration,
83            active_requests,
84        };
85        MetricsLayer { metrics }
86    }
87}
88
89impl<S> Layer<S> for MetricsLayer {
90    type Service = MetricsService<S>;
91
92    fn layer(&self, inner: S) -> Self::Service {
93        MetricsService {
94            service: inner,
95            metrics: self.metrics.clone(),
96        }
97    }
98}
99
100#[derive(Clone)]
101pub struct MetricsService<S> {
102    metrics: Metrics,
103    service: S,
104}
105
106impl<S, B, C> Service<request::Request<B>> for MetricsService<S>
107where
108    S: Service<request::Request<B>, Response = response::Response<C>>,
109{
110    type Response = S::Response;
111    type Error = S::Error;
112    type Future = MetricsFuture<S::Future>;
113
114    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
115        self.service.poll_ready(cx)
116    }
117
118    fn call(&mut self, req: request::Request<B>) -> Self::Future {
119        let path = req.uri().path();
120        let (service, method) = path.rsplit_once("/").expect("Path must contain a method");
121        let service = service.to_owned();
122        let method = method.to_owned();
123        let metrics = self.metrics.clone();
124        let inner = self.service.call(req);
125        MetricsFuture {
126            inner,
127            metrics,
128            service,
129            method,
130            started_at: None,
131        }
132    }
133}
134
135#[pin_project]
136pub struct MetricsFuture<F> {
137    #[pin]
138    inner: F,
139    metrics: Metrics,
140    service: String,
141    method: String,
142    started_at: Option<Instant>,
143}
144
145impl<F, B, E> Future for MetricsFuture<F>
146where
147    F: Future<Output = Result<response::Response<B>, E>>,
148{
149    type Output = F::Output;
150
151    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
152        let this = self.project();
153
154        let sm_labels = vec![
155            KeyValue::new("grpc_service", this.service.clone()),
156            KeyValue::new("grpc_method", this.method.clone()),
157        ];
158
159        let started_at = this.started_at.get_or_insert_with(|| {
160            this.metrics.active_requests.add(1, &sm_labels);
161            this.metrics.started_total.add(1, &sm_labels);
162            Instant::now()
163        });
164
165        if let Poll::Ready(res) = this.inner.poll(cx) {
166            let code = res.as_ref().map_or(Code::Unknown, |resp| {
167                resp.headers()
168                    .get("grpc-status")
169                    .map(|s| Code::from_bytes(s.as_bytes()))
170                    .unwrap_or(Code::Ok)
171            });
172            let smc_labels = [
173                KeyValue::new("grpc_service", this.service.clone()),
174                KeyValue::new("grpc_method", this.method.clone()),
175                KeyValue::new("grpc_code", format!("{:?}", code)),
176            ];
177            let elapsed = started_at.elapsed().as_secs_f64();
178            this.metrics.active_requests.add(-1, &sm_labels);
179            this.metrics.handled_total.add(1, &smc_labels);
180            this.metrics.handling_duration.record(elapsed, &smc_labels);
181
182            Poll::Ready(res)
183        } else {
184            Poll::Pending
185        }
186    }
187}