1#![doc(html_root_url = "https://docs.rs/http-zipkin/0.3")]
17#![warn(missing_docs)]
18
19use http::header::{HeaderMap, HeaderValue};
20use std::fmt::Write;
21use std::str::FromStr;
22use zipkin::{SamplingFlags, TraceContext};
23
24const X_B3_SAMPLED: &str = "X-B3-Sampled";
25const X_B3_FLAGS: &str = "X-B3-Flags";
26const X_B3_TRACEID: &str = "X-B3-TraceId";
27const X_B3_PARENTSPANID: &str = "X-B3-ParentSpanId";
28const X_B3_SPANID: &str = "X-B3-SpanId";
29const B3: &str = "b3";
30
31pub fn set_sampling_flags_single(flags: SamplingFlags, headers: &mut HeaderMap) {
35 if flags.debug() {
36 headers.insert(B3, HeaderValue::from_static("d"));
37 } else if flags.sampled() == Some(true) {
38 headers.insert(B3, HeaderValue::from_static("1"));
39 } else if flags.sampled() == Some(false) {
40 headers.insert(B3, HeaderValue::from_static("0"));
41 } else {
42 headers.remove(B3);
43 }
44}
45
46pub fn set_sampling_flags(flags: SamplingFlags, headers: &mut HeaderMap) {
48 if flags.debug() {
49 headers.insert(X_B3_FLAGS, HeaderValue::from_static("1"));
50 headers.remove(X_B3_SAMPLED);
51 } else {
52 headers.remove(X_B3_FLAGS);
53 match flags.sampled() {
54 Some(true) => {
55 headers.insert(X_B3_SAMPLED, HeaderValue::from_static("1"));
56 }
57 Some(false) => {
58 headers.insert(X_B3_SAMPLED, HeaderValue::from_static("0"));
59 }
60 None => {
61 headers.remove(X_B3_SAMPLED);
62 }
63 }
64 }
65}
66
67pub fn get_sampling_flags(headers: &HeaderMap) -> SamplingFlags {
69 match headers.get(B3) {
70 Some(value) => get_sampling_flags_single(value),
71 None => get_sampling_flags_multi(headers),
72 }
73}
74
75fn get_sampling_flags_single(value: &HeaderValue) -> SamplingFlags {
76 let mut builder = SamplingFlags::builder();
77
78 if value == "d" {
79 builder.debug(true);
80 } else if value == "1" {
81 builder.sampled(true);
82 } else if value == "0" {
83 builder.sampled(false);
84 } else if let Some(context) = get_trace_context_single(value) {
85 return context.sampling_flags();
86 }
87
88 builder.build()
89}
90
91fn get_sampling_flags_multi(headers: &HeaderMap) -> SamplingFlags {
92 let mut builder = SamplingFlags::builder();
93
94 if let Some(flags) = headers.get(X_B3_FLAGS) {
95 if flags == "1" {
96 builder.debug(true);
97 }
98 } else if let Some(sampled) = headers.get(X_B3_SAMPLED) {
99 if sampled == "1" {
100 builder.sampled(true);
101 } else if sampled == "0" {
102 builder.sampled(false);
103 }
104 }
105
106 builder.build()
107}
108
109pub fn set_trace_context_single(context: TraceContext, headers: &mut HeaderMap) {
113 let mut value = String::new();
114 write!(value, "{}-{}", context.trace_id(), context.span_id()).unwrap();
115 if context.debug() {
116 value.push_str("-d");
117 } else if context.sampled() == Some(true) {
118 value.push_str("-1");
119 } else if context.sampled() == Some(false) {
120 value.push_str("-0");
121 }
122 if let Some(parent_id) = context.parent_id() {
123 write!(value, "-{}", parent_id).unwrap();
124 }
125 headers.insert(B3, HeaderValue::from_str(&value).unwrap());
126}
127
128pub fn set_trace_context(context: TraceContext, headers: &mut HeaderMap) {
130 set_sampling_flags(context.sampling_flags(), headers);
131
132 headers.insert(
133 X_B3_TRACEID,
134 HeaderValue::from_str(&context.trace_id().to_string()).unwrap(),
135 );
136 match context.parent_id() {
137 Some(parent_id) => {
138 headers.insert(
139 X_B3_PARENTSPANID,
140 HeaderValue::from_str(&parent_id.to_string()).unwrap(),
141 );
142 }
143 None => {
144 headers.remove(X_B3_PARENTSPANID);
145 }
146 }
147 headers.insert(
148 X_B3_SPANID,
149 HeaderValue::from_str(&context.span_id().to_string()).unwrap(),
150 );
151}
152
153pub fn get_trace_context(headers: &HeaderMap) -> Option<TraceContext> {
155 match headers.get(B3) {
156 Some(value) => get_trace_context_single(value),
157 None => get_trace_context_multi(headers),
158 }
159}
160
161fn get_trace_context_single(value: &HeaderValue) -> Option<TraceContext> {
162 let mut parts = value.to_str().ok()?.split('-');
163
164 let trace_id = parts.next()?.parse().ok()?;
165 let span_id = parts.next()?.parse().ok()?;
166
167 let mut builder = TraceContext::builder();
168 builder.trace_id(trace_id).span_id(span_id);
169
170 let maybe_sampling = match parts.next() {
171 Some(next) => next,
172 None => return Some(builder.build()),
173 };
174
175 let parent_id = if maybe_sampling == "d" {
176 builder.debug(true);
177 parts.next()
178 } else if maybe_sampling == "1" {
179 builder.sampled(true);
180 parts.next()
181 } else if maybe_sampling == "0" {
182 builder.sampled(false);
183 parts.next()
184 } else {
185 Some(maybe_sampling)
186 };
187
188 if let Some(parent_id) = parent_id {
189 builder.parent_id(parent_id.parse().ok()?);
190 }
191
192 Some(builder.build())
193}
194
195fn get_trace_context_multi(headers: &HeaderMap) -> Option<TraceContext> {
196 let trace_id = parse_header(headers, X_B3_TRACEID)?;
197 let span_id = parse_header(headers, X_B3_SPANID)?;
198
199 let mut builder = TraceContext::builder();
200 builder
201 .trace_id(trace_id)
202 .span_id(span_id)
203 .sampling_flags(get_sampling_flags_multi(headers));
204
205 if let Some(parent_id) = parse_header(headers, X_B3_PARENTSPANID) {
206 builder.parent_id(parent_id);
207 }
208
209 Some(builder.build())
210}
211
212fn parse_header<T>(headers: &HeaderMap, name: &str) -> Option<T>
213where
214 T: FromStr,
215{
216 headers
217 .get(name)
218 .and_then(|v| v.to_str().ok())
219 .and_then(|s| s.parse().ok())
220}
221
222#[cfg(test)]
223mod test {
224 use super::*;
225
226 #[test]
227 fn flags_empty() {
228 let mut headers = HeaderMap::new();
229 let flags = SamplingFlags::builder().build();
230 set_sampling_flags(flags, &mut headers);
231
232 let expected_headers = HeaderMap::new();
233 assert_eq!(headers, expected_headers);
234
235 assert_eq!(get_sampling_flags(&headers), flags);
236 }
237
238 #[test]
239 fn flags_empty_single() {
240 let mut headers = HeaderMap::new();
241 let flags = SamplingFlags::builder().build();
242 set_sampling_flags_single(flags, &mut headers);
243
244 let expected_headers = HeaderMap::new();
245 assert_eq!(headers, expected_headers);
246
247 assert_eq!(get_sampling_flags(&headers), flags);
248 }
249
250 #[test]
251 fn flags_debug() {
252 let mut headers = HeaderMap::new();
253 let flags = SamplingFlags::builder().debug(true).build();
254 set_sampling_flags(flags, &mut headers);
255
256 let mut expected_headers = HeaderMap::new();
257 expected_headers.insert("X-B3-Flags", HeaderValue::from_static("1"));
258 assert_eq!(headers, expected_headers);
259
260 assert_eq!(get_sampling_flags(&headers), flags);
261 }
262
263 #[test]
264 fn flags_debug_single() {
265 let mut headers = HeaderMap::new();
266 let flags = SamplingFlags::builder().debug(true).build();
267 set_sampling_flags_single(flags, &mut headers);
268
269 let mut expected_headers = HeaderMap::new();
270 expected_headers.insert("b3", HeaderValue::from_static("d"));
271 assert_eq!(headers, expected_headers);
272
273 assert_eq!(get_sampling_flags(&headers), flags);
274 }
275
276 #[test]
277 fn flags_sampled() {
278 let mut headers = HeaderMap::new();
279 let flags = SamplingFlags::builder().sampled(true).build();
280 set_sampling_flags(flags, &mut headers);
281
282 let mut expected_headers = HeaderMap::new();
283 expected_headers.insert("X-B3-Sampled", HeaderValue::from_static("1"));
284 assert_eq!(headers, expected_headers);
285
286 assert_eq!(get_sampling_flags(&headers), flags);
287 }
288
289 #[test]
290 fn flags_sampled_single() {
291 let mut headers = HeaderMap::new();
292 let flags = SamplingFlags::builder().sampled(true).build();
293 set_sampling_flags_single(flags, &mut headers);
294
295 let mut expected_headers = HeaderMap::new();
296 expected_headers.insert("b3", HeaderValue::from_static("1"));
297 assert_eq!(headers, expected_headers);
298
299 assert_eq!(get_sampling_flags(&headers), flags);
300 }
301
302 #[test]
303 fn flags_unsampled() {
304 let mut headers = HeaderMap::new();
305 let flags = SamplingFlags::builder().sampled(false).build();
306 set_sampling_flags(flags, &mut headers);
307
308 let mut expected_headers = HeaderMap::new();
309 expected_headers.insert("X-B3-Sampled", HeaderValue::from_static("0"));
310 assert_eq!(headers, expected_headers);
311
312 assert_eq!(get_sampling_flags(&headers), flags);
313 }
314
315 #[test]
316 fn flags_unsampled_single() {
317 let mut headers = HeaderMap::new();
318 let flags = SamplingFlags::builder().sampled(false).build();
319 set_sampling_flags_single(flags, &mut headers);
320
321 let mut expected_headers = HeaderMap::new();
322 expected_headers.insert("b3", HeaderValue::from_static("0"));
323 assert_eq!(headers, expected_headers);
324
325 assert_eq!(get_sampling_flags(&headers), flags);
326 }
327
328 #[test]
329 fn trace_context() {
330 let mut headers = HeaderMap::new();
331 let context = TraceContext::builder()
332 .trace_id([0, 1, 2, 3, 4, 5, 6, 7].into())
333 .parent_id([1, 2, 3, 4, 5, 6, 7, 8].into())
334 .span_id([2, 3, 4, 5, 6, 7, 8, 9].into())
335 .sampled(true)
336 .build();
337 set_trace_context(context, &mut headers);
338
339 let mut expected_headers = HeaderMap::new();
340 expected_headers.insert("X-B3-TraceId", HeaderValue::from_static("0001020304050607"));
341 expected_headers.insert("X-B3-SpanId", HeaderValue::from_static("0203040506070809"));
342 expected_headers.insert(
343 "X-B3-ParentSpanId",
344 HeaderValue::from_static("0102030405060708"),
345 );
346 expected_headers.insert("X-B3-Sampled", HeaderValue::from_static("1"));
347 assert_eq!(headers, expected_headers);
348
349 assert_eq!(get_trace_context(&headers), Some(context));
350 }
351
352 #[test]
353 fn trace_context_single() {
354 let mut headers = HeaderMap::new();
355 let context = TraceContext::builder()
356 .trace_id([0, 1, 2, 3, 4, 5, 6, 7].into())
357 .parent_id([1, 2, 3, 4, 5, 6, 7, 8].into())
358 .span_id([2, 3, 4, 5, 6, 7, 8, 9].into())
359 .sampled(true)
360 .build();
361 set_trace_context_single(context, &mut headers);
362
363 let mut expected_headers = HeaderMap::new();
364 expected_headers.insert(
365 "b3",
366 HeaderValue::from_static("0001020304050607-0203040506070809-1-0102030405060708"),
367 );
368 assert_eq!(headers, expected_headers);
369
370 assert_eq!(get_trace_context(&headers), Some(context));
371 }
372
373 #[test]
374 fn trace_context_unsampled_single() {
375 let mut headers = HeaderMap::new();
376 let context = TraceContext::builder()
377 .trace_id([0, 1, 2, 3, 4, 5, 6, 7].into())
378 .parent_id([1, 2, 3, 4, 5, 6, 7, 8].into())
379 .span_id([2, 3, 4, 5, 6, 7, 8, 9].into())
380 .build();
381 set_trace_context_single(context, &mut headers);
382
383 let mut expected_headers = HeaderMap::new();
384 expected_headers.insert(
385 "b3",
386 HeaderValue::from_static("0001020304050607-0203040506070809-0102030405060708"),
387 );
388 assert_eq!(headers, expected_headers);
389
390 assert_eq!(get_trace_context(&headers), Some(context));
391 }
392
393 #[test]
394 fn trace_context_parentless_single() {
395 let mut headers = HeaderMap::new();
396 let context = TraceContext::builder()
397 .trace_id([0, 1, 2, 3, 4, 5, 6, 7].into())
398 .span_id([2, 3, 4, 5, 6, 7, 8, 9].into())
399 .sampled(true)
400 .build();
401 set_trace_context_single(context, &mut headers);
402
403 let mut expected_headers = HeaderMap::new();
404 expected_headers.insert(
405 "b3",
406 HeaderValue::from_static("0001020304050607-0203040506070809-1"),
407 );
408 assert_eq!(headers, expected_headers);
409
410 assert_eq!(get_trace_context(&headers), Some(context));
411 }
412
413 #[test]
414 fn trace_context_minimal_single() {
415 let mut headers = HeaderMap::new();
416 let context = TraceContext::builder()
417 .trace_id([0, 1, 2, 3, 4, 5, 6, 7].into())
418 .span_id([2, 3, 4, 5, 6, 7, 8, 9].into())
419 .build();
420 set_trace_context_single(context, &mut headers);
421
422 let mut expected_headers = HeaderMap::new();
423 expected_headers.insert(
424 "b3",
425 HeaderValue::from_static("0001020304050607-0203040506070809"),
426 );
427 assert_eq!(headers, expected_headers);
428
429 assert_eq!(get_trace_context(&headers), Some(context));
430 }
431}