use crate::logger::Logger;
use opentelemetry::trace::TraceContextExt;
use opentelemetry::{
propagation::{text_map_propagator::FieldIter, Extractor, Injector, TextMapPropagator},
Context,
};
use opentelemetry_aws::trace::XrayPropagator;
use std::{collections::HashMap, env};
static LOGGER: Logger = Logger::const_new("propagation");
const AWS_XRAY_TRACE_HEADER: &str = "x-amzn-trace-id";
const AWS_XRAY_TRACE_ENV_VAR: &str = "_X_AMZN_TRACE_ID";
#[derive(Debug, Default)]
pub struct LambdaXrayPropagator {
inner: XrayPropagator,
}
impl LambdaXrayPropagator {
pub fn new() -> Self {
Self {
inner: XrayPropagator::default(),
}
}
}
impl TextMapPropagator for LambdaXrayPropagator {
fn fields(&self) -> FieldIter<'_> {
self.inner.fields()
}
fn extract_with_context(&self, cx: &Context, extractor: &dyn Extractor) -> Context {
let ctx = self.inner.extract_with_context(cx, extractor);
let has_carrier_context = has_active_span(&ctx);
if !has_carrier_context {
if let Ok(trace_id_value) = env::var(AWS_XRAY_TRACE_ENV_VAR) {
LOGGER.debug(format!("Found _X_AMZN_TRACE_ID: {trace_id_value}"));
if trace_id_value.contains("Sampled=0") {
LOGGER.debug("_X_AMZN_TRACE_ID has Sampled=0; skipping context extraction to allow root span sampling");
return cx.clone();
}
let mut env_carrier = HashMap::new();
env_carrier.insert(AWS_XRAY_TRACE_HEADER.to_string(), trace_id_value);
let env_ctx = self.inner.extract_with_context(cx, &env_carrier);
let span = env_ctx.span();
let span_context = span.span_context();
if span_context.is_valid() && span_context.is_sampled() {
LOGGER.debug("Successfully extracted *sampled* context from _X_AMZN_TRACE_ID");
return env_ctx;
} else {
LOGGER
.debug("Ignoring _X_AMZN_TRACE_ID because context is invalid or unsampled");
}
}
}
ctx
}
fn extract(&self, extractor: &dyn Extractor) -> Context {
self.extract_with_context(&Context::current(), extractor)
}
fn inject_context(&self, cx: &Context, injector: &mut dyn Injector) {
self.inner.inject_context(cx, injector)
}
}
fn has_active_span(cx: &Context) -> bool {
cx.span().span_context().is_valid()
}
#[cfg(test)]
mod tests {
use super::*;
use opentelemetry::trace::{
SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState,
};
use std::env;
#[test]
fn test_extract_from_carrier() {
let trace_id = "1-5759e988-bd862e3fe1be46a994272793";
let parent_id = "53995c3f42cd8ad8";
let header_value = format!("Root={trace_id};Parent={parent_id};Sampled=1");
let carrier = HashMap::from([(AWS_XRAY_TRACE_HEADER.to_string(), header_value)]);
let propagator = LambdaXrayPropagator::default();
let context = propagator.extract(&carrier);
assert!(context.span().span_context().is_valid());
}
#[test]
fn test_extract_from_env_var() {
let original_env = env::var("_X_AMZN_TRACE_ID").ok();
let trace_id = "1-5759e988-bd862e3fe1be46a994272793";
let parent_id = "53995c3f42cd8ad8";
let header_value = format!("Root={trace_id};Parent={parent_id};Sampled=1");
env::set_var("_X_AMZN_TRACE_ID", &header_value);
let xray_propagator = XrayPropagator::default();
let env_carrier =
HashMap::from([(AWS_XRAY_TRACE_HEADER.to_string(), header_value.clone())]);
let direct_context = xray_propagator.extract(&env_carrier);
assert!(
direct_context.span().span_context().is_valid(),
"XrayPropagator itself should be able to parse the header"
);
let empty_carrier = HashMap::<String, String>::new();
let propagator = LambdaXrayPropagator::default();
let context = propagator.extract_with_context(&Context::current(), &empty_carrier);
assert!(
context.span().span_context().is_valid(),
"Expected valid context from env var via extract_with_context"
);
if let Some(val) = original_env {
env::set_var("_X_AMZN_TRACE_ID", val);
} else {
env::remove_var("_X_AMZN_TRACE_ID");
}
}
#[test]
fn test_inject_context() {
let span_context = SpanContext::new(
TraceId::from_hex("5759e988bd862e3fe1be46a994272793").unwrap(),
SpanId::from_hex("53995c3f42cd8ad8").unwrap(),
TraceFlags::SAMPLED,
true,
TraceState::default(),
);
let context = Context::current().with_remote_span_context(span_context);
let mut injector = HashMap::<String, String>::new();
let propagator = LambdaXrayPropagator::default();
propagator.inject_context(&context, &mut injector);
assert!(injector.contains_key(AWS_XRAY_TRACE_HEADER));
let header = injector.get(AWS_XRAY_TRACE_HEADER).unwrap();
assert!(header.contains("Root=1-5759e988-bd862e3fe1be46a994272793"));
assert!(header.contains("Parent=53995c3f42cd8ad8"));
assert!(header.contains("Sampled=1"));
}
#[test]
fn test_precedence() {
let original_env = env::var("_X_AMZN_TRACE_ID").ok();
let env_trace_id = "1-5759e988-bd862e3fe1be46a994272793";
let env_parent_id = "53995c3f42cd8ad8";
let env_header = format!("Root={env_trace_id};Parent={env_parent_id};Sampled=1");
env::set_var("_X_AMZN_TRACE_ID", &env_header);
let carrier_trace_id = "1-58406520-a006649127e371903a2de979";
let carrier_parent_id = "4c721bf33e3caf8f";
let carrier_header =
format!("Root={carrier_trace_id};Parent={carrier_parent_id};Sampled=1");
let carrier = HashMap::from([(AWS_XRAY_TRACE_HEADER.to_string(), carrier_header)]);
let propagator = LambdaXrayPropagator::default();
let context = propagator.extract(&carrier);
let span = context.span();
let span_context = span.span_context();
assert!(span_context.is_valid());
assert_eq!(
span_context.trace_id(),
TraceId::from_hex("58406520a006649127e371903a2de979").unwrap()
);
if let Some(val) = original_env {
env::set_var("_X_AMZN_TRACE_ID", val);
} else {
env::remove_var("_X_AMZN_TRACE_ID");
}
}
}