1#![doc(html_root_url = "https://docs.rs/hyper-zipkin/0.4")]
17#![warn(missing_docs)]
18extern crate zipkin;
19
20#[macro_use]
21extern crate hyper;
22
23use hyper::header::{Formatter, Header, Headers, Raw};
24use std::fmt;
25use std::ops::{Deref, DerefMut};
26use zipkin::{SamplingFlags, SpanId, TraceContext, TraceId};
27
28header! {
29 #[derive(Copy)] (XB3TraceId, "X-B3-TraceId") => [TraceId]
34}
35
36header! {
37 #[derive(Copy)] (XB3SpanId, "X-B3-SpanId") => [SpanId]
42}
43
44header! {
45 #[derive(Copy)] (XB3ParentSpanId, "X-B3-ParentSpanId") => [SpanId]
50}
51
52#[derive(Copy, Clone, Debug, PartialEq)]
57pub struct XB3Flags;
58
59impl Header for XB3Flags {
60 fn header_name() -> &'static str {
61 "X-B3-Flags"
62 }
63
64 fn parse_header(raw: &Raw) -> hyper::Result<XB3Flags> {
65 if let Some(line) = raw.one() {
66 if line.len() == 1 {
67 let byte = line[0];
68 match byte {
69 b'1' => return Ok(XB3Flags),
70 _ => {}
71 }
72 }
73 }
74 Err(hyper::Error::Header)
75 }
76
77 fn fmt_header(&self, fmt: &mut Formatter) -> fmt::Result {
78 fmt.fmt_line(&"1")
79 }
80}
81
82#[derive(Copy, Clone, Debug, PartialEq)]
88pub struct XB3Sampled(pub bool);
89
90impl Deref for XB3Sampled {
91 type Target = bool;
92
93 fn deref(&self) -> &bool {
94 &self.0
95 }
96}
97
98impl DerefMut for XB3Sampled {
99 fn deref_mut(&mut self) -> &mut bool {
100 &mut self.0
101 }
102}
103
104impl Header for XB3Sampled {
105 fn header_name() -> &'static str {
106 "X-B3-Sampled"
107 }
108
109 fn parse_header(raw: &Raw) -> hyper::Result<XB3Sampled> {
110 if let Some(line) = raw.one() {
111 if line.len() == 1 {
112 let byte = line[0];
113 match byte {
114 b'0' => return Ok(XB3Sampled(false)),
115 b'1' => return Ok(XB3Sampled(true)),
116 _ => {}
117 }
118 }
119 }
120 Err(hyper::Error::Header)
121 }
122
123 fn fmt_header(&self, fmt: &mut Formatter) -> fmt::Result {
124 let s = if self.0 { "1" } else { "0" };
125 fmt.fmt_line(&s)
126 }
127}
128
129pub fn get_sampling_flags(headers: &Headers) -> SamplingFlags {
131 let mut builder = SamplingFlags::builder();
132
133 if let Some(sampled) = headers.get::<XB3Sampled>() {
134 builder.sampled(sampled.0);
135 }
136
137 if let Some(&XB3Flags) = headers.get::<XB3Flags>() {
138 builder.debug(true);
139 }
140
141 builder.build()
142}
143
144pub fn set_sampling_flags(flags: SamplingFlags, headers: &mut Headers) {
146 if flags.debug() {
147 headers.set(XB3Flags);
148 } else if let Some(sampled) = flags.sampled() {
149 headers.set(XB3Sampled(sampled));
150 }
151}
152
153pub fn get_trace_context(headers: &Headers) -> Option<TraceContext> {
155 let trace_id = headers.get::<XB3TraceId>()?.0;
156 let span_id = headers.get::<XB3SpanId>()?.0;
157
158 let mut context = TraceContext::builder();
159 context
160 .trace_id(trace_id)
161 .span_id(span_id)
162 .sampling_flags(get_sampling_flags(headers));
163
164 if let Some(parent_id) = headers.get::<XB3ParentSpanId>() {
165 context.parent_id(parent_id.0);
166 }
167
168 Some(context.build())
169}
170
171pub fn set_trace_context(context: TraceContext, headers: &mut Headers) {
173 headers.set(XB3TraceId(context.trace_id()));
174 headers.set(XB3SpanId(context.span_id()));
175
176 if let Some(parent_id) = context.parent_id() {
177 headers.set(XB3ParentSpanId(parent_id));
178 }
179
180 set_sampling_flags(context.sampling_flags(), headers);
181}
182
183#[cfg(test)]
184mod test {
185 use super::*;
186
187 #[test]
188 fn flags_empty() {
189 let mut headers = Headers::new();
190 let flags = SamplingFlags::builder().build();
191 set_sampling_flags(flags, &mut headers);
192
193 let expected_headers = Headers::new();
194 assert_eq!(headers, expected_headers);
195
196 assert_eq!(get_sampling_flags(&headers), flags);
197 }
198
199 #[test]
200 fn flags_debug() {
201 let mut headers = Headers::new();
202 let flags = SamplingFlags::builder().debug(true).build();
203 set_sampling_flags(flags, &mut headers);
204
205 let mut expected_headers = Headers::new();
206 expected_headers.set_raw("X-B3-Flags", "1");
207 assert_eq!(headers, expected_headers);
208
209 assert_eq!(get_sampling_flags(&headers), flags);
210 }
211
212 #[test]
213 fn flags_sampled() {
214 let mut headers = Headers::new();
215 let flags = SamplingFlags::builder().sampled(true).build();
216 set_sampling_flags(flags, &mut headers);
217
218 let mut expected_headers = Headers::new();
219 expected_headers.set_raw("X-B3-Sampled", "1");
220 assert_eq!(headers, expected_headers);
221
222 assert_eq!(get_sampling_flags(&headers), flags);
223 }
224
225 #[test]
226 fn flags_unsampled() {
227 let mut headers = Headers::new();
228 let flags = SamplingFlags::builder().sampled(false).build();
229 set_sampling_flags(flags, &mut headers);
230
231 let mut expected_headers = Headers::new();
232 expected_headers.set_raw("X-B3-Sampled", "0");
233 assert_eq!(headers, expected_headers);
234
235 assert_eq!(get_sampling_flags(&headers), flags);
236 }
237
238 #[test]
239 fn trace_context() {
240 let mut headers = Headers::new();
241 let context = TraceContext::builder()
242 .trace_id([0, 1, 2, 3, 4, 5, 6, 7].into())
243 .parent_id([1, 2, 3, 4, 5, 6, 7, 8].into())
244 .span_id([2, 3, 4, 5, 6, 7, 8, 9].into())
245 .sampled(true)
246 .build();
247 set_trace_context(context, &mut headers);
248
249 let mut expected_headers = Headers::new();
250 expected_headers.set_raw("X-B3-TraceId", "0001020304050607");
251 expected_headers.set_raw("X-B3-SpanId", "0203040506070809");
252 expected_headers.set_raw("X-B3-ParentSpanId", "0102030405060708");
253 expected_headers.set_raw("X-B3-Sampled", "1");
254 assert_eq!(headers, expected_headers);
255
256 assert_eq!(get_trace_context(&headers), Some(context));
257 }
258}