tower_otel/trace/
grpc.rs

1//! Middleware that adds tracing to a [`Service`] that handles gRPC requests.
2
3use std::{
4    fmt::Display,
5    future::Future,
6    pin::Pin,
7    task::{ready, Context, Poll},
8};
9
10use http::{Request, Response};
11use pin_project::pin_project;
12use tower_layer::Layer;
13use tower_service::Service;
14use tracing::{Level, Span};
15use tracing_opentelemetry::OpenTelemetrySpanExt;
16
17use crate::{
18    trace::{extractor::HeaderExtractor, injector::HeaderInjector},
19    util,
20};
21
22/// Describes the relationship between the [`Span`] and the service producing the span.
23#[derive(Clone, Copy, Debug)]
24enum SpanKind {
25    /// The span describes a request sent to some remote service.
26    Client,
27    /// The span describes the server-side handling of a request.
28    Server,
29}
30
31/// [`Layer`] that adds tracing to a [`Service`] that handles gRRC requests.
32#[derive(Clone, Debug)]
33pub struct GrpcLayer {
34    level: Level,
35    kind: SpanKind,
36}
37
38impl GrpcLayer {
39    /// [`Span`]s are constructed at the given level from server side.
40    pub fn server(level: Level) -> Self {
41        Self {
42            level,
43            kind: SpanKind::Server,
44        }
45    }
46
47    /// [`Span`]s are constructed at the given level from client side.
48    pub fn client(level: Level) -> Self {
49        Self {
50            level,
51            kind: SpanKind::Client,
52        }
53    }
54}
55
56impl<S> Layer<S> for GrpcLayer {
57    type Service = Grpc<S>;
58
59    fn layer(&self, inner: S) -> Self::Service {
60        Grpc {
61            inner,
62            level: self.level,
63            kind: self.kind,
64        }
65    }
66}
67
68/// Middleware that adds tracing to a [`Service`] that handles gRPC requests.
69#[derive(Clone, Debug)]
70pub struct Grpc<S> {
71    inner: S,
72    level: Level,
73    kind: SpanKind,
74}
75
76impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Grpc<S>
77where
78    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
79    S::Error: Display,
80{
81    type Response = S::Response;
82    type Error = S::Error;
83    type Future = ResponseFuture<S::Future>;
84
85    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86        self.inner.poll_ready(cx)
87    }
88
89    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
90        let span = make_request_span(self.level, self.kind, &mut req);
91        let inner = {
92            let _enter = span.enter();
93            self.inner.call(req)
94        };
95
96        ResponseFuture { inner, span }
97    }
98}
99
100/// Response future for [`Grpc`].
101#[pin_project]
102pub struct ResponseFuture<F> {
103    #[pin]
104    inner: F,
105    span: Span,
106}
107
108impl<F, ResBody, E> Future for ResponseFuture<F>
109where
110    F: Future<Output = Result<Response<ResBody>, E>>,
111    E: Display,
112{
113    type Output = Result<Response<ResBody>, E>;
114
115    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
116        let this = self.project();
117        let _enter = this.span.enter();
118
119        match ready!(this.inner.poll(cx)) {
120            Ok(response) => {
121                record_response(this.span, &response);
122                Poll::Ready(Ok(response))
123            }
124            Err(err) => {
125                record_error(this.span, &err);
126                Poll::Ready(Err(err))
127            }
128        }
129    }
130}
131
132/// String representation of span kind
133fn span_kind(kind: SpanKind) -> &'static str {
134    match kind {
135        SpanKind::Client => "client",
136        SpanKind::Server => "server",
137    }
138}
139
140/// Creates a new [`Span`] for the given request.
141fn make_request_span<B>(level: Level, kind: SpanKind, request: &mut Request<B>) -> Span {
142    macro_rules! make_span {
143        ($level:expr) => {{
144            use tracing::field::Empty;
145
146            tracing::span!(
147                $level,
148                "GRPC",
149                "client.address" = Empty,
150                "client.port" = Empty,
151                "error.message" = Empty,
152                "otel.kind" = span_kind(kind),
153                "otel.name" = Empty,
154                "otel.status_code" = Empty,
155                "rpc.grpc.status_code" = Empty,
156                "rpc.method" = Empty,
157                "rpc.service" = Empty,
158                "rpc.system" = "grpc",
159                "server.address" = Empty,
160                "server.port" = Empty,
161            )
162        }};
163    }
164
165    let span = match level {
166        Level::ERROR => make_span!(Level::ERROR),
167        Level::WARN => make_span!(Level::WARN),
168        Level::INFO => make_span!(Level::INFO),
169        Level::DEBUG => make_span!(Level::DEBUG),
170        Level::TRACE => make_span!(Level::TRACE),
171    };
172
173    for (header_name, header_value) in request.headers().iter() {
174        if let Ok(attribute_value) = header_value.to_str() {
175            let attribute_name = format!("rpc.grpc.request.metadata.{}", header_name);
176            span.set_attribute(attribute_name, attribute_value.to_owned());
177        }
178    }
179
180    let path = request.uri().path();
181    let name = path.trim_start_matches('/');
182    span.record("otel.name", name);
183    if let Some((service, method)) = name.split_once('/') {
184        span.record("rpc.service", service);
185        span.record("rpc.method", method);
186    }
187
188    match kind {
189        SpanKind::Client => {
190            let util::HttpRequestAttributes {
191                server_address,
192                server_port,
193                ..
194            } = util::HttpRequestAttributes::from_sent_request(request);
195
196            if let Some(server_address) = server_address {
197                span.record("server.address", server_address);
198            }
199            if let Some(server_port) = server_port {
200                span.record("server.port", server_port);
201            }
202
203            let context = span.context();
204            opentelemetry::global::get_text_map_propagator(|injector| {
205                injector.inject_context(&context, &mut HeaderInjector(request.headers_mut()));
206            });
207        }
208        SpanKind::Server => {
209            if let Some(client_address) = util::client_address(request) {
210                let ip = client_address.ip();
211                span.record("client.address", tracing::field::display(ip));
212                span.record("client.port", client_address.port());
213            }
214
215            let util::HttpRequestAttributes {
216                server_address,
217                server_port,
218                ..
219            } = util::HttpRequestAttributes::from_recv_request(request);
220
221            if let Some(server_address) = server_address {
222                span.record("server.address", server_address);
223            }
224            if let Some(server_port) = server_port {
225                span.record("server.port", server_port);
226            }
227
228            let context = opentelemetry::global::get_text_map_propagator(|extractor| {
229                extractor.extract(&HeaderExtractor(request.headers_mut()))
230            });
231            if let Err(err) = span.set_parent(context) {
232                tracing::warn!("Failed to set parent span: {err}");
233            }
234        }
235    }
236
237    span
238}
239
240/// Records fields associated to the response.
241fn record_response<B>(span: &Span, response: &Response<B>) {
242    for (header_name, header_value) in response.headers().iter() {
243        if let Ok(attribute_value) = header_value.to_str() {
244            let attribute_name = format!("rpc.grpc.response.metadata.{}", header_name);
245            span.set_attribute(attribute_name, attribute_value.to_owned());
246        }
247    }
248
249    if let Some(header_value) = response.headers().get("grpc-status") {
250        if let Ok(header_value) = header_value.to_str() {
251            if let Ok(status_code) = header_value.parse::<i32>() {
252                span.record("rpc.grpc.status_code", status_code);
253            }
254        }
255    } else {
256        span.record("rpc.grpc.status_code", 0);
257    }
258}
259
260/// Records the error message.
261fn record_error<E: Display>(span: &Span, err: &E) {
262    span.record("otel.status_code", "ERROR");
263    span.record("error.message", err.to_string());
264}