tower_http_tracing/
lib.rs

1//!Tower tracing middleware to annotate every HTTP request with tracing's span.
2//!
3//!## Span creation
4//!
5//!Use [macro](macro.make_request_spanner.html) to declare function that creates desirable span
6//!
7//!## Example
8//!
9//!Below is illustration of how to initialize request layer for passing into your service
10//!
11//!```rust
12//!use std::net::IpAddr;
13//!
14//!use tower_http_tracing::HttpRequestLayer;
15//!
16//!//Logic to extract client ip has to be written by user
17//!//You can use utilities in separate crate to design this logic:
18//!//https://docs.rs/http-ip/latest/http_ip/
19//!fn extract_client_ip(_parts: &http::request::Parts) -> Option<IpAddr> {
20//!    None
21//!}
22//!tower_http_tracing::make_request_spanner!(make_my_request_span("my_request", tracing::Level::INFO));
23//!let layer = HttpRequestLayer::new(make_my_request_span).with_extract_client_ip(extract_client_ip)
24//!                                                       .with_inspect_headers(&[&http::header::FORWARDED]);
25//!//Use above layer in your service
26//!```
27//!
28//!## Features
29//!
30//!- `opentelemetry` - Enables integration with opentelemetry to propagate context from requests and into responses
31
32#![warn(missing_docs)]
33#![allow(clippy::style)]
34
35mod grpc;
36mod headers;
37#[cfg(feature = "opentelemetry")]
38pub mod opentelemetry;
39
40use std::net::IpAddr;
41use core::{cmp, fmt, ptr, task};
42use core::pin::Pin;
43use core::future::Future;
44
45pub use tracing;
46
47///RequestId's header name
48pub const REQUEST_ID: http::HeaderName = http::HeaderName::from_static("x-request-id");
49///Alias to function signature required to create span
50pub type MakeSpan = fn() -> tracing::Span;
51///ALias to function signature to extract client's ip from request
52pub type ExtractClientIp = fn(&http::request::Parts) -> Option<IpAddr>;
53
54#[inline]
55fn default_client_ip(_: &http::request::Parts) -> Option<IpAddr> {
56    None
57}
58
59#[derive(Copy, Clone, PartialEq, Eq)]
60///Possible request protocol
61pub enum Protocol {
62    ///Regular HTTP call
63    ///
64    ///Default value for all requests
65    Http,
66    ///gRPC call, identified by presence of `Content-Type` with grpc protocol signature
67    Grpc,
68}
69
70impl Protocol {
71    #[inline(always)]
72    ///Determines protocol from value of `Content-Type`
73    pub fn from_content_type(typ: &[u8]) -> Self {
74        if typ.starts_with(b"application/grpc") {
75            Self::Grpc
76        } else {
77            Self::Http
78        }
79    }
80
81    #[inline(always)]
82    ///Returns textual representation of the `self`
83    pub const fn as_str(&self) -> &'static str {
84        match self {
85            Self::Grpc => "grpc",
86            Self::Http => "http"
87        }
88    }
89}
90
91impl fmt::Debug for Protocol {
92    #[inline(always)]
93    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
94        fmt::Debug::fmt(self.as_str(), fmt)
95    }
96}
97
98impl fmt::Display for Protocol {
99    #[inline(always)]
100    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
101        fmt::Display::fmt(self.as_str(), fmt)
102    }
103}
104
105type RequestIdBuffer = [u8; 64];
106
107#[derive(Clone)]
108///Request's id
109///
110///By default it is extracted from `X-Request-Id` header
111pub struct RequestId {
112    buffer: RequestIdBuffer,
113    len: u8,
114}
115
116impl RequestId {
117    fn from_bytes(bytes: &[u8]) -> Self {
118        let mut buffer: RequestIdBuffer = [0; 64];
119
120        let len = cmp::min(buffer.len(), bytes.len());
121
122        unsafe {
123            ptr::copy_nonoverlapping(bytes.as_ptr(), buffer.as_mut_ptr(), len)
124        };
125
126        Self {
127            buffer,
128            len: len as _,
129        }
130    }
131
132    fn from_uuid(uuid: uuid::Uuid) -> Self {
133        let mut buffer: RequestIdBuffer = [0; 64];
134        let uuid = uuid.as_hyphenated();
135        let len = uuid.encode_lower(&mut buffer).len();
136
137        Self {
138            buffer,
139            len: len as _,
140        }
141    }
142
143    #[inline]
144    ///Returns slice to already written data.
145    pub const fn as_bytes(&self) -> &[u8] {
146        unsafe {
147            core::slice::from_raw_parts(self.buffer.as_ptr(), self.len as _)
148        }
149    }
150
151    #[inline(always)]
152    ///Gets textual representation of the request id, if header value is string
153    pub const fn as_str(&self) -> Option<&str> {
154        match core::str::from_utf8(self.as_bytes()) {
155            Ok(header) => Some(header),
156            Err(_) => None,
157        }
158    }
159}
160
161impl fmt::Debug for RequestId {
162    #[inline(always)]
163    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
164        match self.as_str() {
165            Some(id) => fmt::Debug::fmt(id, fmt),
166            None => fmt::Debug::fmt(self.as_bytes(), fmt),
167        }
168    }
169}
170
171impl fmt::Display for RequestId {
172    #[inline(always)]
173    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
174        match self.as_str() {
175            Some(id) => fmt::Display::fmt(id, fmt),
176            None => fmt::Display::fmt("<non-utf8>", fmt),
177        }
178    }
179}
180
181#[macro_export]
182///Declares `fn` function compatible with `MakeSpan` using provided parameters
183///
184///## Span fields
185///
186///Following fields are declared when span is created:
187///- `http.request.method`
188///- `url.path`
189///- `url.query`
190///- `url.scheme`
191///- `http.request_id` - Inherited from request 'X-Request-Id' or random uuid
192///- `user_agent.original` - Only populated if user agent header is present
193///- `http.headers` - Optional. Populated if more than 1 header specified via layer [config](struct.HttpRequestLayer.html#method.with_inspect_headers)
194///- `network.protocol.name` - Either `http` or `grpc` depending on `content-type`
195///- `network.protocol.version` - Set to HTTP version in case of plain `http` protocol.
196///- `client.address` - Optionally added if IP extractor is specified via layer [config](struct.HttpRequestLayer.html#method.with_extract_client_ip)
197///- `http.response.status_code` - Semantics of this code depends on `protocol`
198///- `error.type` - Populated with `core::any::type_name` value of error type used by the service.
199///- `error.message` - Populated with `Display` content of the error, returned by underlying service, after processing request.
200///
201///Loosely follows <https://opentelemetry.io/docs/specs/semconv/http/http-spans/#http-server>
202///
203///## Usage
204///
205///```
206///use tower_http_tracing::make_request_spanner;
207///
208///make_request_spanner!(make_my_request_span("my_request", tracing::Level::INFO));
209/////Customize span with extra fields. You can use tracing::field::Empty if you want to omit value
210///make_request_spanner!(make_my_service_request_span("my_request", tracing::Level::INFO, service_name = "<your name>"));
211///
212///let span = make_my_request_span();
213///span.record("url.path", "I can override span field");
214///
215///```
216macro_rules! make_request_spanner {
217    ($fn:ident($name:literal, $level:expr)) => {
218        $crate::make_request_spanner!($fn($name, $level,));
219    };
220    ($fn:ident($name:literal, $level:expr, $($fields:tt)*)) => {
221        #[track_caller]
222        pub fn $fn() -> $crate::tracing::Span {
223            use $crate::tracing::field;
224
225            $crate::tracing::span!(
226                $level,
227                $name,
228                //Defaults
229                span.kind = "server",
230                //Assigned on creation of span
231                http.request.method = field::Empty,
232                url.path = field::Empty,
233                url.query = field::Empty,
234                url.scheme = field::Empty,
235                http.request_id = field::Empty,
236                user_agent.original = field::Empty,
237                http.headers = field::Empty,
238                network.protocol.name = field::Empty,
239                network.protocol.version = field::Empty,
240                //Optional
241                client.address = field::Empty,
242                //Assigned after request is complete
243                http.response.status_code = field::Empty,
244                error.type = field::Empty,
245                error.message = field::Empty,
246                $(
247                    $fields
248                )*
249            )
250        }
251    };
252}
253
254#[derive(Clone, Debug)]
255///Request's information
256///
257///It is accessible via [extensions](https://docs.rs/http/latest/http/struct.Extensions.html)
258pub struct RequestInfo {
259    ///Request's protocol
260    pub protocol: Protocol,
261    ///Request's id
262    pub request_id: RequestId,
263    ///Client's IP address extracted, if available.
264    pub client_ip: Option<IpAddr>,
265}
266
267///Request's span information
268///
269///Created on every request by the middleware, but not accessible to the user directly
270pub struct RequestSpan {
271    ///Underlying tracing span
272    pub span: tracing::Span,
273    ///Request's information
274    pub info: RequestInfo,
275}
276
277impl RequestSpan {
278    ///Creates new request span
279    pub fn new(span: tracing::Span, extract_client_ip: ExtractClientIp, parts: &http::request::Parts) -> Self {
280        let _entered = span.enter();
281
282        let client_ip = (extract_client_ip)(parts);
283        let protocol = parts.headers
284                            .get(http::header::CONTENT_TYPE)
285                            .map_or(Protocol::Http, |content_type| Protocol::from_content_type(content_type.as_bytes()));
286
287        let request_id = if let Some(request_id) = parts.headers.get(REQUEST_ID) {
288            RequestId::from_bytes(request_id.as_bytes())
289        } else {
290            RequestId::from_uuid(uuid::Uuid::new_v4())
291        };
292
293        if let Some(user_agent) = parts.headers.get(http::header::USER_AGENT).and_then(|header| header.to_str().ok()) {
294            span.record("user_agent.original", user_agent);
295        }
296        span.record("http.request.method", parts.method.as_str());
297        span.record("url.path", parts.uri.path());
298        if let Some(query) = parts.uri.query() {
299            span.record("url.query", query);
300        }
301        if let Some(scheme) = parts.uri.scheme() {
302            span.record("url.scheme", scheme.as_str());
303        }
304        if let Some(request_id) = request_id.as_str() {
305            span.record("http.request_id", &request_id);
306        } else {
307            span.record("http.request_id", request_id.as_bytes());
308        }
309        if let Some(client_ip) = client_ip {
310            span.record("client.address", tracing::field::display(client_ip));
311        }
312        span.record("network.protocol.name", protocol.as_str());
313        if let Protocol::Http = protocol {
314            match parts.version {
315                http::Version::HTTP_09 => span.record("network.protocol.version", 0.9),
316                http::Version::HTTP_10 => span.record("network.protocol.version", 1.0),
317                http::Version::HTTP_11 => span.record("network.protocol.version", 1.1),
318                http::Version::HTTP_2 => span.record("network.protocol.version", 2),
319                http::Version::HTTP_3 => span.record("network.protocol.version", 3),
320                //Invalid version so just set 0
321                _ => span.record("network.protocol.version", 0),
322            };
323        }
324
325        drop(_entered);
326
327        Self {
328            span,
329            info: RequestInfo {
330                protocol,
331                request_id,
332                client_ip
333            }
334        }
335    }
336}
337
338#[derive(Clone)]
339///Tower layer
340pub struct HttpRequestLayer {
341    make_span: MakeSpan,
342    inspect_headers: &'static [&'static http::HeaderName],
343    extract_client_ip: ExtractClientIp,
344}
345
346impl HttpRequestLayer {
347    #[inline]
348    ///Creates new layer with provided span maker
349    pub fn new(make_span: MakeSpan) -> Self {
350        Self {
351            make_span,
352            inspect_headers: &[],
353            extract_client_ip: default_client_ip
354        }
355    }
356
357    #[inline]
358    ///Specifies list of headers you want to inspect via `http.headers` attribute.
359    ///
360    ///By default none of the headers are inspected
361    pub fn with_inspect_headers(mut self, inspect_headers: &'static [&'static http::HeaderName]) -> Self {
362        self.inspect_headers = inspect_headers;
363        self
364    }
365
366    ///Customizes client ip extraction method
367    ///
368    ///Default extracts none
369    pub fn with_extract_client_ip(mut self, extract_client_ip: ExtractClientIp) -> Self {
370        self.extract_client_ip = extract_client_ip;
371        self
372    }
373}
374
375impl<S> tower_layer::Layer<S> for HttpRequestLayer {
376    type Service = HttpRequestService<S>;
377    #[inline(always)]
378    fn layer(&self, inner: S) -> Self::Service {
379        HttpRequestService {
380            layer: self.clone(),
381            inner,
382        }
383    }
384}
385
386///Tower service to annotate requests with span
387pub struct HttpRequestService<S> {
388    layer: HttpRequestLayer,
389    inner: S
390}
391
392impl<ReqBody, ResBody, S: tower_service::Service<http::Request<ReqBody>, Response = http::Response<ResBody>>> tower_service::Service<http::Request<ReqBody>> for HttpRequestService<S> where S::Error: std::error::Error {
393    type Response = S::Response;
394    type Error = S::Error;
395    type Future = ResponseFut<S::Future>;
396
397    #[inline(always)]
398    fn poll_ready(&mut self, ctx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
399        self.inner.poll_ready(ctx)
400    }
401
402    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
403        let (parts, body) = req.into_parts();
404        let RequestSpan { span, info } = RequestSpan::new((self.layer.make_span)(), self.layer.extract_client_ip, &parts);
405
406        let mut req = http::Request::from_parts(parts, body);
407        #[cfg(feature = "opentelemetry")]
408        opentelemetry::on_request(&span, &req);
409
410        let _entered = span.enter();
411        if !self.layer.inspect_headers.is_empty() {
412            span.record("http.headers", tracing::field::debug(headers::InspectHeaders {
413                header_list: self.layer.inspect_headers,
414                headers: req.headers()
415            }));
416        }
417        let request_id = info.request_id.clone();
418        let protocol = info.protocol;
419        req.extensions_mut().insert(info);
420
421        let inner = self.inner.call(req);
422
423        drop(_entered);
424        ResponseFut {
425            inner,
426            span,
427            protocol,
428            request_id
429        }
430    }
431}
432
433///Middleware's response future
434pub struct ResponseFut<F> {
435    inner: F,
436    span: tracing::Span,
437    protocol: Protocol,
438    request_id: RequestId,
439}
440
441impl<ResBody, E: std::error::Error, F: Future<Output = Result<http::Response<ResBody>, E>>> Future for ResponseFut<F> {
442    type Output = F::Output;
443
444    fn poll(self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
445        let (fut, span, protocol, request_id) = unsafe {
446            let this = self.get_unchecked_mut();
447            (
448                Pin::new_unchecked(&mut this.inner),
449                &this.span,
450                this.protocol,
451                &this.request_id,
452            )
453        };
454        let _entered = span.enter();
455        match Future::poll(fut, ctx) {
456            task::Poll::Ready(Ok(mut resp)) => {
457                if let Ok(request_id) = http::HeaderValue::from_bytes(request_id.as_bytes()) {
458                    resp.headers_mut().insert(REQUEST_ID, request_id);
459                }
460                let status = match protocol {
461                    Protocol::Http => resp.status().as_u16(),
462                    Protocol::Grpc => match resp.headers().get("grpc-status") {
463                        Some(status) => grpc::parse_grpc_status(status.as_bytes()),
464                        None => 2,
465                    }
466                };
467                span.record("http.response.status_code", status);
468
469                #[cfg(feature = "opentelemetry")]
470                opentelemetry::on_response_ok(&span, &mut resp);
471
472                task::Poll::Ready(Ok(resp))
473            }
474            task::Poll::Ready(Err(error)) => {
475                let status = match protocol {
476                    Protocol::Http => 500u16,
477                    Protocol::Grpc => 13,
478                };
479                span.record("http.response.status_code", status);
480                span.record("error.type", core::any::type_name::<E>());
481                span.record("error.message", tracing::field::display(&error));
482
483                #[cfg(feature = "opentelemetry")]
484                opentelemetry::on_response_error(&span, &error);
485
486                task::Poll::Ready(Err(error))
487            },
488            task::Poll::Pending => task::Poll::Pending
489        }
490    }
491}