tower_http_tracing/
lib.rs

1//!Tower tracing middleware to annotate every HTTP request with tracing's span.
2//!
3//!## Example
4//!
5//!Below is illustration of how to initialize request layer for passing into your service
6//!
7//!```rust
8//!use std::net::IpAddr;
9//!
10//!use tower_http_tracing::HttpRequestLayer;
11//!
12//!//Logic to extract client ip has to be written by user
13//!//You can use utilities in separate crate to design this logic:
14//!//https://docs.rs/http-ip/latest/http_ip/
15//!fn extract_client_ip(_parts: &http::request::Parts) -> Option<IpAddr> {
16//!    None
17//!}
18//!tower_http_tracing::make_request_spanner!(make_my_request_span("my_request", tracing::Level::INFO));
19//!let layer = HttpRequestLayer::new(make_my_request_span).with_extract_client_ip(extract_client_ip)
20//!                                                       .with_inspect_headers(&[&http::header::FORWARDED]);
21//!//Use above layer in your service
22//!```
23
24#![warn(missing_docs)]
25#![allow(clippy::style)]
26
27mod grpc;
28mod headers;
29
30use std::net::IpAddr;
31use core::{cmp, fmt, ptr, task};
32use core::pin::Pin;
33use core::future::Future;
34
35pub use tracing;
36
37///RequestId's header name
38pub const REQUEST_ID: http::HeaderName = http::HeaderName::from_static("x-request-id");
39///Alias to function signature required to create span
40pub type MakeSpan = fn() -> tracing::Span;
41///ALias to function signature to extract client's ip from request
42pub type ExtractClientIp = fn(&http::request::Parts) -> Option<IpAddr>;
43
44#[inline]
45fn default_client_ip(_: &http::request::Parts) -> Option<IpAddr> {
46    None
47}
48
49#[derive(Copy, Clone, PartialEq, Eq)]
50///Possible request protocol
51pub enum Protocol {
52    ///Regular HTTP call
53    ///
54    ///Default value for all requests
55    Http,
56    ///gRPC call, identified by presence of `Content-Type` with grpc protocol signature
57    Grpc,
58}
59
60impl Protocol {
61    #[inline(always)]
62    ///Determines protocol from value of `Content-Type`
63    pub fn from_content_type(typ: &[u8]) -> Self {
64        if typ.starts_with(b"application/grpc") {
65            Self::Grpc
66        } else {
67            Self::Http
68        }
69    }
70
71    #[inline(always)]
72    ///Returns textual representation of the `self`
73    pub const fn as_str(&self) -> &'static str {
74        match self {
75            Self::Grpc => "Grpc",
76            Self::Http => "Http"
77        }
78    }
79}
80
81impl fmt::Debug for Protocol {
82    #[inline(always)]
83    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
84        fmt::Debug::fmt(self.as_str(), fmt)
85    }
86}
87
88impl fmt::Display for Protocol {
89    #[inline(always)]
90    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
91        fmt::Display::fmt(self.as_str(), fmt)
92    }
93}
94
95type RequestIdBuffer = [u8; 64];
96
97#[derive(Clone)]
98///Request's id
99///
100///By default it is extracted from `X-Request-Id` header
101pub struct RequestId {
102    buffer: RequestIdBuffer,
103    len: u8,
104}
105
106impl RequestId {
107    fn from_bytes(bytes: &[u8]) -> Self {
108        let mut buffer: RequestIdBuffer = [0; 64];
109
110        let len = cmp::min(buffer.len(), bytes.len());
111
112        unsafe {
113            ptr::copy_nonoverlapping(bytes.as_ptr(), buffer.as_mut_ptr(), len)
114        };
115
116        Self {
117            buffer,
118            len: len as _,
119        }
120    }
121
122    fn from_uuid(uuid: uuid::Uuid) -> Self {
123        let mut buffer: RequestIdBuffer = [0; 64];
124        let uuid = uuid.as_hyphenated();
125        let len = uuid.encode_lower(&mut buffer).len();
126
127        Self {
128            buffer,
129            len: len as _,
130        }
131    }
132
133    #[inline]
134    ///Returns slice to already written data.
135    pub const fn as_bytes(&self) -> &[u8] {
136        unsafe {
137            core::slice::from_raw_parts(self.buffer.as_ptr(), self.len as _)
138        }
139    }
140
141    #[inline(always)]
142    ///Gets textual representation of the request id, if header value is string
143    pub const fn as_str(&self) -> Option<&str> {
144        match core::str::from_utf8(self.as_bytes()) {
145            Ok(header) => Some(header),
146            Err(_) => None,
147        }
148    }
149}
150
151impl fmt::Debug for RequestId {
152    #[inline(always)]
153    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
154        match self.as_str() {
155            Some(id) => fmt::Debug::fmt(id, fmt),
156            None => fmt::Debug::fmt(self.as_bytes(), fmt),
157        }
158    }
159}
160
161impl fmt::Display for RequestId {
162    #[inline(always)]
163    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
164        match self.as_str() {
165            Some(id) => fmt::Display::fmt(id, fmt),
166            None => fmt::Display::fmt("<non-utf8>", fmt),
167        }
168    }
169}
170
171#[macro_export]
172///Declares `fn` function compatible with `MakeSpan` using provided parameters
173///
174///## Usage
175///
176///```
177///use tower_http_tracing::make_request_spanner;
178///
179///make_request_spanner!(make_my_request_span("my_request", tracing::Level::INFO));
180///
181///let span = make_my_request_span();
182///span.record("http.url", "I can override span field");
183///```
184macro_rules! make_request_spanner {
185    ($fn:ident($name:literal, $level:expr)) => {
186        #[track_caller]
187        pub fn $fn() -> $crate::tracing::Span {
188            use $crate::tracing::field;
189            $crate::tracing::span!(
190                $level,
191                $name,
192                //Assigned on creation of span
193                http.method = field::Empty,
194                http.url = field::Empty,
195                http.request_id = field::Empty,
196                http.user_agent = field::Empty,
197                http.version = field::Empty,
198                http.headers = field::Empty,
199                protocol = field::Empty,
200                //Optional
201                http.client.ip = field::Empty,
202                //Assigned after request is complete
203                http.status_code = field::Empty,
204                error.message = field::Empty,
205            )
206        }
207    };
208}
209
210#[derive(Clone, Debug)]
211///Request's information
212///
213///It is accessible via [extensions](https://docs.rs/http/latest/http/struct.Extensions.html)
214pub struct RequestInfo {
215    ///Request's protocol
216    pub protocol: Protocol,
217    ///Request's id
218    pub request_id: RequestId,
219    ///Client's IP address extracted, if available.
220    pub client_ip: Option<IpAddr>,
221}
222
223///Request's span information
224///
225///Created on every request by the middleware, but not accessible to the user directly
226pub struct RequestSpan {
227    ///Underlying tracing span
228    pub span: tracing::Span,
229    ///Request's information
230    pub info: RequestInfo,
231}
232
233impl RequestSpan {
234    ///Creates new request span
235    pub fn new(span: tracing::Span, extract_client_ip: ExtractClientIp, parts: &http::request::Parts) -> Self {
236        let _entered = span.enter();
237
238        let client_ip = (extract_client_ip)(parts);
239        let protocol = parts.headers
240                            .get(http::header::CONTENT_TYPE)
241                            .map_or(Protocol::Http, |content_type| Protocol::from_content_type(content_type.as_bytes()));
242
243        let request_id = if let Some(request_id) = parts.headers.get(REQUEST_ID) {
244            RequestId::from_bytes(request_id.as_bytes())
245        } else {
246            RequestId::from_uuid(uuid::Uuid::new_v4())
247        };
248
249        if let Some(user_agent) = parts.headers.get(http::header::USER_AGENT).and_then(|header| header.to_str().ok()) {
250            span.record("http.user_agent", user_agent);
251        }
252        span.record("http.method", parts.method.as_str());
253        span.record("http.version", tracing::field::debug(&parts.version));
254        span.record("http.url", parts.uri.path());
255        if let Some(request_id) = request_id.as_str() {
256            span.record("http.request_id", &request_id);
257        } else {
258            span.record("http.request_id", request_id.as_bytes());
259        }
260        if let Some(client_ip) = client_ip {
261            span.record("http.client.ip", tracing::field::display(client_ip));
262        }
263        span.record("protocol", protocol.as_str());
264
265        drop(_entered);
266
267        Self {
268            span,
269            info: RequestInfo {
270                protocol,
271                request_id,
272                client_ip
273            }
274        }
275    }
276}
277
278#[derive(Clone)]
279///Tower layer
280pub struct HttpRequestLayer {
281    make_span: MakeSpan,
282    inspect_headers: &'static [&'static http::HeaderName],
283    extract_client_ip: ExtractClientIp,
284}
285
286impl HttpRequestLayer {
287    #[inline]
288    ///Creates new layer with provided span maker
289    pub fn new(make_span: MakeSpan) -> Self {
290        Self {
291            make_span,
292            inspect_headers: &[],
293            extract_client_ip: default_client_ip
294        }
295    }
296
297    #[inline]
298    ///Specifies list of headers you want to inspect via `http.headers` attribute.
299    ///
300    ///By default none of the headers are inspected
301    pub fn with_inspect_headers(mut self, inspect_headers: &'static [&'static http::HeaderName]) -> Self {
302        self.inspect_headers = inspect_headers;
303        self
304    }
305
306    ///Customizes client ip extraction method
307    ///
308    ///Default extracts none
309    pub fn with_extract_client_ip(mut self, extract_client_ip: ExtractClientIp) -> Self {
310        self.extract_client_ip = extract_client_ip;
311        self
312    }
313}
314
315impl<S> tower_layer::Layer<S> for HttpRequestLayer {
316    type Service = HttpRequestService<S>;
317    #[inline(always)]
318    fn layer(&self, inner: S) -> Self::Service {
319        HttpRequestService {
320            layer: self.clone(),
321            inner,
322        }
323    }
324}
325
326///Tower service to annotate requests with span
327pub struct HttpRequestService<S> {
328    layer: HttpRequestLayer,
329    inner: S
330}
331
332impl<ReqBody, ResBody, S: tower_service::Service<http::Request<ReqBody>, Response = http::Response<ResBody>>> tower_service::Service<http::Request<ReqBody>> for HttpRequestService<S> where S::Error: std::error::Error {
333    type Response = S::Response;
334    type Error = S::Error;
335    type Future = ResponseFut<S::Future>;
336
337    #[inline(always)]
338    fn poll_ready(&mut self, ctx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
339        self.inner.poll_ready(ctx)
340    }
341
342    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
343        let (parts, body) = req.into_parts();
344        let RequestSpan { span, info } = RequestSpan::new((self.layer.make_span)(), self.layer.extract_client_ip, &parts);
345
346        let _entered = span.enter();
347        if !self.layer.inspect_headers.is_empty() {
348            span.record("http.headers", tracing::field::debug(headers::InspectHeaders {
349                header_list: self.layer.inspect_headers,
350                headers: &parts.headers
351            }));
352        }
353        let request_id = info.request_id.clone();
354        let protocol = info.protocol;
355        let mut req = http::Request::from_parts(parts, body);
356        req.extensions_mut().insert(info);
357        let inner = self.inner.call(req);
358
359        drop(_entered);
360        ResponseFut {
361            inner,
362            span,
363            protocol,
364            request_id
365        }
366    }
367}
368
369///Middleware's response future
370pub struct ResponseFut<F> {
371    inner: F,
372    span: tracing::Span,
373    protocol: Protocol,
374    request_id: RequestId,
375}
376
377impl<ResBody, E: std::error::Error, F: Future<Output = Result<http::Response<ResBody>, E>>> Future for ResponseFut<F> {
378    type Output = F::Output;
379
380    fn poll(self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
381        let (fut, span, protocol, request_id) = unsafe {
382            let this = self.get_unchecked_mut();
383            (
384                Pin::new_unchecked(&mut this.inner),
385                &this.span,
386                this.protocol,
387                &this.request_id,
388            )
389        };
390        let _entered = span.enter();
391        match Future::poll(fut, ctx) {
392            task::Poll::Ready(Ok(mut resp)) => {
393                if let Ok(request_id) = http::HeaderValue::from_bytes(request_id.as_bytes()) {
394                    resp.headers_mut().insert(REQUEST_ID, request_id);
395                }
396                let status = match protocol {
397                    Protocol::Http => resp.status().as_u16(),
398                    Protocol::Grpc => match resp.headers().get("grpc-status") {
399                        Some(status) => grpc::parse_grpc_status(status.as_bytes()),
400                        None => 2,
401                    }
402                };
403                span.record("http.status_code", status);
404
405                task::Poll::Ready(Ok(resp))
406            }
407            task::Poll::Ready(Err(error)) => {
408                let status = match protocol {
409                    Protocol::Http => 500u16,
410                    Protocol::Grpc => 13,
411                };
412                span.record("http.status_code", status);
413                span.record("error.message", tracing::field::display(&error));
414                task::Poll::Ready(Err(error))
415            },
416            task::Poll::Pending => task::Poll::Pending
417        }
418    }
419}