1use std::{
4 fmt::Display,
5 future::Future,
6 pin::Pin,
7 task::{ready, Context, Poll},
8};
9
10use http::{Request, Response};
11use pin_project::pin_project;
12use tower_layer::Layer;
13use tower_service::Service;
14use tracing::{Level, Span};
15use tracing_opentelemetry::OpenTelemetrySpanExt;
16
17use crate::{
18 trace::{extractor::HeaderExtractor, injector::HeaderInjector},
19 util,
20};
21
22#[derive(Clone, Copy, Debug)]
24enum SpanKind {
25 Client,
27 Server,
29}
30
31#[derive(Clone, Debug)]
33pub struct GrpcLayer {
34 level: Level,
35 kind: SpanKind,
36}
37
38impl GrpcLayer {
39 pub fn server(level: Level) -> Self {
41 Self {
42 level,
43 kind: SpanKind::Server,
44 }
45 }
46
47 pub fn client(level: Level) -> Self {
49 Self {
50 level,
51 kind: SpanKind::Client,
52 }
53 }
54}
55
56impl<S> Layer<S> for GrpcLayer {
57 type Service = Grpc<S>;
58
59 fn layer(&self, inner: S) -> Self::Service {
60 Grpc {
61 inner,
62 level: self.level,
63 kind: self.kind,
64 }
65 }
66}
67
68#[derive(Clone, Debug)]
70pub struct Grpc<S> {
71 inner: S,
72 level: Level,
73 kind: SpanKind,
74}
75
76impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Grpc<S>
77where
78 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
79 S::Error: Display,
80{
81 type Response = S::Response;
82 type Error = S::Error;
83 type Future = ResponseFuture<S::Future>;
84
85 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86 self.inner.poll_ready(cx)
87 }
88
89 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
90 let span = make_request_span(self.level, self.kind, &mut req);
91 let inner = {
92 let _enter = span.enter();
93 self.inner.call(req)
94 };
95
96 ResponseFuture { inner, span }
97 }
98}
99
100#[pin_project]
102pub struct ResponseFuture<F> {
103 #[pin]
104 inner: F,
105 span: Span,
106}
107
108impl<F, ResBody, E> Future for ResponseFuture<F>
109where
110 F: Future<Output = Result<Response<ResBody>, E>>,
111 E: Display,
112{
113 type Output = Result<Response<ResBody>, E>;
114
115 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
116 let this = self.project();
117 let _enter = this.span.enter();
118
119 match ready!(this.inner.poll(cx)) {
120 Ok(response) => {
121 record_response(this.span, &response);
122 Poll::Ready(Ok(response))
123 }
124 Err(err) => {
125 record_error(this.span, &err);
126 Poll::Ready(Err(err))
127 }
128 }
129 }
130}
131
132fn span_kind(kind: SpanKind) -> &'static str {
134 match kind {
135 SpanKind::Client => "client",
136 SpanKind::Server => "server",
137 }
138}
139
140fn make_request_span<B>(level: Level, kind: SpanKind, request: &mut Request<B>) -> Span {
142 macro_rules! make_span {
143 ($level:expr) => {{
144 use tracing::field::Empty;
145
146 tracing::span!(
147 $level,
148 "GRPC",
149 "client.address" = Empty,
150 "client.port" = Empty,
151 "error.message" = Empty,
152 "otel.kind" = span_kind(kind),
153 "otel.name" = Empty,
154 "otel.status_code" = Empty,
155 "rpc.grpc.status_code" = Empty,
156 "rpc.method" = Empty,
157 "rpc.service" = Empty,
158 "rpc.system" = "grpc",
159 "server.address" = Empty,
160 "server.port" = Empty,
161 )
162 }};
163 }
164
165 let span = match level {
166 Level::ERROR => make_span!(Level::ERROR),
167 Level::WARN => make_span!(Level::WARN),
168 Level::INFO => make_span!(Level::INFO),
169 Level::DEBUG => make_span!(Level::DEBUG),
170 Level::TRACE => make_span!(Level::TRACE),
171 };
172
173 for (header_name, header_value) in request.headers().iter() {
174 if let Ok(attribute_value) = header_value.to_str() {
175 let attribute_name = format!("rpc.grpc.request.metadata.{}", header_name);
176 span.set_attribute(attribute_name, attribute_value.to_owned());
177 }
178 }
179
180 let path = request.uri().path();
181 let name = path.trim_start_matches('/');
182 span.record("otel.name", name);
183 if let Some((service, method)) = name.split_once('/') {
184 span.record("rpc.service", service);
185 span.record("rpc.method", method);
186 }
187
188 match kind {
189 SpanKind::Client => {
190 let util::HttpRequestAttributes {
191 server_address,
192 server_port,
193 ..
194 } = util::HttpRequestAttributes::from_sent_request(request);
195
196 if let Some(server_address) = server_address {
197 span.record("server.address", server_address);
198 }
199 if let Some(server_port) = server_port {
200 span.record("server.port", server_port);
201 }
202
203 let context = span.context();
204 opentelemetry::global::get_text_map_propagator(|injector| {
205 injector.inject_context(&context, &mut HeaderInjector(request.headers_mut()));
206 });
207 }
208 SpanKind::Server => {
209 if let Some(client_address) = util::client_address(request) {
210 let ip = client_address.ip();
211 span.record("client.address", tracing::field::display(ip));
212 span.record("client.port", client_address.port());
213 }
214
215 let util::HttpRequestAttributes {
216 server_address,
217 server_port,
218 ..
219 } = util::HttpRequestAttributes::from_recv_request(request);
220
221 if let Some(server_address) = server_address {
222 span.record("server.address", server_address);
223 }
224 if let Some(server_port) = server_port {
225 span.record("server.port", server_port);
226 }
227
228 let context = opentelemetry::global::get_text_map_propagator(|extractor| {
229 extractor.extract(&HeaderExtractor(request.headers_mut()))
230 });
231 if let Err(err) = span.set_parent(context) {
232 tracing::warn!("Failed to set parent span: {err}");
233 }
234 }
235 }
236
237 span
238}
239
240fn record_response<B>(span: &Span, response: &Response<B>) {
242 for (header_name, header_value) in response.headers().iter() {
243 if let Ok(attribute_value) = header_value.to_str() {
244 let attribute_name = format!("rpc.grpc.response.metadata.{}", header_name);
245 span.set_attribute(attribute_name, attribute_value.to_owned());
246 }
247 }
248
249 if let Some(header_value) = response.headers().get("grpc-status") {
250 if let Ok(header_value) = header_value.to_str() {
251 if let Ok(status_code) = header_value.parse::<i32>() {
252 span.record("rpc.grpc.status_code", status_code);
253 }
254 }
255 } else {
256 span.record("rpc.grpc.status_code", 0);
257 }
258}
259
260fn record_error<E: Display>(span: &Span, err: &E) {
262 span.record("otel.status_code", "ERROR");
263 span.record("error.message", err.to_string());
264}