hyper_zipkin/
lib.rs

1//  Copyright 2017 Palantir Technologies, Inc.
2//
3//  Licensed under the Apache License, Version 2.0 (the "License");
4//  you may not use this file except in compliance with the License.
5//  You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14
15//! Hyper definitions for Zipkin headers.
16#![doc(html_root_url = "https://docs.rs/hyper-zipkin/0.4")]
17#![warn(missing_docs)]
18extern crate zipkin;
19
20#[macro_use]
21extern crate hyper;
22
23use hyper::header::{Formatter, Header, Headers, Raw};
24use std::fmt;
25use std::ops::{Deref, DerefMut};
26use zipkin::{SamplingFlags, SpanId, TraceContext, TraceId};
27
28header! {
29    /// The `X-B3-TraceId` header.
30    ///
31    /// Its value is a hexadecimal-encoded 8 or 16 byte trace ID. It corresponds
32    /// to the `trace_id` field of a `TraceContext`.
33    #[derive(Copy)] (XB3TraceId, "X-B3-TraceId") => [TraceId]
34}
35
36header! {
37    /// The `X-B3-SpanId` header.
38    ///
39    /// Its value is a hexadecimal-encoded 8 byte span ID. It corresponds to the
40    /// `span_id` field of a `TraceContext`.
41    #[derive(Copy)] (XB3SpanId, "X-B3-SpanId") => [SpanId]
42}
43
44header! {
45    /// The `X-B3-ParentSpanID` header.
46    ///
47    /// Its value is a hexadecimal-encoded 8 byte span ID. It corresponds to the
48    /// `parent_id` field of a `TraceContext`.
49    #[derive(Copy)] (XB3ParentSpanId, "X-B3-ParentSpanId") => [SpanId]
50}
51
52/// The `X-B3-Flags` header.
53///
54/// Its value is always `1` if present, which indicates that the context is in
55/// debug mode. It corresponds to the `debug` field of a `TraceContext`.
56#[derive(Copy, Clone, Debug, PartialEq)]
57pub struct XB3Flags;
58
59impl Header for XB3Flags {
60    fn header_name() -> &'static str {
61        "X-B3-Flags"
62    }
63
64    fn parse_header(raw: &Raw) -> hyper::Result<XB3Flags> {
65        if let Some(line) = raw.one() {
66            if line.len() == 1 {
67                let byte = line[0];
68                match byte {
69                    b'1' => return Ok(XB3Flags),
70                    _ => {}
71                }
72            }
73        }
74        Err(hyper::Error::Header)
75    }
76
77    fn fmt_header(&self, fmt: &mut Formatter) -> fmt::Result {
78        fmt.fmt_line(&"1")
79    }
80}
81
82/// The `X-B3-Sampled` header.
83///
84/// Its value is either `0` or `1`, and indicates if the client has requested
85/// that the context be sampled or not. It correponds to the `sampled` field of
86/// a `TraceContext`.
87#[derive(Copy, Clone, Debug, PartialEq)]
88pub struct XB3Sampled(pub bool);
89
90impl Deref for XB3Sampled {
91    type Target = bool;
92
93    fn deref(&self) -> &bool {
94        &self.0
95    }
96}
97
98impl DerefMut for XB3Sampled {
99    fn deref_mut(&mut self) -> &mut bool {
100        &mut self.0
101    }
102}
103
104impl Header for XB3Sampled {
105    fn header_name() -> &'static str {
106        "X-B3-Sampled"
107    }
108
109    fn parse_header(raw: &Raw) -> hyper::Result<XB3Sampled> {
110        if let Some(line) = raw.one() {
111            if line.len() == 1 {
112                let byte = line[0];
113                match byte {
114                    b'0' => return Ok(XB3Sampled(false)),
115                    b'1' => return Ok(XB3Sampled(true)),
116                    _ => {}
117                }
118            }
119        }
120        Err(hyper::Error::Header)
121    }
122
123    fn fmt_header(&self, fmt: &mut Formatter) -> fmt::Result {
124        let s = if self.0 { "1" } else { "0" };
125        fmt.fmt_line(&s)
126    }
127}
128
129/// Constructs `SamplingFlags` from a set of headers.
130pub fn get_sampling_flags(headers: &Headers) -> SamplingFlags {
131    let mut builder = SamplingFlags::builder();
132
133    if let Some(sampled) = headers.get::<XB3Sampled>() {
134        builder.sampled(sampled.0);
135    }
136
137    if let Some(&XB3Flags) = headers.get::<XB3Flags>() {
138        builder.debug(true);
139    }
140
141    builder.build()
142}
143
144/// Serializes `SamplingFlags` into a set of headers.
145pub fn set_sampling_flags(flags: SamplingFlags, headers: &mut Headers) {
146    if flags.debug() {
147        headers.set(XB3Flags);
148    } else if let Some(sampled) = flags.sampled() {
149        headers.set(XB3Sampled(sampled));
150    }
151}
152
153/// Constructs a `TraceContext` from a set of headers.
154pub fn get_trace_context(headers: &Headers) -> Option<TraceContext> {
155    let trace_id = headers.get::<XB3TraceId>()?.0;
156    let span_id = headers.get::<XB3SpanId>()?.0;
157
158    let mut context = TraceContext::builder();
159    context
160        .trace_id(trace_id)
161        .span_id(span_id)
162        .sampling_flags(get_sampling_flags(headers));
163
164    if let Some(parent_id) = headers.get::<XB3ParentSpanId>() {
165        context.parent_id(parent_id.0);
166    }
167
168    Some(context.build())
169}
170
171/// Serializes a `TraceContext` into a set of headers.
172pub fn set_trace_context(context: TraceContext, headers: &mut Headers) {
173    headers.set(XB3TraceId(context.trace_id()));
174    headers.set(XB3SpanId(context.span_id()));
175
176    if let Some(parent_id) = context.parent_id() {
177        headers.set(XB3ParentSpanId(parent_id));
178    }
179
180    set_sampling_flags(context.sampling_flags(), headers);
181}
182
183#[cfg(test)]
184mod test {
185    use super::*;
186
187    #[test]
188    fn flags_empty() {
189        let mut headers = Headers::new();
190        let flags = SamplingFlags::builder().build();
191        set_sampling_flags(flags, &mut headers);
192
193        let expected_headers = Headers::new();
194        assert_eq!(headers, expected_headers);
195
196        assert_eq!(get_sampling_flags(&headers), flags);
197    }
198
199    #[test]
200    fn flags_debug() {
201        let mut headers = Headers::new();
202        let flags = SamplingFlags::builder().debug(true).build();
203        set_sampling_flags(flags, &mut headers);
204
205        let mut expected_headers = Headers::new();
206        expected_headers.set_raw("X-B3-Flags", "1");
207        assert_eq!(headers, expected_headers);
208
209        assert_eq!(get_sampling_flags(&headers), flags);
210    }
211
212    #[test]
213    fn flags_sampled() {
214        let mut headers = Headers::new();
215        let flags = SamplingFlags::builder().sampled(true).build();
216        set_sampling_flags(flags, &mut headers);
217
218        let mut expected_headers = Headers::new();
219        expected_headers.set_raw("X-B3-Sampled", "1");
220        assert_eq!(headers, expected_headers);
221
222        assert_eq!(get_sampling_flags(&headers), flags);
223    }
224
225    #[test]
226    fn flags_unsampled() {
227        let mut headers = Headers::new();
228        let flags = SamplingFlags::builder().sampled(false).build();
229        set_sampling_flags(flags, &mut headers);
230
231        let mut expected_headers = Headers::new();
232        expected_headers.set_raw("X-B3-Sampled", "0");
233        assert_eq!(headers, expected_headers);
234
235        assert_eq!(get_sampling_flags(&headers), flags);
236    }
237
238    #[test]
239    fn trace_context() {
240        let mut headers = Headers::new();
241        let context = TraceContext::builder()
242            .trace_id([0, 1, 2, 3, 4, 5, 6, 7].into())
243            .parent_id([1, 2, 3, 4, 5, 6, 7, 8].into())
244            .span_id([2, 3, 4, 5, 6, 7, 8, 9].into())
245            .sampled(true)
246            .build();
247        set_trace_context(context, &mut headers);
248
249        let mut expected_headers = Headers::new();
250        expected_headers.set_raw("X-B3-TraceId", "0001020304050607");
251        expected_headers.set_raw("X-B3-SpanId", "0203040506070809");
252        expected_headers.set_raw("X-B3-ParentSpanId", "0102030405060708");
253        expected_headers.set_raw("X-B3-Sampled", "1");
254        assert_eq!(headers, expected_headers);
255
256        assert_eq!(get_trace_context(&headers), Some(context));
257    }
258}