tower_http_tracing/
lib.rs

1//!Tower tracing middleware to annotate every HTTP request with tracing's span.
2//!
3//!## Span creation
4//!
5//!Use [macro](macro.make_request_spanner.html) to declare function that creates desirable span
6//!
7//!## Example
8//!
9//!Below is illustration of how to initialize request layer for passing into your service
10//!
11//!```rust
12//!use std::net::IpAddr;
13//!
14//!use tower_http_tracing::HttpRequestLayer;
15//!
16//!//Logic to extract client ip has to be written by user
17//!//You can use utilities in separate crate to design this logic:
18//!//https://docs.rs/http-ip/latest/http_ip/
19//!fn extract_client_ip(_parts: &http::request::Parts) -> Option<IpAddr> {
20//!    None
21//!}
22//!tower_http_tracing::make_request_spanner!(make_my_request_span("my_request", tracing::Level::INFO));
23//!let layer = HttpRequestLayer::new(make_my_request_span).with_extract_client_ip(extract_client_ip)
24//!                                                       .with_inspect_headers(&[&http::header::FORWARDED]);
25//!//Use above layer in your service
26//!```
27//!
28//!## Features
29//!
30//!- `opentelemetry` - Enables integration with opentelemetry to propagate context from requests and into responses
31
32#![warn(missing_docs)]
33#![allow(clippy::style)]
34
35mod grpc;
36mod headers;
37#[cfg(feature = "opentelemetry")]
38pub mod opentelemetry;
39#[cfg(feature = "datadog")]
40pub mod datadog;
41
42use std::net::IpAddr;
43use core::{cmp, fmt, ptr, task};
44use core::pin::Pin;
45use core::future::Future;
46
47pub use tracing;
48
49///RequestId's header name
50pub const REQUEST_ID: http::HeaderName = http::HeaderName::from_static("x-request-id");
51///Alias to function signature required to create span
52pub type MakeSpan = fn() -> tracing::Span;
53///ALias to function signature to extract client's ip from request
54pub type ExtractClientIp = fn(&http::request::Parts) -> Option<IpAddr>;
55
56#[inline]
57fn default_client_ip(_: &http::request::Parts) -> Option<IpAddr> {
58    None
59}
60
61#[derive(Copy, Clone, PartialEq, Eq)]
62///Possible request protocol
63pub enum Protocol {
64    ///Regular HTTP call
65    ///
66    ///Default value for all requests
67    Http,
68    ///gRPC call, identified by presence of `Content-Type` with grpc protocol signature
69    Grpc,
70}
71
72impl Protocol {
73    #[inline(always)]
74    ///Determines protocol from value of `Content-Type`
75    pub fn from_content_type(typ: &[u8]) -> Self {
76        if typ.starts_with(b"application/grpc") {
77            Self::Grpc
78        } else {
79            Self::Http
80        }
81    }
82
83    #[inline(always)]
84    ///Returns textual representation of the `self`
85    pub const fn as_str(&self) -> &'static str {
86        match self {
87            Self::Grpc => "grpc",
88            Self::Http => "http"
89        }
90    }
91}
92
93impl fmt::Debug for Protocol {
94    #[inline(always)]
95    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
96        fmt::Debug::fmt(self.as_str(), fmt)
97    }
98}
99
100impl fmt::Display for Protocol {
101    #[inline(always)]
102    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
103        fmt::Display::fmt(self.as_str(), fmt)
104    }
105}
106
107type RequestIdBuffer = [u8; 64];
108
109#[derive(Clone)]
110///Request's id
111///
112///By default it is extracted from `X-Request-Id` header
113pub struct RequestId {
114    buffer: RequestIdBuffer,
115    len: u8,
116}
117
118impl RequestId {
119    fn from_bytes(bytes: &[u8]) -> Self {
120        let mut buffer: RequestIdBuffer = [0; 64];
121
122        let len = cmp::min(buffer.len(), bytes.len());
123
124        unsafe {
125            ptr::copy_nonoverlapping(bytes.as_ptr(), buffer.as_mut_ptr(), len)
126        };
127
128        Self {
129            buffer,
130            len: len as _,
131        }
132    }
133
134    fn from_uuid(uuid: uuid::Uuid) -> Self {
135        let mut buffer: RequestIdBuffer = [0; 64];
136        let uuid = uuid.as_hyphenated();
137        let len = uuid.encode_lower(&mut buffer).len();
138
139        Self {
140            buffer,
141            len: len as _,
142        }
143    }
144
145    #[inline]
146    ///Returns slice to already written data.
147    pub const fn as_bytes(&self) -> &[u8] {
148        unsafe {
149            core::slice::from_raw_parts(self.buffer.as_ptr(), self.len as _)
150        }
151    }
152
153    #[inline(always)]
154    ///Gets textual representation of the request id, if header value is string
155    pub const fn as_str(&self) -> Option<&str> {
156        match core::str::from_utf8(self.as_bytes()) {
157            Ok(header) => Some(header),
158            Err(_) => None,
159        }
160    }
161}
162
163impl fmt::Debug for RequestId {
164    #[inline(always)]
165    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
166        match self.as_str() {
167            Some(id) => fmt::Debug::fmt(id, fmt),
168            None => fmt::Debug::fmt(self.as_bytes(), fmt),
169        }
170    }
171}
172
173impl fmt::Display for RequestId {
174    #[inline(always)]
175    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
176        match self.as_str() {
177            Some(id) => fmt::Display::fmt(id, fmt),
178            None => fmt::Display::fmt("<non-utf8>", fmt),
179        }
180    }
181}
182
183#[macro_export]
184///Declares `fn` function compatible with `MakeSpan` using provided parameters
185///
186///## Span fields
187///
188///Following fields are declared when span is created:
189///- `http.request.method`
190///- `url.path`
191///- `url.query`
192///- `url.scheme`
193///- `http.request_id` - Inherited from request 'X-Request-Id' or random uuid
194///- `user_agent.original` - Only populated if user agent header is present
195///- `http.headers` - Optional. Populated if more than 1 header specified via layer [config](struct.HttpRequestLayer.html#method.with_inspect_headers)
196///- `network.protocol.name` - Either `http` or `grpc` depending on `content-type`
197///- `network.protocol.version` - Set to HTTP version in case of plain `http` protocol.
198///- `client.address` - Optionally added if IP extractor is specified via layer [config](struct.HttpRequestLayer.html#method.with_extract_client_ip)
199///- `http.response.status_code` - Semantics of this code depends on `protocol`
200///- `error.type` - Populated with `core::any::type_name` value of error type used by the service.
201///- `error.message` - Populated with `Display` content of the error, returned by underlying service, after processing request.
202///
203///Loosely follows <https://opentelemetry.io/docs/specs/semconv/http/http-spans/#http-server>
204///
205///## Usage
206///
207///```
208///use tower_http_tracing::make_request_spanner;
209///
210///make_request_spanner!(make_my_request_span("my_request", tracing::Level::INFO));
211/////Customize span with extra fields. You can use tracing::field::Empty if you want to omit value
212///make_request_spanner!(make_my_service_request_span("my_request", tracing::Level::INFO, service_name = "<your name>"));
213///
214///let span = make_my_request_span();
215///span.record("url.path", "I can override span field");
216///
217///```
218macro_rules! make_request_spanner {
219    ($fn:ident($name:literal, $level:expr)) => {
220        $crate::make_request_spanner!($fn($name, $level,));
221    };
222    ($fn:ident($name:literal, $level:expr, $($fields:tt)*)) => {
223        #[track_caller]
224        pub fn $fn() -> $crate::tracing::Span {
225            use $crate::tracing::field;
226
227            $crate::tracing::span!(
228                $level,
229                $name,
230                //Defaults
231                span.kind = "server",
232                //Assigned on creation of span
233                http.request.method = field::Empty,
234                url.path = field::Empty,
235                url.query = field::Empty,
236                url.scheme = field::Empty,
237                http.request_id = field::Empty,
238                user_agent.original = field::Empty,
239                http.headers = field::Empty,
240                network.protocol.name = field::Empty,
241                network.protocol.version = field::Empty,
242                //Optional
243                client.address = field::Empty,
244                //Assigned after request is complete
245                http.response.status_code = field::Empty,
246                error.type = field::Empty,
247                error.message = field::Empty,
248                $(
249                    $fields
250                )*
251            )
252        }
253    };
254}
255
256#[derive(Clone, Debug)]
257///Request's information
258///
259///It is accessible via [extensions](https://docs.rs/http/latest/http/struct.Extensions.html)
260pub struct RequestInfo {
261    ///Request's protocol
262    pub protocol: Protocol,
263    ///Request's id
264    pub request_id: RequestId,
265    ///Client's IP address extracted, if available.
266    pub client_ip: Option<IpAddr>,
267}
268
269///Request's span information
270///
271///Created on every request by the middleware, but not accessible to the user directly
272pub struct RequestSpan {
273    ///Underlying tracing span
274    pub span: tracing::Span,
275    ///Request's information
276    pub info: RequestInfo,
277}
278
279impl RequestSpan {
280    ///Creates new request span
281    pub fn new(span: tracing::Span, extract_client_ip: ExtractClientIp, parts: &http::request::Parts) -> Self {
282        let _entered = span.enter();
283
284        let client_ip = (extract_client_ip)(parts);
285        let protocol = parts.headers
286                            .get(http::header::CONTENT_TYPE)
287                            .map_or(Protocol::Http, |content_type| Protocol::from_content_type(content_type.as_bytes()));
288
289        let request_id = if let Some(request_id) = parts.headers.get(REQUEST_ID) {
290            RequestId::from_bytes(request_id.as_bytes())
291        } else {
292            RequestId::from_uuid(uuid::Uuid::new_v4())
293        };
294
295        if let Some(user_agent) = parts.headers.get(http::header::USER_AGENT).and_then(|header| header.to_str().ok()) {
296            span.record("user_agent.original", user_agent);
297        }
298        span.record("http.request.method", parts.method.as_str());
299        span.record("url.path", parts.uri.path());
300        if let Some(query) = parts.uri.query() {
301            span.record("url.query", query);
302        }
303        if let Some(scheme) = parts.uri.scheme() {
304            span.record("url.scheme", scheme.as_str());
305        }
306        if let Some(request_id) = request_id.as_str() {
307            span.record("http.request_id", &request_id);
308        } else {
309            span.record("http.request_id", request_id.as_bytes());
310        }
311        if let Some(client_ip) = client_ip {
312            span.record("client.address", tracing::field::display(client_ip));
313        }
314        span.record("network.protocol.name", protocol.as_str());
315        if let Protocol::Http = protocol {
316            match parts.version {
317                http::Version::HTTP_09 => span.record("network.protocol.version", 0.9),
318                http::Version::HTTP_10 => span.record("network.protocol.version", 1.0),
319                http::Version::HTTP_11 => span.record("network.protocol.version", 1.1),
320                http::Version::HTTP_2 => span.record("network.protocol.version", 2),
321                http::Version::HTTP_3 => span.record("network.protocol.version", 3),
322                //Invalid version so just set 0
323                _ => span.record("network.protocol.version", 0),
324            };
325        }
326
327        drop(_entered);
328
329        Self {
330            span,
331            info: RequestInfo {
332                protocol,
333                request_id,
334                client_ip
335            }
336        }
337    }
338}
339
340#[derive(Clone)]
341///Tower layer
342pub struct HttpRequestLayer {
343    make_span: MakeSpan,
344    inspect_headers: &'static [&'static http::HeaderName],
345    extract_client_ip: ExtractClientIp,
346}
347
348impl HttpRequestLayer {
349    #[inline]
350    ///Creates new layer with provided span maker
351    pub fn new(make_span: MakeSpan) -> Self {
352        Self {
353            make_span,
354            inspect_headers: &[],
355            extract_client_ip: default_client_ip
356        }
357    }
358
359    #[inline]
360    ///Specifies list of headers you want to inspect via `http.headers` attribute.
361    ///
362    ///By default none of the headers are inspected
363    pub fn with_inspect_headers(mut self, inspect_headers: &'static [&'static http::HeaderName]) -> Self {
364        self.inspect_headers = inspect_headers;
365        self
366    }
367
368    ///Customizes client ip extraction method
369    ///
370    ///Default extracts none
371    pub fn with_extract_client_ip(mut self, extract_client_ip: ExtractClientIp) -> Self {
372        self.extract_client_ip = extract_client_ip;
373        self
374    }
375}
376
377impl<S> tower_layer::Layer<S> for HttpRequestLayer {
378    type Service = HttpRequestService<S>;
379    #[inline(always)]
380    fn layer(&self, inner: S) -> Self::Service {
381        HttpRequestService {
382            layer: self.clone(),
383            inner,
384        }
385    }
386}
387
388///Tower service to annotate requests with span
389pub struct HttpRequestService<S> {
390    layer: HttpRequestLayer,
391    inner: S
392}
393
394impl<ReqBody, ResBody, S: tower_service::Service<http::Request<ReqBody>, Response = http::Response<ResBody>>> tower_service::Service<http::Request<ReqBody>> for HttpRequestService<S> where S::Error: std::error::Error {
395    type Response = S::Response;
396    type Error = S::Error;
397    type Future = ResponseFut<S::Future>;
398
399    #[inline(always)]
400    fn poll_ready(&mut self, ctx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
401        self.inner.poll_ready(ctx)
402    }
403
404    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
405        let (parts, body) = req.into_parts();
406        let RequestSpan { span, info } = RequestSpan::new((self.layer.make_span)(), self.layer.extract_client_ip, &parts);
407
408        let mut req = http::Request::from_parts(parts, body);
409        #[cfg(feature = "opentelemetry")]
410        opentelemetry::on_request(&span, &req);
411        #[cfg(feature = "datadog")]
412        datadog::on_request(&span, &req);
413
414        let _entered = span.enter();
415        if !self.layer.inspect_headers.is_empty() {
416            span.record("http.headers", tracing::field::debug(headers::InspectHeaders {
417                header_list: self.layer.inspect_headers,
418                headers: req.headers()
419            }));
420        }
421        let request_id = info.request_id.clone();
422        let protocol = info.protocol;
423        req.extensions_mut().insert(info);
424
425        let inner = self.inner.call(req);
426
427        drop(_entered);
428        ResponseFut {
429            inner,
430            span,
431            protocol,
432            request_id
433        }
434    }
435}
436
437///Middleware's response future
438pub struct ResponseFut<F> {
439    inner: F,
440    span: tracing::Span,
441    protocol: Protocol,
442    request_id: RequestId,
443}
444
445impl<ResBody, E: std::error::Error, F: Future<Output = Result<http::Response<ResBody>, E>>> Future for ResponseFut<F> {
446    type Output = F::Output;
447
448    fn poll(self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
449        let (fut, span, protocol, request_id) = unsafe {
450            let this = self.get_unchecked_mut();
451            (
452                Pin::new_unchecked(&mut this.inner),
453                &this.span,
454                this.protocol,
455                &this.request_id,
456            )
457        };
458        let _entered = span.enter();
459        match Future::poll(fut, ctx) {
460            task::Poll::Ready(Ok(mut resp)) => {
461                if let Ok(request_id) = http::HeaderValue::from_bytes(request_id.as_bytes()) {
462                    resp.headers_mut().insert(REQUEST_ID, request_id);
463                }
464                let status = match protocol {
465                    Protocol::Http => resp.status().as_u16(),
466                    Protocol::Grpc => match resp.headers().get("grpc-status") {
467                        Some(status) => grpc::parse_grpc_status(status.as_bytes()),
468                        None => 2,
469                    }
470                };
471                span.record("http.response.status_code", status);
472
473                #[cfg(feature = "opentelemetry")]
474                opentelemetry::on_response_ok(&span, &mut resp);
475                #[cfg(feature = "datadog")]
476                datadog::on_response_ok(&span, &mut resp);
477
478                task::Poll::Ready(Ok(resp))
479            }
480            task::Poll::Ready(Err(error)) => {
481                let status = match protocol {
482                    Protocol::Http => 500u16,
483                    Protocol::Grpc => 13,
484                };
485                span.record("http.response.status_code", status);
486                span.record("error.type", core::any::type_name::<E>());
487                span.record("error.message", tracing::field::display(&error));
488
489                #[cfg(feature = "opentelemetry")]
490                opentelemetry::on_response_error(&span, &error);
491                #[cfg(feature = "datadog")]
492                datadog::on_response_error(&span, &error);
493
494                task::Poll::Ready(Err(error))
495            },
496            task::Poll::Pending => task::Poll::Pending
497        }
498    }
499}