1use std::marker::PhantomData;
2use std::str::FromStr;
3use std::task::{Context, Poll};
4
5use futures::future::BoxFuture;
6use hyper::http::{HeaderName, HeaderValue, Request, Response};
7use tower::{Layer, Service};
8
9use crate::{MakeTraceId, TraceId};
10
11#[derive(Debug, Clone)]
29pub struct SetTraceIdLayer<T>
30where
31 T: MakeTraceId,
32{
33 header_name: Option<HeaderName>,
34 _phantom: PhantomData<T>,
35}
36
37impl<T> SetTraceIdLayer<T>
38where
39 T: MakeTraceId,
40{
41 pub fn new() -> Self {
42 Self {
43 header_name: None,
44 _phantom: Default::default(),
45 }
46 }
47
48 pub fn with_header_name(self, header_name: &str) -> Self {
49 Self {
50 header_name: Some(HeaderName::from_str(header_name).unwrap()),
51 _phantom: Default::default(),
52 }
53 }
54}
55
56impl<T> Default for SetTraceIdLayer<T>
57where
58 T: MakeTraceId,
59{
60 fn default() -> Self {
61 SetTraceIdLayer::new()
62 }
63}
64
65impl<S, T> Layer<S> for SetTraceIdLayer<T>
66where
67 T: MakeTraceId,
68{
69 type Service = TraceIdMiddleware<S, T>;
70
71 fn layer(&self, inner: S) -> Self::Service {
72 TraceIdMiddleware {
73 inner,
74 header_name: self.header_name.clone(),
75 _phantom: Default::default(),
76 }
77 }
78}
79
80#[derive(Clone)]
81pub struct TraceIdMiddleware<S, T> {
82 inner: S,
83 header_name: Option<HeaderName>,
84 _phantom: PhantomData<T>,
85}
86
87impl<S, T, Rq, Rs> Service<Request<Rq>> for TraceIdMiddleware<S, T>
88where
89 S: Service<Request<Rq>, Response = Response<Rs>> + Send + 'static,
90 S::Future: Send + 'static,
91 T: MakeTraceId + 'static,
92{
93 type Response = S::Response;
94 type Error = S::Error;
95 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
96
97 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
98 self.inner.poll_ready(cx)
99 }
100
101 fn call(&mut self, mut req: Request<Rq>) -> Self::Future {
102 let trace_id = TraceId::<T>::new();
103 req.extensions_mut().insert(trace_id.clone());
104
105 let mut header_val: Option<HeaderValue> = None;
107 if let Some(header_name) = self.header_name.clone() {
108 header_val = Some(
109 HeaderValue::try_from(trace_id.id.to_string())
110 .unwrap_or(HeaderValue::from_static("unavailable")),
111 );
112 req.headers_mut()
113 .insert(header_name, header_val.clone().unwrap());
114 }
115
116 let future = self.inner.call(req);
117 let moved_header_name = self.header_name.clone();
118 Box::pin(async move {
119 let mut response: Response<Rs> = future.await?;
120
121 if let Some(header_name) = moved_header_name {
123 response
124 .headers_mut()
125 .insert(header_name, header_val.unwrap());
126 }
127
128 Ok(response)
129 })
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use hyper::body::Body;
137 use std::cell::RefCell;
138 use std::convert::Infallible;
139 use std::sync::{Arc, Mutex};
140 use tower::{ServiceBuilder, ServiceExt};
141
142 #[tokio::test]
143 async fn test_extension_not_added() {
144 let call_arc = Arc::new(RefCell::new(0));
146
147 let assert_no_trace_id = |mut req: Request<Body>| -> Request<Body> {
148 call_arc.replace(1);
149 assert!(req.extensions_mut().get::<TraceId<String>>().is_none());
150 req
151 };
152
153 let test_svc = ServiceBuilder::new()
154 .map_request(assert_no_trace_id)
155 .service_fn(|_req: Request<Body>| async {
156 let res: Result<(), Infallible> = Ok(());
157 res
158 });
159
160 let req = Request::new(Body::empty());
161 test_svc.oneshot(req).await.unwrap();
162
163 assert_eq!(call_arc.take(), 1)
165 }
166
167 #[tokio::test]
168 async fn test_extension_added() {
169 let call_arc = Arc::new(Mutex::new(0));
171
172 let moved_call_arc = call_arc.clone();
173 let assert_trace_id = move |mut req: Request<Body>| -> Request<Body> {
174 let mut calls = moved_call_arc.lock().unwrap();
175 *calls = 1;
176 assert!(req.extensions_mut().get::<TraceId<String>>().is_some());
177 req
178 };
179
180 let test_svc = ServiceBuilder::new()
181 .layer(SetTraceIdLayer::<String>::new())
182 .map_request(assert_trace_id)
183 .service_fn(|_req: Request<Body>| async {
184 let res: Result<Response<Body>, Infallible> = Ok(Response::new(Body::empty()));
185 res
186 });
187
188 let req = Request::new(Body::empty());
189 test_svc.oneshot(req).await.unwrap();
190
191 let calls = call_arc.lock().unwrap();
193 assert_eq!(*calls, 1);
194 }
195
196 #[tokio::test]
197 async fn test_header_added() {
198 let header_name = "x-trace-id";
199
200 let call_arc = Arc::new(Mutex::new(0));
202
203 let moved_call_arc = call_arc.clone();
204 let assert_trace_id = move |mut req: Request<Body>| -> Request<Body> {
205 let mut calls = moved_call_arc.lock().unwrap();
206 *calls = 1;
207 assert!(req.extensions_mut().get::<TraceId<String>>().is_some());
208 req
209 };
210
211 let test_svc = ServiceBuilder::new()
212 .layer(SetTraceIdLayer::<String>::new().with_header_name(header_name))
213 .map_request(assert_trace_id)
214 .service_fn(|_req: Request<Body>| async {
215 let res: Result<Response<Body>, Infallible> = Ok(Response::new(Body::empty()));
216 res
217 });
218
219 let req = Request::new(Body::empty());
220 let resp = test_svc.oneshot(req).await.unwrap();
221
222 assert!(resp.headers().get(header_name).is_some());
223 }
224}