1use crate::{
21 propagation::{text_map_propagator::FieldIter, Extractor, Injector, TextMapPropagator},
22 trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState},
23 Context,
24};
25use std::str::FromStr;
26
27const SUPPORTED_VERSION: u8 = 0;
28const MAX_VERSION: u8 = 254;
29const TRACEPARENT_HEADER: &str = "traceparent";
30const TRACESTATE_HEADER: &str = "tracestate";
31
32lazy_static::lazy_static! {
33 static ref TRACE_CONTEXT_HEADER_FIELDS: [String; 2] = [
34 TRACEPARENT_HEADER.to_string(),
35 TRACESTATE_HEADER.to_string()
36 ];
37}
38
39#[derive(Clone, Debug, Default)]
43pub struct TraceContextPropagator {
44 _private: (),
45}
46
47impl TraceContextPropagator {
48 pub fn new() -> Self {
50 TraceContextPropagator { _private: () }
51 }
52
53 fn extract_span_context(&self, extractor: &dyn Extractor) -> Result<SpanContext, ()> {
55 let header_value = extractor.get(TRACEPARENT_HEADER).unwrap_or("").trim();
56 let parts = header_value.split_terminator('-').collect::<Vec<&str>>();
57 if parts.len() < 4 {
59 return Err(());
60 }
61
62 let version = u8::from_str_radix(parts[0], 16).map_err(|_| ())?;
64 if version > MAX_VERSION || version == 0 && parts.len() != 4 {
65 return Err(());
66 }
67
68 if parts[1].chars().any(|c| c.is_ascii_uppercase()) {
70 return Err(());
71 }
72
73 let trace_id = TraceId::from_hex(parts[1]).map_err(|_| ())?;
75
76 if parts[2].chars().any(|c| c.is_ascii_uppercase()) {
78 return Err(());
79 }
80
81 let span_id = SpanId::from_hex(parts[2]).map_err(|_| ())?;
83
84 let opts = u8::from_str_radix(parts[3], 16).map_err(|_| ())?;
86
87 if version == 0 && opts > 2 {
89 return Err(());
90 }
91
92 let trace_flags = TraceFlags::new(opts) & TraceFlags::SAMPLED;
95
96 let trace_state: TraceState =
97 TraceState::from_str(extractor.get(TRACESTATE_HEADER).unwrap_or(""))
98 .unwrap_or_else(|_| TraceState::default());
99
100 let span_context = SpanContext::new(trace_id, span_id, trace_flags, true, trace_state);
102
103 if !span_context.is_valid() {
105 return Err(());
106 }
107
108 Ok(span_context)
109 }
110}
111
112impl TextMapPropagator for TraceContextPropagator {
113 fn inject_context(&self, cx: &Context, injector: &mut dyn Injector) {
116 let span = cx.span();
117 let span_context = span.span_context();
118 if span_context.is_valid() {
119 let header_value = format!(
120 "{:02x}-{:032x}-{:016x}-{:02x}",
121 SUPPORTED_VERSION,
122 span_context.trace_id(),
123 span_context.span_id(),
124 span_context.trace_flags() & TraceFlags::SAMPLED
125 );
126 injector.set(TRACEPARENT_HEADER, header_value);
127 injector.set(TRACESTATE_HEADER, span_context.trace_state().header());
128 }
129 }
130
131 fn extract_with_context(&self, cx: &Context, extractor: &dyn Extractor) -> Context {
136 self.extract_span_context(extractor)
137 .map(|sc| cx.with_remote_span_context(sc))
138 .unwrap_or_else(|_| cx.clone())
139 }
140
141 fn fields(&self) -> FieldIter<'_> {
142 FieldIter::new(TRACE_CONTEXT_HEADER_FIELDS.as_ref())
143 }
144}
145
146#[cfg(all(test, feature = "testing", feature = "trace"))]
147mod tests {
148 use super::*;
149 use crate::testing::trace::TestSpan;
150 use crate::{
151 propagation::{Extractor, Injector, TextMapPropagator},
152 trace::{SpanContext, SpanId, TraceId},
153 };
154 use std::collections::HashMap;
155 use std::str::FromStr;
156
157 #[rustfmt::skip]
158 fn extract_data() -> Vec<(&'static str, &'static str, SpanContext)> {
159 vec![
160 ("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::default(), true, TraceState::from_str("foo=bar").unwrap())),
161 ("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
162 ("02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
163 ("02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-09", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
164 ("02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-08", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::default(), true, TraceState::from_str("foo=bar").unwrap())),
165 ("02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-09-XYZxsf09", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
166 ("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01-", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
167 ("01-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-09-", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
168 ]
169 }
170
171 #[rustfmt::skip]
172 fn extract_data_invalid() -> Vec<(&'static str, &'static str)> {
173 vec![
174 ("0000-00000000000000000000000000000000-0000000000000000-01", "wrong version length"),
175 ("00-ab00000000000000000000000000000000-cd00000000000000-01", "wrong trace ID length"),
176 ("00-ab000000000000000000000000000000-cd0000000000000000-01", "wrong span ID length"),
177 ("00-ab000000000000000000000000000000-cd00000000000000-0100", "wrong trace flag length"),
178 ("qw-00000000000000000000000000000000-0000000000000000-01", "bogus version"),
179 ("00-qw000000000000000000000000000000-cd00000000000000-01", "bogus trace ID"),
180 ("00-ab000000000000000000000000000000-qw00000000000000-01", "bogus span ID"),
181 ("00-ab000000000000000000000000000000-cd00000000000000-qw", "bogus trace flag"),
182 ("A0-00000000000000000000000000000000-0000000000000000-01", "upper case version"),
183 ("00-AB000000000000000000000000000000-cd00000000000000-01", "upper case trace ID"),
184 ("00-ab000000000000000000000000000000-CD00000000000000-01", "upper case span ID"),
185 ("00-ab000000000000000000000000000000-cd00000000000000-A1", "upper case trace flag"),
186 ("00-00000000000000000000000000000000-0000000000000000-01", "zero trace ID and span ID"),
187 ("00-ab000000000000000000000000000000-cd00000000000000-09", "trace-flag unused bits set"),
188 ("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7", "missing options"),
189 ("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-", "empty options"),
190 ]
191 }
192
193 #[rustfmt::skip]
194 fn inject_data() -> Vec<(&'static str, &'static str, SpanContext)> {
195 vec![
196 ("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
197 ("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::default(), true, TraceState::from_str("foo=bar").unwrap())),
198 ("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::new(0xff), true, TraceState::from_str("foo=bar").unwrap())),
199 ("", "", SpanContext::empty_context()),
200 ]
201 }
202
203 #[test]
204 fn extract_w3c() {
205 let propagator = TraceContextPropagator::new();
206
207 for (trace_parent, trace_state, expected_context) in extract_data() {
208 let mut extractor = HashMap::new();
209 extractor.insert(TRACEPARENT_HEADER.to_string(), trace_parent.to_string());
210 extractor.insert(TRACESTATE_HEADER.to_string(), trace_state.to_string());
211
212 assert_eq!(
213 propagator.extract(&extractor).span().span_context(),
214 &expected_context
215 )
216 }
217 }
218
219 #[test]
220 fn extract_w3c_tracestate() {
221 let propagator = TraceContextPropagator::new();
222 let state = "foo=bar".to_string();
223 let parent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00".to_string();
224
225 let mut extractor = HashMap::new();
226 extractor.insert(TRACEPARENT_HEADER.to_string(), parent);
227 extractor.insert(TRACESTATE_HEADER.to_string(), state.clone());
228
229 assert_eq!(
230 propagator
231 .extract(&extractor)
232 .span()
233 .span_context()
234 .trace_state()
235 .header(),
236 state
237 )
238 }
239
240 #[test]
241 fn extract_w3c_reject_invalid() {
242 let propagator = TraceContextPropagator::new();
243
244 for (invalid_header, reason) in extract_data_invalid() {
245 let mut extractor = HashMap::new();
246 extractor.insert(TRACEPARENT_HEADER.to_string(), invalid_header.to_string());
247
248 assert_eq!(
249 propagator.extract(&extractor).span().span_context(),
250 &SpanContext::empty_context(),
251 "{}",
252 reason
253 )
254 }
255 }
256
257 #[test]
258 fn inject_w3c() {
259 let propagator = TraceContextPropagator::new();
260
261 for (expected_trace_parent, expected_trace_state, context) in inject_data() {
262 let mut injector = HashMap::new();
263 propagator.inject_context(
264 &Context::current_with_span(TestSpan(context)),
265 &mut injector,
266 );
267
268 assert_eq!(
269 Extractor::get(&injector, TRACEPARENT_HEADER).unwrap_or(""),
270 expected_trace_parent
271 );
272
273 assert_eq!(
274 Extractor::get(&injector, TRACESTATE_HEADER).unwrap_or(""),
275 expected_trace_state
276 );
277 }
278 }
279
280 #[test]
281 fn inject_w3c_tracestate() {
282 let propagator = TraceContextPropagator::new();
283 let state = "foo=bar";
284
285 let mut injector: HashMap<String, String> = HashMap::new();
286 injector.set(TRACESTATE_HEADER, state.to_string());
287
288 propagator.inject_context(&Context::current(), &mut injector);
289
290 assert_eq!(Extractor::get(&injector, TRACESTATE_HEADER), Some(state))
291 }
292}