lambda_otel_lite/
propagation.rs

1//! Context propagation extensions for AWS Lambda.
2//!
3//! This module provides specialized context propagators for AWS Lambda environments,
4//! including enhanced X-Ray propagation that integrates with Lambda's built-in tracing.
5
6use crate::logger::Logger;
7use opentelemetry::trace::TraceContextExt;
8use opentelemetry::{
9    propagation::{text_map_propagator::FieldIter, Extractor, Injector, TextMapPropagator},
10    Context,
11};
12use opentelemetry_aws::trace::XrayPropagator;
13use std::{collections::HashMap, env};
14
15// Add module-specific logger
16static LOGGER: Logger = Logger::const_new("propagation");
17
18// Define the X-Ray trace header constant since it's not publicly exported
19const AWS_XRAY_TRACE_HEADER: &str = "x-amzn-trace-id";
20const AWS_XRAY_TRACE_ENV_VAR: &str = "_X_AMZN_TRACE_ID";
21
22/// A custom propagator that wraps the `XrayPropagator` with Lambda-specific enhancements.
23///
24/// This propagator extends the standard X-Ray propagator to automatically extract
25/// trace context from the Lambda `_X_AMZN_TRACE_ID` environment variable when no
26/// valid context is found in the provided carrier.
27///
28/// # Example
29///
30/// ```no_run
31/// use lambda_otel_lite::{init_telemetry, TelemetryConfig};
32/// use lambda_otel_lite::propagation::LambdaXrayPropagator;
33/// use opentelemetry::global;
34/// use lambda_runtime::Error;
35///
36/// # async fn example() -> Result<(), Error> {
37/// // Add the LambdaXrayPropagator
38/// let config = TelemetryConfig::builder()
39///     .with_named_propagator("tracecontext")
40///     .with_named_propagator("xray-lambda")
41///     .build();
42///
43/// let _ = init_telemetry(config).await?;
44/// # Ok(())
45/// # }
46/// ```
47#[derive(Debug, Default)]
48pub struct LambdaXrayPropagator {
49    /// The wrapped XrayPropagator instance
50    inner: XrayPropagator,
51}
52
53impl LambdaXrayPropagator {
54    /// Create a new instance of the LambdaXrayPropagator.
55    pub fn new() -> Self {
56        Self {
57            inner: XrayPropagator::default(),
58        }
59    }
60}
61
62impl TextMapPropagator for LambdaXrayPropagator {
63    fn fields(&self) -> FieldIter<'_> {
64        self.inner.fields()
65    }
66
67    fn extract_with_context(&self, cx: &Context, extractor: &dyn Extractor) -> Context {
68        // First, try to extract from the provided carrier using the inner propagator
69        let ctx = self.inner.extract_with_context(cx, extractor);
70
71        // Check if we got a valid context from the carrier
72        let has_carrier_context = has_active_span(&ctx);
73
74        // If we didn't get a valid context from the carrier, try the environment variable
75        if !has_carrier_context {
76            if let Ok(trace_id_value) = env::var(AWS_XRAY_TRACE_ENV_VAR) {
77                LOGGER.debug(format!("Found _X_AMZN_TRACE_ID: {trace_id_value}"));
78
79                // If Sampled=0, skip extraction entirely so the sampler treats this as a root span.
80                // This avoids suppressing spans when X-Ray is disabled but Lambda still injects the env var.
81                if trace_id_value.contains("Sampled=0") {
82                    LOGGER.debug("_X_AMZN_TRACE_ID has Sampled=0; skipping context extraction to allow root span sampling");
83                    return cx.clone();
84                }
85
86                // Create a carrier from the environment variable
87                let mut env_carrier = HashMap::new();
88                env_carrier.insert(AWS_XRAY_TRACE_HEADER.to_string(), trace_id_value);
89
90                // Try to extract from the environment variable
91                let env_ctx = self.inner.extract_with_context(cx, &env_carrier);
92                let span = env_ctx.span();
93                let span_context = span.span_context();
94                if span_context.is_valid() && span_context.is_sampled() {
95                    LOGGER.debug("Successfully extracted *sampled* context from _X_AMZN_TRACE_ID");
96                    return env_ctx;
97                } else {
98                    LOGGER
99                        .debug("Ignoring _X_AMZN_TRACE_ID because context is invalid or unsampled");
100                }
101            }
102        }
103
104        // Return the original context
105        ctx
106    }
107
108    fn extract(&self, extractor: &dyn Extractor) -> Context {
109        self.extract_with_context(&Context::current(), extractor)
110    }
111
112    fn inject_context(&self, cx: &Context, injector: &mut dyn Injector) {
113        self.inner.inject_context(cx, injector)
114    }
115}
116
117// Helper function to check if a context has an active span
118fn has_active_span(cx: &Context) -> bool {
119    cx.span().span_context().is_valid()
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use opentelemetry::trace::{
126        SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState,
127    };
128    use std::env;
129
130    #[test]
131    fn test_extract_from_carrier() {
132        // Create a valid X-Ray header
133        let trace_id = "1-5759e988-bd862e3fe1be46a994272793";
134        let parent_id = "53995c3f42cd8ad8";
135        let header_value = format!("Root={trace_id};Parent={parent_id};Sampled=1");
136
137        // Create a carrier with the header
138        let carrier = HashMap::from([(AWS_XRAY_TRACE_HEADER.to_string(), header_value)]);
139
140        // Extract context
141        let propagator = LambdaXrayPropagator::default();
142        let context = propagator.extract(&carrier);
143
144        // Verify the extracted context is valid using TraceContextExt trait
145        assert!(context.span().span_context().is_valid());
146    }
147
148    #[test]
149    fn test_extract_from_env_var() {
150        // Save the original environment variable if it exists
151        let original_env = env::var("_X_AMZN_TRACE_ID").ok();
152
153        // Set up a test environment variable
154        // Using a format that's known to be valid with the XrayPropagator
155        let trace_id = "1-5759e988-bd862e3fe1be46a994272793";
156        let parent_id = "53995c3f42cd8ad8";
157        let header_value = format!("Root={trace_id};Parent={parent_id};Sampled=1");
158        env::set_var("_X_AMZN_TRACE_ID", &header_value);
159
160        // First verify the XrayPropagator itself can parse the header
161        let xray_propagator = XrayPropagator::default();
162        let env_carrier =
163            HashMap::from([(AWS_XRAY_TRACE_HEADER.to_string(), header_value.clone())]);
164        let direct_context = xray_propagator.extract(&env_carrier);
165        assert!(
166            direct_context.span().span_context().is_valid(),
167            "XrayPropagator itself should be able to parse the header"
168        );
169
170        // Create an empty carrier - there are no headers, but the env var should be used
171        let empty_carrier = HashMap::<String, String>::new();
172
173        // Extract using our custom propagator's extract_with_context method directly
174        let propagator = LambdaXrayPropagator::default();
175        let context = propagator.extract_with_context(&Context::current(), &empty_carrier);
176
177        // Verify the extracted context is valid
178        assert!(
179            context.span().span_context().is_valid(),
180            "Expected valid context from env var via extract_with_context"
181        );
182
183        // Restore the original environment variable
184        if let Some(val) = original_env {
185            env::set_var("_X_AMZN_TRACE_ID", val);
186        } else {
187            env::remove_var("_X_AMZN_TRACE_ID");
188        }
189    }
190
191    #[test]
192    fn test_inject_context() {
193        // Create a test span context
194        let span_context = SpanContext::new(
195            TraceId::from_hex("5759e988bd862e3fe1be46a994272793").unwrap(),
196            SpanId::from_hex("53995c3f42cd8ad8").unwrap(),
197            TraceFlags::SAMPLED,
198            true,
199            TraceState::default(),
200        );
201
202        // Create context with the span context
203        let context = Context::current().with_remote_span_context(span_context);
204
205        // Create an injector
206        let mut injector = HashMap::<String, String>::new();
207
208        // Inject context
209        let propagator = LambdaXrayPropagator::default();
210        propagator.inject_context(&context, &mut injector);
211
212        // Verify the injected header
213        assert!(injector.contains_key(AWS_XRAY_TRACE_HEADER));
214        let header = injector.get(AWS_XRAY_TRACE_HEADER).unwrap();
215        assert!(header.contains("Root=1-5759e988-bd862e3fe1be46a994272793"));
216        assert!(header.contains("Parent=53995c3f42cd8ad8"));
217        assert!(header.contains("Sampled=1"));
218    }
219
220    #[test]
221    fn test_precedence() {
222        // Save the original environment variable if it exists
223        let original_env = env::var("_X_AMZN_TRACE_ID").ok();
224
225        // Set up a test environment variable (this should NOT be used if carrier is valid)
226        let env_trace_id = "1-5759e988-bd862e3fe1be46a994272793";
227        let env_parent_id = "53995c3f42cd8ad8";
228        let env_header = format!("Root={env_trace_id};Parent={env_parent_id};Sampled=1");
229        env::set_var("_X_AMZN_TRACE_ID", &env_header);
230
231        // Create a different valid X-Ray header for the carrier
232        let carrier_trace_id = "1-58406520-a006649127e371903a2de979";
233        let carrier_parent_id = "4c721bf33e3caf8f";
234        let carrier_header =
235            format!("Root={carrier_trace_id};Parent={carrier_parent_id};Sampled=1");
236
237        // Create a carrier with the header
238        let carrier = HashMap::from([(AWS_XRAY_TRACE_HEADER.to_string(), carrier_header)]);
239
240        // Extract context
241        let propagator = LambdaXrayPropagator::default();
242        let context = propagator.extract(&carrier);
243
244        // Verify the extracted context used the carrier, not the env var
245        let span = context.span();
246        let span_context = span.span_context();
247        assert!(span_context.is_valid());
248        assert_eq!(
249            span_context.trace_id(),
250            TraceId::from_hex("58406520a006649127e371903a2de979").unwrap()
251        );
252
253        // Restore the original environment variable
254        if let Some(val) = original_env {
255            env::set_var("_X_AMZN_TRACE_ID", val);
256        } else {
257            env::remove_var("_X_AMZN_TRACE_ID");
258        }
259    }
260}