lambda_otel_lite/
propagation.rs

1//! Context propagation extensions for AWS Lambda.
2//!
3//! This module provides specialized context propagators for AWS Lambda environments,
4//! including enhanced X-Ray propagation that integrates with Lambda's built-in tracing.
5
6use crate::logger::Logger;
7use opentelemetry::{
8    propagation::{text_map_propagator::FieldIter, Extractor, Injector, TextMapPropagator},
9    Context,
10};
11use opentelemetry_aws::trace::XrayPropagator;
12use std::{collections::HashMap, env};
13
14// Add module-specific logger
15static LOGGER: Logger = Logger::const_new("propagation");
16
17// Define the X-Ray trace header constant since it's not publicly exported
18const AWS_XRAY_TRACE_HEADER: &str = "x-amzn-trace-id";
19
20/// A custom propagator that wraps the `XrayPropagator` with Lambda-specific enhancements.
21///
22/// This propagator extends the standard X-Ray propagator to automatically extract
23/// trace context from the Lambda `_X_AMZN_TRACE_ID` environment variable when no
24/// valid context is found in the provided carrier.
25///
26/// # Example
27///
28/// ```no_run
29/// use lambda_otel_lite::{init_telemetry, TelemetryConfig};
30/// use lambda_otel_lite::propagation::LambdaXrayPropagator;
31/// use opentelemetry::global;
32/// use lambda_runtime::Error;
33///
34/// # async fn example() -> Result<(), Error> {
35/// // Add the LambdaXrayPropagator
36/// let config = TelemetryConfig::builder()
37///     .with_named_propagator("tracecontext")
38///     .with_named_propagator("xray-lambda")
39///     .build();
40///
41/// let _ = init_telemetry(config).await?;
42/// # Ok(())
43/// # }
44/// ```
45#[derive(Debug, Default)]
46pub struct LambdaXrayPropagator {
47    /// The wrapped XrayPropagator instance
48    inner: XrayPropagator,
49}
50
51impl LambdaXrayPropagator {
52    /// Create a new instance of the LambdaXrayPropagator.
53    pub fn new() -> Self {
54        Self {
55            inner: XrayPropagator::default(),
56        }
57    }
58}
59
60impl TextMapPropagator for LambdaXrayPropagator {
61    fn fields(&self) -> FieldIter<'_> {
62        self.inner.fields()
63    }
64
65    fn extract_with_context(&self, cx: &Context, extractor: &dyn Extractor) -> Context {
66        // First, try to extract from the provided carrier using the inner propagator
67        let ctx = self.inner.extract_with_context(cx, extractor);
68
69        // Check if we got a valid context from the carrier
70        let has_carrier_context = has_active_span(&ctx);
71
72        // If we didn't get a valid context from the carrier, try the environment variable
73        if !has_carrier_context {
74            if let Ok(trace_id_value) = env::var("_X_AMZN_TRACE_ID") {
75                LOGGER.debug(format!("Found _X_AMZN_TRACE_ID: {}", trace_id_value));
76
77                // Create a carrier from the environment variable
78                let mut env_carrier = HashMap::new();
79                env_carrier.insert(AWS_XRAY_TRACE_HEADER.to_string(), trace_id_value);
80
81                // Try to extract from the environment variable
82                let env_ctx = self.inner.extract_with_context(cx, &env_carrier);
83                if has_active_span(&env_ctx) {
84                    LOGGER.debug("Successfully extracted context from _X_AMZN_TRACE_ID");
85                    return env_ctx;
86                }
87            }
88        }
89
90        // Return the original context
91        ctx
92    }
93
94    fn extract(&self, extractor: &dyn Extractor) -> Context {
95        self.extract_with_context(&Context::current(), extractor)
96    }
97
98    fn inject_context(&self, cx: &Context, injector: &mut dyn Injector) {
99        self.inner.inject_context(cx, injector)
100    }
101}
102
103// Helper function to check if a context has an active span
104fn has_active_span(cx: &Context) -> bool {
105    use opentelemetry::trace::TraceContextExt;
106    cx.span().span_context().is_valid()
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use opentelemetry::trace::{
113        SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState,
114    };
115    use std::env;
116
117    #[test]
118    fn test_extract_from_carrier() {
119        // Create a valid X-Ray header
120        let trace_id = "1-5759e988-bd862e3fe1be46a994272793";
121        let parent_id = "53995c3f42cd8ad8";
122        let header_value = format!("Root={};Parent={};Sampled=1", trace_id, parent_id);
123
124        // Create a carrier with the header
125        let carrier = HashMap::from([(AWS_XRAY_TRACE_HEADER.to_string(), header_value)]);
126
127        // Extract context
128        let propagator = LambdaXrayPropagator::default();
129        let context = propagator.extract(&carrier);
130
131        // Verify the extracted context is valid using TraceContextExt trait
132        assert!(context.span().span_context().is_valid());
133    }
134
135    #[test]
136    fn test_extract_from_env_var() {
137        // Save the original environment variable if it exists
138        let original_env = env::var("_X_AMZN_TRACE_ID").ok();
139
140        // Set up a test environment variable
141        // Using a format that's known to be valid with the XrayPropagator
142        let trace_id = "1-5759e988-bd862e3fe1be46a994272793";
143        let parent_id = "53995c3f42cd8ad8";
144        let header_value = format!("Root={};Parent={};Sampled=1", trace_id, parent_id);
145        env::set_var("_X_AMZN_TRACE_ID", &header_value);
146
147        // First verify the XrayPropagator itself can parse the header
148        let xray_propagator = XrayPropagator::default();
149        let env_carrier =
150            HashMap::from([(AWS_XRAY_TRACE_HEADER.to_string(), header_value.clone())]);
151        let direct_context = xray_propagator.extract(&env_carrier);
152        assert!(
153            direct_context.span().span_context().is_valid(),
154            "XrayPropagator itself should be able to parse the header"
155        );
156
157        // Create an empty carrier - there are no headers, but the env var should be used
158        let empty_carrier = HashMap::<String, String>::new();
159
160        // Extract using our custom propagator's extract_with_context method directly
161        let propagator = LambdaXrayPropagator::default();
162        let context = propagator.extract_with_context(&Context::current(), &empty_carrier);
163
164        // Verify the extracted context is valid
165        assert!(
166            context.span().span_context().is_valid(),
167            "Expected valid context from env var via extract_with_context"
168        );
169
170        // Restore the original environment variable
171        if let Some(val) = original_env {
172            env::set_var("_X_AMZN_TRACE_ID", val);
173        } else {
174            env::remove_var("_X_AMZN_TRACE_ID");
175        }
176    }
177
178    #[test]
179    fn test_inject_context() {
180        // Create a test span context
181        let span_context = SpanContext::new(
182            TraceId::from_hex("5759e988bd862e3fe1be46a994272793").unwrap(),
183            SpanId::from_hex("53995c3f42cd8ad8").unwrap(),
184            TraceFlags::SAMPLED,
185            true,
186            TraceState::default(),
187        );
188
189        // Create context with the span context
190        let context = Context::current().with_remote_span_context(span_context);
191
192        // Create an injector
193        let mut injector = HashMap::<String, String>::new();
194
195        // Inject context
196        let propagator = LambdaXrayPropagator::default();
197        propagator.inject_context(&context, &mut injector);
198
199        // Verify the injected header
200        assert!(injector.contains_key(AWS_XRAY_TRACE_HEADER));
201        let header = injector.get(AWS_XRAY_TRACE_HEADER).unwrap();
202        assert!(header.contains("Root=1-5759e988-bd862e3fe1be46a994272793"));
203        assert!(header.contains("Parent=53995c3f42cd8ad8"));
204        assert!(header.contains("Sampled=1"));
205    }
206
207    #[test]
208    fn test_precedence() {
209        // Save the original environment variable if it exists
210        let original_env = env::var("_X_AMZN_TRACE_ID").ok();
211
212        // Set up a test environment variable (this should NOT be used if carrier is valid)
213        let env_trace_id = "1-5759e988-bd862e3fe1be46a994272793";
214        let env_parent_id = "53995c3f42cd8ad8";
215        let env_header = format!("Root={};Parent={};Sampled=1", env_trace_id, env_parent_id);
216        env::set_var("_X_AMZN_TRACE_ID", &env_header);
217
218        // Create a different valid X-Ray header for the carrier
219        let carrier_trace_id = "1-58406520-a006649127e371903a2de979";
220        let carrier_parent_id = "4c721bf33e3caf8f";
221        let carrier_header = format!(
222            "Root={};Parent={};Sampled=1",
223            carrier_trace_id, carrier_parent_id
224        );
225
226        // Create a carrier with the header
227        let carrier = HashMap::from([(AWS_XRAY_TRACE_HEADER.to_string(), carrier_header)]);
228
229        // Extract context
230        let propagator = LambdaXrayPropagator::default();
231        let context = propagator.extract(&carrier);
232
233        // Verify the extracted context used the carrier, not the env var
234        let span = context.span();
235        let span_context = span.span_context();
236        assert!(span_context.is_valid());
237        assert_eq!(
238            span_context.trace_id(),
239            TraceId::from_hex("58406520a006649127e371903a2de979").unwrap()
240        );
241
242        // Restore the original environment variable
243        if let Some(val) = original_env {
244            env::set_var("_X_AMZN_TRACE_ID", val);
245        } else {
246            env::remove_var("_X_AMZN_TRACE_ID");
247        }
248    }
249}