hyper_trace_id/
layer.rs

1use std::marker::PhantomData;
2use std::str::FromStr;
3use std::task::{Context, Poll};
4
5use futures::future::BoxFuture;
6use hyper::http::{HeaderName, HeaderValue, Request, Response};
7use tower::{Layer, Service};
8
9use crate::{MakeTraceId, TraceId};
10
11/// Add the TraceId<T> extension to requests and optionally include trace ids in request and response headers.
12///
13/// ```
14/// use std::convert::Infallible;
15/// use hyper::{Body, Request, Response};
16/// use tower::ServiceBuilder;
17/// use hyper_trace_id::{SetTraceIdLayer, TraceId};
18///
19/// let trace_id_header = "x-trace-id";
20/// let svc = ServiceBuilder::new()
21///     .layer(SetTraceIdLayer::<String>::new().with_header_name(trace_id_header))
22///     .service_fn(|_req: Request<Body>| async {
23///         let res: Result<Response<Body>, Infallible> = Ok(Response::new(Body::empty()));
24///         res
25///     });
26///
27/// ```
28#[derive(Debug, Clone)]
29pub struct SetTraceIdLayer<T>
30where
31    T: MakeTraceId,
32{
33    header_name: Option<HeaderName>,
34    _phantom: PhantomData<T>,
35}
36
37impl<T> SetTraceIdLayer<T>
38where
39    T: MakeTraceId,
40{
41    pub fn new() -> Self {
42        Self {
43            header_name: None,
44            _phantom: Default::default(),
45        }
46    }
47
48    pub fn with_header_name(self, header_name: &str) -> Self {
49        Self {
50            header_name: Some(HeaderName::from_str(header_name).unwrap()),
51            _phantom: Default::default(),
52        }
53    }
54}
55
56impl<T> Default for SetTraceIdLayer<T>
57where
58    T: MakeTraceId,
59{
60    fn default() -> Self {
61        SetTraceIdLayer::new()
62    }
63}
64
65impl<S, T> Layer<S> for SetTraceIdLayer<T>
66where
67    T: MakeTraceId,
68{
69    type Service = TraceIdMiddleware<S, T>;
70
71    fn layer(&self, inner: S) -> Self::Service {
72        TraceIdMiddleware {
73            inner,
74            header_name: self.header_name.clone(),
75            _phantom: Default::default(),
76        }
77    }
78}
79
80#[derive(Clone)]
81pub struct TraceIdMiddleware<S, T> {
82    inner: S,
83    header_name: Option<HeaderName>,
84    _phantom: PhantomData<T>,
85}
86
87impl<S, T, Rq, Rs> Service<Request<Rq>> for TraceIdMiddleware<S, T>
88where
89    S: Service<Request<Rq>, Response = Response<Rs>> + Send + 'static,
90    S::Future: Send + 'static,
91    T: MakeTraceId + 'static,
92{
93    type Response = S::Response;
94    type Error = S::Error;
95    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
96
97    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
98        self.inner.poll_ready(cx)
99    }
100
101    fn call(&mut self, mut req: Request<Rq>) -> Self::Future {
102        let trace_id = TraceId::<T>::new();
103        req.extensions_mut().insert(trace_id.clone());
104
105        // Add TraceId header to request
106        let mut header_val: Option<HeaderValue> = None;
107        if let Some(header_name) = self.header_name.clone() {
108            header_val = Some(
109                HeaderValue::try_from(trace_id.id.to_string())
110                    .unwrap_or(HeaderValue::from_static("unavailable")),
111            );
112            req.headers_mut()
113                .insert(header_name, header_val.clone().unwrap());
114        }
115
116        let future = self.inner.call(req);
117        let moved_header_name = self.header_name.clone();
118        Box::pin(async move {
119            let mut response: Response<Rs> = future.await?;
120
121            // Add TraceId header to response
122            if let Some(header_name) = moved_header_name {
123                response
124                    .headers_mut()
125                    .insert(header_name, header_val.unwrap());
126            }
127
128            Ok(response)
129        })
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use hyper::body::Body;
137    use std::cell::RefCell;
138    use std::convert::Infallible;
139    use std::sync::{Arc, Mutex};
140    use tower::{ServiceBuilder, ServiceExt};
141
142    #[tokio::test]
143    async fn test_extension_not_added() {
144        // Gets set to 1 when the assert_no_trace_id was called.
145        let call_arc = Arc::new(RefCell::new(0));
146
147        let assert_no_trace_id = |mut req: Request<Body>| -> Request<Body> {
148            call_arc.replace(1);
149            assert!(req.extensions_mut().get::<TraceId<String>>().is_none());
150            req
151        };
152
153        let test_svc = ServiceBuilder::new()
154            .map_request(assert_no_trace_id)
155            .service_fn(|_req: Request<Body>| async {
156                let res: Result<(), Infallible> = Ok(());
157                res
158            });
159
160        let req = Request::new(Body::empty());
161        test_svc.oneshot(req).await.unwrap();
162
163        // Assert that assert_no_trace_id was actually called
164        assert_eq!(call_arc.take(), 1)
165    }
166
167    #[tokio::test]
168    async fn test_extension_added() {
169        // Gets set to 1 when the assert_trace_id was called.
170        let call_arc = Arc::new(Mutex::new(0));
171
172        let moved_call_arc = call_arc.clone();
173        let assert_trace_id = move |mut req: Request<Body>| -> Request<Body> {
174            let mut calls = moved_call_arc.lock().unwrap();
175            *calls = 1;
176            assert!(req.extensions_mut().get::<TraceId<String>>().is_some());
177            req
178        };
179
180        let test_svc = ServiceBuilder::new()
181            .layer(SetTraceIdLayer::<String>::new())
182            .map_request(assert_trace_id)
183            .service_fn(|_req: Request<Body>| async {
184                let res: Result<Response<Body>, Infallible> = Ok(Response::new(Body::empty()));
185                res
186            });
187
188        let req = Request::new(Body::empty());
189        test_svc.oneshot(req).await.unwrap();
190
191        // Assert that assert_trace_id was actually called
192        let calls = call_arc.lock().unwrap();
193        assert_eq!(*calls, 1);
194    }
195
196    #[tokio::test]
197    async fn test_header_added() {
198        let header_name = "x-trace-id";
199
200        // Gets set to 1 when the assert_trace_id was called.
201        let call_arc = Arc::new(Mutex::new(0));
202
203        let moved_call_arc = call_arc.clone();
204        let assert_trace_id = move |mut req: Request<Body>| -> Request<Body> {
205            let mut calls = moved_call_arc.lock().unwrap();
206            *calls = 1;
207            assert!(req.extensions_mut().get::<TraceId<String>>().is_some());
208            req
209        };
210
211        let test_svc = ServiceBuilder::new()
212            .layer(SetTraceIdLayer::<String>::new().with_header_name(header_name))
213            .map_request(assert_trace_id)
214            .service_fn(|_req: Request<Body>| async {
215                let res: Result<Response<Body>, Infallible> = Ok(Response::new(Body::empty()));
216                res
217            });
218
219        let req = Request::new(Body::empty());
220        let resp = test_svc.oneshot(req).await.unwrap();
221
222        assert!(resp.headers().get(header_name).is_some());
223    }
224}