#![doc(html_root_url = "https://docs.rs/http-zipkin/0.2")]
#![warn(missing_docs)]
use http::header::{HeaderMap, HeaderValue};
use std::fmt::Write;
use std::str::FromStr;
use zipkin::{SamplingFlags, TraceContext};
const X_B3_SAMPLED: &str = "X-B3-Sampled";
const X_B3_FLAGS: &str = "X-B3-Flags";
const X_B3_TRACEID: &str = "X-B3-TraceId";
const X_B3_PARENTSPANID: &str = "X-B3-ParentSpanId";
const X_B3_SPANID: &str = "X-B3-SpanId";
const B3: &str = "b3";
pub fn set_sampling_flags_single(flags: SamplingFlags, headers: &mut HeaderMap) {
if flags.debug() {
headers.insert(B3, HeaderValue::from_static("d"));
} else if flags.sampled() == Some(true) {
headers.insert(B3, HeaderValue::from_static("1"));
} else if flags.sampled() == Some(false) {
headers.insert(B3, HeaderValue::from_static("0"));
} else {
headers.remove(B3);
}
}
pub fn set_sampling_flags(flags: SamplingFlags, headers: &mut HeaderMap) {
if flags.debug() {
headers.insert(X_B3_FLAGS, HeaderValue::from_static("1"));
headers.remove(X_B3_SAMPLED);
} else {
headers.remove(X_B3_FLAGS);
match flags.sampled() {
Some(true) => {
headers.insert(X_B3_SAMPLED, HeaderValue::from_static("1"));
}
Some(false) => {
headers.insert(X_B3_SAMPLED, HeaderValue::from_static("0"));
}
None => {
headers.remove(X_B3_SAMPLED);
}
}
}
}
pub fn get_sampling_flags(headers: &HeaderMap) -> SamplingFlags {
match headers.get(B3) {
Some(value) => get_sampling_flags_single(value),
None => get_sampling_flags_multi(headers),
}
}
fn get_sampling_flags_single(value: &HeaderValue) -> SamplingFlags {
let mut builder = SamplingFlags::builder();
if value == "d" {
builder.debug(true);
} else if value == "1" {
builder.sampled(true);
} else if value == "0" {
builder.sampled(false);
} else if let Some(context) = get_trace_context_single(value) {
return context.sampling_flags();
}
builder.build()
}
fn get_sampling_flags_multi(headers: &HeaderMap) -> SamplingFlags {
let mut builder = SamplingFlags::builder();
if let Some(flags) = headers.get(X_B3_FLAGS) {
if flags == "1" {
builder.debug(true);
}
} else if let Some(sampled) = headers.get(X_B3_SAMPLED) {
if sampled == "1" {
builder.sampled(true);
} else if sampled == "0" {
builder.sampled(false);
}
}
builder.build()
}
pub fn set_trace_context_single(context: TraceContext, headers: &mut HeaderMap) {
let mut value = String::new();
write!(value, "{}-{}", context.trace_id(), context.span_id()).unwrap();
if context.debug() {
value.push_str("-d");
} else if context.sampled() == Some(true) {
value.push_str("-1");
} else if context.sampled() == Some(false) {
value.push_str("-0");
}
if let Some(parent_id) = context.parent_id() {
write!(value, "-{}", parent_id).unwrap();
}
headers.insert(B3, HeaderValue::from_str(&value).unwrap());
}
pub fn set_trace_context(context: TraceContext, headers: &mut HeaderMap) {
set_sampling_flags(context.sampling_flags(), headers);
headers.insert(
X_B3_TRACEID,
HeaderValue::from_str(&context.trace_id().to_string()).unwrap(),
);
match context.parent_id() {
Some(parent_id) => {
headers.insert(
X_B3_PARENTSPANID,
HeaderValue::from_str(&parent_id.to_string()).unwrap(),
);
}
None => {
headers.remove(X_B3_PARENTSPANID);
}
}
headers.insert(
X_B3_SPANID,
HeaderValue::from_str(&context.span_id().to_string()).unwrap(),
);
}
pub fn get_trace_context(headers: &HeaderMap) -> Option<TraceContext> {
match headers.get(B3) {
Some(value) => get_trace_context_single(value),
None => get_trace_context_multi(headers),
}
}
fn get_trace_context_single(value: &HeaderValue) -> Option<TraceContext> {
let mut parts = value.to_str().ok()?.split('-');
let trace_id = parts.next()?.parse().ok()?;
let span_id = parts.next()?.parse().ok()?;
let mut builder = TraceContext::builder();
builder.trace_id(trace_id).span_id(span_id);
let maybe_sampling = match parts.next() {
Some(next) => next,
None => return Some(builder.build()),
};
let parent_id = if maybe_sampling == "d" {
builder.debug(true);
parts.next()
} else if maybe_sampling == "1" {
builder.sampled(true);
parts.next()
} else if maybe_sampling == "0" {
builder.sampled(false);
parts.next()
} else {
Some(maybe_sampling)
};
if let Some(parent_id) = parent_id {
builder.parent_id(parent_id.parse().ok()?);
}
Some(builder.build())
}
fn get_trace_context_multi(headers: &HeaderMap) -> Option<TraceContext> {
let trace_id = parse_header(headers, X_B3_TRACEID)?;
let span_id = parse_header(headers, X_B3_SPANID)?;
let mut builder = TraceContext::builder();
builder
.trace_id(trace_id)
.span_id(span_id)
.sampling_flags(get_sampling_flags_multi(headers));
if let Some(parent_id) = parse_header(headers, X_B3_PARENTSPANID) {
builder.parent_id(parent_id);
}
Some(builder.build())
}
fn parse_header<T>(headers: &HeaderMap, name: &str) -> Option<T>
where
T: FromStr,
{
headers
.get(name)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse().ok())
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn flags_empty() {
let mut headers = HeaderMap::new();
let flags = SamplingFlags::builder().build();
set_sampling_flags(flags, &mut headers);
let expected_headers = HeaderMap::new();
assert_eq!(headers, expected_headers);
assert_eq!(get_sampling_flags(&headers), flags);
}
#[test]
fn flags_empty_single() {
let mut headers = HeaderMap::new();
let flags = SamplingFlags::builder().build();
set_sampling_flags_single(flags, &mut headers);
let expected_headers = HeaderMap::new();
assert_eq!(headers, expected_headers);
assert_eq!(get_sampling_flags(&headers), flags);
}
#[test]
fn flags_debug() {
let mut headers = HeaderMap::new();
let flags = SamplingFlags::builder().debug(true).build();
set_sampling_flags(flags, &mut headers);
let mut expected_headers = HeaderMap::new();
expected_headers.insert("X-B3-Flags", HeaderValue::from_static("1"));
assert_eq!(headers, expected_headers);
assert_eq!(get_sampling_flags(&headers), flags);
}
#[test]
fn flags_debug_single() {
let mut headers = HeaderMap::new();
let flags = SamplingFlags::builder().debug(true).build();
set_sampling_flags_single(flags, &mut headers);
let mut expected_headers = HeaderMap::new();
expected_headers.insert("b3", HeaderValue::from_static("d"));
assert_eq!(headers, expected_headers);
assert_eq!(get_sampling_flags(&headers), flags);
}
#[test]
fn flags_sampled() {
let mut headers = HeaderMap::new();
let flags = SamplingFlags::builder().sampled(true).build();
set_sampling_flags(flags, &mut headers);
let mut expected_headers = HeaderMap::new();
expected_headers.insert("X-B3-Sampled", HeaderValue::from_static("1"));
assert_eq!(headers, expected_headers);
assert_eq!(get_sampling_flags(&headers), flags);
}
#[test]
fn flags_sampled_single() {
let mut headers = HeaderMap::new();
let flags = SamplingFlags::builder().sampled(true).build();
set_sampling_flags_single(flags, &mut headers);
let mut expected_headers = HeaderMap::new();
expected_headers.insert("b3", HeaderValue::from_static("1"));
assert_eq!(headers, expected_headers);
assert_eq!(get_sampling_flags(&headers), flags);
}
#[test]
fn flags_unsampled() {
let mut headers = HeaderMap::new();
let flags = SamplingFlags::builder().sampled(false).build();
set_sampling_flags(flags, &mut headers);
let mut expected_headers = HeaderMap::new();
expected_headers.insert("X-B3-Sampled", HeaderValue::from_static("0"));
assert_eq!(headers, expected_headers);
assert_eq!(get_sampling_flags(&headers), flags);
}
#[test]
fn flags_unsampled_single() {
let mut headers = HeaderMap::new();
let flags = SamplingFlags::builder().sampled(false).build();
set_sampling_flags_single(flags, &mut headers);
let mut expected_headers = HeaderMap::new();
expected_headers.insert("b3", HeaderValue::from_static("0"));
assert_eq!(headers, expected_headers);
assert_eq!(get_sampling_flags(&headers), flags);
}
#[test]
fn trace_context() {
let mut headers = HeaderMap::new();
let context = TraceContext::builder()
.trace_id([0, 1, 2, 3, 4, 5, 6, 7].into())
.parent_id([1, 2, 3, 4, 5, 6, 7, 8].into())
.span_id([2, 3, 4, 5, 6, 7, 8, 9].into())
.sampled(true)
.build();
set_trace_context(context, &mut headers);
let mut expected_headers = HeaderMap::new();
expected_headers.insert("X-B3-TraceId", HeaderValue::from_static("0001020304050607"));
expected_headers.insert("X-B3-SpanId", HeaderValue::from_static("0203040506070809"));
expected_headers.insert(
"X-B3-ParentSpanId",
HeaderValue::from_static("0102030405060708"),
);
expected_headers.insert("X-B3-Sampled", HeaderValue::from_static("1"));
assert_eq!(headers, expected_headers);
assert_eq!(get_trace_context(&headers), Some(context));
}
#[test]
fn trace_context_single() {
let mut headers = HeaderMap::new();
let context = TraceContext::builder()
.trace_id([0, 1, 2, 3, 4, 5, 6, 7].into())
.parent_id([1, 2, 3, 4, 5, 6, 7, 8].into())
.span_id([2, 3, 4, 5, 6, 7, 8, 9].into())
.sampled(true)
.build();
set_trace_context_single(context, &mut headers);
let mut expected_headers = HeaderMap::new();
expected_headers.insert(
"b3",
HeaderValue::from_static("0001020304050607-0203040506070809-1-0102030405060708"),
);
assert_eq!(headers, expected_headers);
assert_eq!(get_trace_context(&headers), Some(context));
}
#[test]
fn trace_context_unsampled_single() {
let mut headers = HeaderMap::new();
let context = TraceContext::builder()
.trace_id([0, 1, 2, 3, 4, 5, 6, 7].into())
.parent_id([1, 2, 3, 4, 5, 6, 7, 8].into())
.span_id([2, 3, 4, 5, 6, 7, 8, 9].into())
.build();
set_trace_context_single(context, &mut headers);
let mut expected_headers = HeaderMap::new();
expected_headers.insert(
"b3",
HeaderValue::from_static("0001020304050607-0203040506070809-0102030405060708"),
);
assert_eq!(headers, expected_headers);
assert_eq!(get_trace_context(&headers), Some(context));
}
#[test]
fn trace_context_parentless_single() {
let mut headers = HeaderMap::new();
let context = TraceContext::builder()
.trace_id([0, 1, 2, 3, 4, 5, 6, 7].into())
.span_id([2, 3, 4, 5, 6, 7, 8, 9].into())
.sampled(true)
.build();
set_trace_context_single(context, &mut headers);
let mut expected_headers = HeaderMap::new();
expected_headers.insert(
"b3",
HeaderValue::from_static("0001020304050607-0203040506070809-1"),
);
assert_eq!(headers, expected_headers);
assert_eq!(get_trace_context(&headers), Some(context));
}
#[test]
fn trace_context_minimal_single() {
let mut headers = HeaderMap::new();
let context = TraceContext::builder()
.trace_id([0, 1, 2, 3, 4, 5, 6, 7].into())
.span_id([2, 3, 4, 5, 6, 7, 8, 9].into())
.build();
set_trace_context_single(context, &mut headers);
let mut expected_headers = HeaderMap::new();
expected_headers.insert(
"b3",
HeaderValue::from_static("0001020304050607-0203040506070809"),
);
assert_eq!(headers, expected_headers);
assert_eq!(get_trace_context(&headers), Some(context));
}
}