lambda_otel_lite/
propagation.rs

1//! Context propagation extensions for AWS Lambda.
2//!
3//! This module provides specialized context propagators for AWS Lambda environments,
4//! including enhanced X-Ray propagation that integrates with Lambda's built-in tracing.
5
6use opentelemetry::{
7    propagation::{text_map_propagator::FieldIter, Extractor, Injector, TextMapPropagator},
8    Context,
9};
10use opentelemetry_aws::trace::XrayPropagator;
11use std::{collections::HashMap, env};
12use tracing::debug;
13
14// Define the X-Ray trace header constant since it's not publicly exported
15const AWS_XRAY_TRACE_HEADER: &str = "x-amzn-trace-id";
16
17/// A custom propagator that wraps the `XrayPropagator` with Lambda-specific enhancements.
18///
19/// This propagator extends the standard X-Ray propagator to automatically extract
20/// trace context from the Lambda `_X_AMZN_TRACE_ID` environment variable when no
21/// valid context is found in the provided carrier.
22///
23/// # Example
24///
25/// ```no_run
26/// use lambda_otel_lite::{init_telemetry, TelemetryConfig};
27/// use lambda_otel_lite::propagation::LambdaXrayPropagator;
28/// use opentelemetry::global;
29/// use lambda_runtime::Error;
30///
31/// # async fn example() -> Result<(), Error> {
32/// // Add the LambdaXrayPropagator
33/// let config = TelemetryConfig::builder()
34///     .with_named_propagator("tracecontext")
35///     .with_named_propagator("xray-lambda")
36///     .build();
37///
38/// let _ = init_telemetry(config).await?;
39/// # Ok(())
40/// # }
41/// ```
42#[derive(Debug, Default)]
43pub struct LambdaXrayPropagator {
44    /// The wrapped XrayPropagator instance
45    inner: XrayPropagator,
46}
47
48impl LambdaXrayPropagator {
49    /// Create a new instance of the LambdaXrayPropagator.
50    pub fn new() -> Self {
51        Self {
52            inner: XrayPropagator::default(),
53        }
54    }
55}
56
57impl TextMapPropagator for LambdaXrayPropagator {
58    fn fields(&self) -> FieldIter<'_> {
59        self.inner.fields()
60    }
61
62    fn extract_with_context(&self, cx: &Context, extractor: &dyn Extractor) -> Context {
63        // First, try to extract from the provided carrier using the inner propagator
64        let ctx = self.inner.extract_with_context(cx, extractor);
65
66        // Check if we got a valid context from the carrier
67        let has_carrier_context = has_active_span(&ctx);
68
69        // If we didn't get a valid context from the carrier, try the environment variable
70        if !has_carrier_context {
71            if let Ok(trace_id_value) = env::var("_X_AMZN_TRACE_ID") {
72                debug!("Found _X_AMZN_TRACE_ID: {}", trace_id_value);
73
74                // Create a carrier from the environment variable
75                let mut env_carrier = HashMap::new();
76                env_carrier.insert(AWS_XRAY_TRACE_HEADER.to_string(), trace_id_value);
77
78                // Try to extract from the environment variable
79                let env_ctx = self.inner.extract_with_context(cx, &env_carrier);
80                if has_active_span(&env_ctx) {
81                    debug!("Successfully extracted context from _X_AMZN_TRACE_ID");
82                    return env_ctx;
83                }
84            }
85        }
86
87        // Return the original context
88        ctx
89    }
90
91    fn extract(&self, extractor: &dyn Extractor) -> Context {
92        self.extract_with_context(&Context::current(), extractor)
93    }
94
95    fn inject_context(&self, cx: &Context, injector: &mut dyn Injector) {
96        self.inner.inject_context(cx, injector)
97    }
98}
99
100// Helper function to check if a context has an active span
101fn has_active_span(cx: &Context) -> bool {
102    use opentelemetry::trace::TraceContextExt;
103    cx.span().span_context().is_valid()
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use opentelemetry::trace::{
110        SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState,
111    };
112    use std::env;
113
114    #[test]
115    fn test_extract_from_carrier() {
116        // Create a valid X-Ray header
117        let trace_id = "1-5759e988-bd862e3fe1be46a994272793";
118        let parent_id = "53995c3f42cd8ad8";
119        let header_value = format!("Root={};Parent={};Sampled=1", trace_id, parent_id);
120
121        // Create a carrier with the header
122        let carrier = HashMap::from([(AWS_XRAY_TRACE_HEADER.to_string(), header_value)]);
123
124        // Extract context
125        let propagator = LambdaXrayPropagator::default();
126        let context = propagator.extract(&carrier);
127
128        // Verify the extracted context is valid using TraceContextExt trait
129        assert!(context.span().span_context().is_valid());
130    }
131
132    #[test]
133    fn test_extract_from_env_var() {
134        // Save the original environment variable if it exists
135        let original_env = env::var("_X_AMZN_TRACE_ID").ok();
136
137        // Set up a test environment variable
138        // Using a format that's known to be valid with the XrayPropagator
139        let trace_id = "1-5759e988-bd862e3fe1be46a994272793";
140        let parent_id = "53995c3f42cd8ad8";
141        let header_value = format!("Root={};Parent={};Sampled=1", trace_id, parent_id);
142        env::set_var("_X_AMZN_TRACE_ID", &header_value);
143
144        // First verify the XrayPropagator itself can parse the header
145        let xray_propagator = XrayPropagator::default();
146        let env_carrier =
147            HashMap::from([(AWS_XRAY_TRACE_HEADER.to_string(), header_value.clone())]);
148        let direct_context = xray_propagator.extract(&env_carrier);
149        assert!(
150            direct_context.span().span_context().is_valid(),
151            "XrayPropagator itself should be able to parse the header"
152        );
153
154        // Create an empty carrier - there are no headers, but the env var should be used
155        let empty_carrier = HashMap::<String, String>::new();
156
157        // Extract using our custom propagator's extract_with_context method directly
158        let propagator = LambdaXrayPropagator::default();
159        let context = propagator.extract_with_context(&Context::current(), &empty_carrier);
160
161        // Verify the extracted context is valid
162        assert!(
163            context.span().span_context().is_valid(),
164            "Expected valid context from env var via extract_with_context"
165        );
166
167        // Restore the original environment variable
168        if let Some(val) = original_env {
169            env::set_var("_X_AMZN_TRACE_ID", val);
170        } else {
171            env::remove_var("_X_AMZN_TRACE_ID");
172        }
173    }
174
175    #[test]
176    fn test_inject_context() {
177        // Create a test span context
178        let span_context = SpanContext::new(
179            TraceId::from_hex("5759e988bd862e3fe1be46a994272793").unwrap(),
180            SpanId::from_hex("53995c3f42cd8ad8").unwrap(),
181            TraceFlags::SAMPLED,
182            true,
183            TraceState::default(),
184        );
185
186        // Create context with the span context
187        let context = Context::current().with_remote_span_context(span_context);
188
189        // Create an injector
190        let mut injector = HashMap::<String, String>::new();
191
192        // Inject context
193        let propagator = LambdaXrayPropagator::default();
194        propagator.inject_context(&context, &mut injector);
195
196        // Verify the injected header
197        assert!(injector.contains_key(AWS_XRAY_TRACE_HEADER));
198        let header = injector.get(AWS_XRAY_TRACE_HEADER).unwrap();
199        assert!(header.contains("Root=1-5759e988-bd862e3fe1be46a994272793"));
200        assert!(header.contains("Parent=53995c3f42cd8ad8"));
201        assert!(header.contains("Sampled=1"));
202    }
203
204    #[test]
205    fn test_precedence() {
206        // Save the original environment variable if it exists
207        let original_env = env::var("_X_AMZN_TRACE_ID").ok();
208
209        // Set up a test environment variable (this should NOT be used if carrier is valid)
210        let env_trace_id = "1-5759e988-bd862e3fe1be46a994272793";
211        let env_parent_id = "53995c3f42cd8ad8";
212        let env_header = format!("Root={};Parent={};Sampled=1", env_trace_id, env_parent_id);
213        env::set_var("_X_AMZN_TRACE_ID", &env_header);
214
215        // Create a different valid X-Ray header for the carrier
216        let carrier_trace_id = "1-58406520-a006649127e371903a2de979";
217        let carrier_parent_id = "4c721bf33e3caf8f";
218        let carrier_header = format!(
219            "Root={};Parent={};Sampled=1",
220            carrier_trace_id, carrier_parent_id
221        );
222
223        // Create a carrier with the header
224        let carrier = HashMap::from([(AWS_XRAY_TRACE_HEADER.to_string(), carrier_header)]);
225
226        // Extract context
227        let propagator = LambdaXrayPropagator::default();
228        let context = propagator.extract(&carrier);
229
230        // Verify the extracted context used the carrier, not the env var
231        let span = context.span();
232        let span_context = span.span_context();
233        assert!(span_context.is_valid());
234        assert_eq!(
235            span_context.trace_id(),
236            TraceId::from_hex("58406520a006649127e371903a2de979").unwrap()
237        );
238
239        // Restore the original environment variable
240        if let Some(val) = original_env {
241            env::set_var("_X_AMZN_TRACE_ID", val);
242        } else {
243            env::remove_var("_X_AMZN_TRACE_ID");
244        }
245    }
246}