opentelemetry_lambda_tower/extractors/
sqs.rs

1//! SQS event extractor for message queue triggers.
2//!
3//! Extracts trace context from SQS messages, checking:
4//! 1. `message_attributes` for W3C `traceparent` (injected by OTel-instrumented producers)
5//! 2. `attributes` (system attributes) for `AWSTraceHeader` in X-Ray format
6//!
7//! Supports both [`SqsEvent`] (string bodies) and [`SqsEventObj<T>`] (typed bodies).
8
9use crate::extractor::TraceContextExtractor;
10use aws_lambda_events::sqs::{SqsEvent, SqsEventObj};
11use lambda_runtime::Context as LambdaContext;
12use opentelemetry::Context;
13use opentelemetry::propagation::Extractor;
14use opentelemetry::trace::{
15    Link, SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState,
16};
17use opentelemetry_semantic_conventions::attribute::{
18    MESSAGING_BATCH_MESSAGE_COUNT, MESSAGING_DESTINATION_NAME, MESSAGING_MESSAGE_ID,
19    MESSAGING_OPERATION_TYPE, MESSAGING_SYSTEM,
20};
21use serde::Serialize;
22use std::collections::HashMap;
23use tracing::Span;
24
25/// Extractor for SQS message events.
26///
27/// SQS events may carry trace context in two locations:
28/// 1. `message_attributes` - W3C `traceparent`/`tracestate` injected by OTel-instrumented producers
29/// 2. `attributes` (system attributes) - `AWSTraceHeader` in X-Ray format set by AWS
30///
31/// This extractor checks both, preferring W3C format when available.
32///
33/// Per OpenTelemetry semantic conventions for messaging systems, this extractor:
34/// - Does NOT set a parent context (returns current context)
35/// - Creates span links for each message's trace context
36///
37/// This approach is appropriate because messages in a batch may originate from
38/// different traces, and the async nature of message queues means span links
39/// are more semantically correct than parent-child relationships.
40///
41/// # Example
42///
43/// ```ignore
44/// use opentelemetry_lambda_tower::{OtelTracingLayer, SqsEventExtractor};
45///
46/// let layer = OtelTracingLayer::new(SqsEventExtractor::new());
47/// ```
48#[derive(Clone, Debug, Default)]
49pub struct SqsEventExtractor;
50
51impl SqsEventExtractor {
52    /// Creates a new SQS event extractor.
53    pub fn new() -> Self {
54        Self
55    }
56
57    /// Extracts the queue name from an event source ARN.
58    ///
59    /// ARN format: `arn:aws:sqs:{region}:{account}:{queue-name}`
60    fn queue_name_from_arn(arn: &str) -> Option<&str> {
61        arn.rsplit(':').next()
62    }
63}
64
65impl TraceContextExtractor<SqsEvent> for SqsEventExtractor {
66    fn extract_context(&self, _event: &SqsEvent) -> Context {
67        Context::current()
68    }
69
70    fn extract_links(&self, event: &SqsEvent) -> Vec<Link> {
71        event
72            .records
73            .iter()
74            .filter_map(|msg| {
75                extract_link_from_sqs_message(&msg.message_attributes, &msg.attributes)
76            })
77            .collect()
78    }
79
80    fn trigger_type(&self) -> &'static str {
81        "pubsub"
82    }
83
84    fn span_name(&self, event: &SqsEvent, lambda_ctx: &LambdaContext) -> String {
85        let queue_name = event
86            .records
87            .first()
88            .and_then(|r| r.event_source_arn.as_deref())
89            .and_then(Self::queue_name_from_arn)
90            .unwrap_or(&lambda_ctx.env_config.function_name);
91
92        format!("{} process", queue_name)
93    }
94
95    fn record_attributes(&self, event: &SqsEvent, span: &Span) {
96        span.record(MESSAGING_SYSTEM, "aws_sqs");
97        span.record(MESSAGING_OPERATION_TYPE, "process");
98
99        if let Some(record) = event.records.first()
100            && let Some(ref arn) = record.event_source_arn
101            && let Some(queue_name) = Self::queue_name_from_arn(arn)
102        {
103            span.record(MESSAGING_DESTINATION_NAME, queue_name);
104        }
105
106        span.record(MESSAGING_BATCH_MESSAGE_COUNT, event.records.len() as i64);
107
108        if event.records.len() == 1
109            && let Some(ref msg_id) = event.records[0].message_id
110        {
111            span.record(MESSAGING_MESSAGE_ID, msg_id.as_str());
112        }
113    }
114}
115
116impl<T: Serialize + Send + Sync + 'static> TraceContextExtractor<SqsEventObj<T>>
117    for SqsEventExtractor
118{
119    fn extract_context(&self, _event: &SqsEventObj<T>) -> Context {
120        Context::current()
121    }
122
123    fn extract_links(&self, event: &SqsEventObj<T>) -> Vec<Link> {
124        event
125            .records
126            .iter()
127            .filter_map(|msg| {
128                extract_link_from_sqs_message(&msg.message_attributes, &msg.attributes)
129            })
130            .collect()
131    }
132
133    fn trigger_type(&self) -> &'static str {
134        "pubsub"
135    }
136
137    fn span_name(&self, event: &SqsEventObj<T>, lambda_ctx: &LambdaContext) -> String {
138        let queue_name = event
139            .records
140            .first()
141            .and_then(|r| r.event_source_arn.as_deref())
142            .and_then(Self::queue_name_from_arn)
143            .unwrap_or(&lambda_ctx.env_config.function_name);
144
145        format!("{} process", queue_name)
146    }
147
148    fn record_attributes(&self, event: &SqsEventObj<T>, span: &Span) {
149        span.record(MESSAGING_SYSTEM, "aws_sqs");
150        span.record(MESSAGING_OPERATION_TYPE, "process");
151
152        if let Some(record) = event.records.first()
153            && let Some(ref arn) = record.event_source_arn
154            && let Some(queue_name) = Self::queue_name_from_arn(arn)
155        {
156            span.record(MESSAGING_DESTINATION_NAME, queue_name);
157        }
158
159        span.record(MESSAGING_BATCH_MESSAGE_COUNT, event.records.len() as i64);
160
161        if event.records.len() == 1
162            && let Some(ref msg_id) = event.records[0].message_id
163        {
164            span.record(MESSAGING_MESSAGE_ID, msg_id.as_str());
165        }
166    }
167}
168
169use aws_lambda_events::sqs::SqsMessageAttribute;
170
171/// Extracts a span link from SQS message attributes.
172///
173/// Uses the globally configured propagator to extract trace context from `message_attributes`,
174/// then falls back to parsing `AWSTraceHeader` from system `attributes` in X-Ray format.
175fn extract_link_from_sqs_message(
176    message_attributes: &HashMap<String, SqsMessageAttribute>,
177    system_attributes: &HashMap<String, String>,
178) -> Option<Link> {
179    if let Some(span_context) = extract_trace_context_from_message_attributes(message_attributes) {
180        return Some(Link::new(span_context, vec![], 0));
181    }
182
183    if let Some(trace_header) = system_attributes.get("AWSTraceHeader")
184        && let Some(span_context) = parse_xray_trace_header(trace_header)
185    {
186        return Some(Link::new(span_context, vec![], 0));
187    }
188
189    None
190}
191
192/// Extracts trace context from SQS message attributes using the global propagator.
193///
194/// The propagator determines which keys to look for (e.g. `traceparent` for W3C,
195/// `X-Amzn-Trace-Id` for X-Ray, `X-B3-*` for Zipkin, etc.).
196fn extract_trace_context_from_message_attributes(
197    message_attributes: &HashMap<String, SqsMessageAttribute>,
198) -> Option<SpanContext> {
199    let extractor = SqsMessageAttributeExtractor(message_attributes);
200    let ctx =
201        opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&extractor));
202
203    let span_context = ctx.span().span_context().clone();
204    if span_context.is_valid() {
205        Some(span_context)
206    } else {
207        None
208    }
209}
210
211/// Adapter to extract trace context from SQS message attributes.
212struct SqsMessageAttributeExtractor<'a>(&'a HashMap<String, SqsMessageAttribute>);
213
214impl Extractor for SqsMessageAttributeExtractor<'_> {
215    fn get(&self, key: &str) -> Option<&str> {
216        self.0
217            .get(key)
218            .and_then(|attr| attr.string_value.as_deref())
219    }
220
221    fn keys(&self) -> Vec<&str> {
222        self.0.keys().map(|k| k.as_str()).collect()
223    }
224}
225
226/// Parses an X-Ray trace header into a SpanContext.
227///
228/// X-Ray format: `Root=1-{epoch}-{random};Parent={span-id};Sampled={0|1}`
229///
230/// # Example
231///
232/// ```
233/// use opentelemetry_lambda_tower::extractors::sqs::parse_xray_trace_header;
234///
235/// let header = "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1";
236/// let ctx = parse_xray_trace_header(header);
237/// assert!(ctx.is_some());
238/// ```
239pub fn parse_xray_trace_header(header: &str) -> Option<SpanContext> {
240    let mut trace_id_str = None;
241    let mut parent_id_str = None;
242    let mut sampled = false;
243
244    for part in header.split(';') {
245        let part = part.trim();
246        if let Some(root) = part.strip_prefix("Root=") {
247            trace_id_str = convert_xray_trace_id(root);
248        } else if let Some(parent) = part.strip_prefix("Parent=") {
249            parent_id_str = Some(parent.to_string());
250        } else if part == "Sampled=1" {
251            sampled = true;
252        }
253    }
254
255    let trace_id_hex = trace_id_str?;
256    let parent_id_hex = parent_id_str?;
257
258    let trace_id_bytes = hex_to_bytes::<16>(&trace_id_hex)?;
259    let trace_id = TraceId::from_bytes(trace_id_bytes);
260
261    let span_id_bytes = hex_to_bytes::<8>(&parent_id_hex)?;
262    let span_id = SpanId::from_bytes(span_id_bytes);
263
264    let flags = if sampled {
265        TraceFlags::SAMPLED
266    } else {
267        TraceFlags::default()
268    };
269
270    Some(SpanContext::new(
271        trace_id,
272        span_id,
273        flags,
274        true,
275        TraceState::default(),
276    ))
277}
278
279fn convert_xray_trace_id(xray_id: &str) -> Option<String> {
280    let parts: Vec<&str> = xray_id.split('-').collect();
281    if parts.len() == 3 && parts[0] == "1" {
282        let combined = format!("{}{}", parts[1], parts[2]);
283        if combined.len() == 32 {
284            return Some(combined);
285        }
286    }
287    None
288}
289
290fn hex_to_bytes<const N: usize>(hex: &str) -> Option<[u8; N]> {
291    if hex.len() != N * 2 {
292        return None;
293    }
294
295    let mut bytes = [0u8; N];
296    for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
297        let high = hex_char_to_nibble(chunk[0])?;
298        let low = hex_char_to_nibble(chunk[1])?;
299        bytes[i] = (high << 4) | low;
300    }
301    Some(bytes)
302}
303
304fn hex_char_to_nibble(c: u8) -> Option<u8> {
305    match c {
306        b'0'..=b'9' => Some(c - b'0'),
307        b'a'..=b'f' => Some(c - b'a' + 10),
308        b'A'..=b'F' => Some(c - b'A' + 10),
309        _ => None,
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use aws_lambda_events::sqs::SqsMessage;
317
318    fn create_test_sqs_event() -> SqsEvent {
319        let mut attributes = HashMap::new();
320        attributes.insert(
321            "AWSTraceHeader".to_string(),
322            "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1"
323                .to_string(),
324        );
325
326        let mut message = SqsMessage::default();
327        message.message_id = Some("msg-123".to_string());
328        message.receipt_handle = Some("receipt-123".to_string());
329        message.body = Some(r#"{"test": "data"}"#.to_string());
330        message.attributes = attributes;
331        message.message_attributes = HashMap::new();
332        message.event_source = Some("aws:sqs".to_string());
333        message.event_source_arn = Some("arn:aws:sqs:us-east-1:123456789:my-queue".to_string());
334        message.aws_region = Some("us-east-1".to_string());
335
336        let mut event = SqsEvent::default();
337        event.records = vec![message];
338        event
339    }
340
341    fn create_test_lambda_context() -> LambdaContext {
342        LambdaContext::default()
343    }
344
345    #[test]
346    fn test_trigger_type() {
347        let extractor = SqsEventExtractor::new();
348        assert_eq!(
349            <SqsEventExtractor as TraceContextExtractor<SqsEvent>>::trigger_type(&extractor),
350            "pubsub"
351        );
352    }
353
354    #[test]
355    fn test_span_name_includes_queue() {
356        let extractor = SqsEventExtractor::new();
357        let event = create_test_sqs_event();
358        let ctx = create_test_lambda_context();
359
360        let name = extractor.span_name(&event, &ctx);
361        assert_eq!(name, "my-queue process");
362    }
363
364    #[test]
365    fn test_queue_name_from_arn() {
366        assert_eq!(
367            SqsEventExtractor::queue_name_from_arn("arn:aws:sqs:us-east-1:123456789:my-queue"),
368            Some("my-queue")
369        );
370        assert_eq!(
371            SqsEventExtractor::queue_name_from_arn(
372                "arn:aws:sqs:eu-west-1:987654321:another-queue.fifo"
373            ),
374            Some("another-queue.fifo")
375        );
376    }
377
378    #[test]
379    fn test_extract_links_from_xray_header() {
380        let extractor = SqsEventExtractor::new();
381        let event = create_test_sqs_event();
382
383        let links = extractor.extract_links(&event);
384
385        assert_eq!(links.len(), 1);
386        let link = &links[0];
387        assert!(link.span_context.is_valid());
388        assert_eq!(
389            link.span_context.trace_id().to_string(),
390            "5759e988bd862e3fe1be46a994272793"
391        );
392        assert_eq!(link.span_context.span_id().to_string(), "53995c3f42cd8ad8");
393        assert!(link.span_context.is_sampled());
394    }
395
396    #[test]
397    fn test_extract_links_multiple_messages() {
398        let extractor = SqsEventExtractor::new();
399
400        let mut attrs1 = HashMap::new();
401        attrs1.insert(
402            "AWSTraceHeader".to_string(),
403            "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1"
404                .to_string(),
405        );
406
407        let mut attrs2 = HashMap::new();
408        attrs2.insert(
409            "AWSTraceHeader".to_string(),
410            "Root=1-67890abc-def0123456789abcdef01234;Parent=1234567890abcdef;Sampled=0"
411                .to_string(),
412        );
413
414        let mut msg1 = SqsMessage::default();
415        msg1.attributes = attrs1;
416        msg1.message_attributes = HashMap::new();
417
418        let mut msg2 = SqsMessage::default();
419        msg2.attributes = attrs2;
420        msg2.message_attributes = HashMap::new();
421
422        let mut event = SqsEvent::default();
423        event.records = vec![msg1, msg2];
424
425        let links = extractor.extract_links(&event);
426
427        assert_eq!(links.len(), 2);
428        assert!(links[0].span_context.is_sampled());
429        assert!(!links[1].span_context.is_sampled());
430    }
431
432    #[test]
433    fn test_extract_links_no_trace_header() {
434        let extractor = SqsEventExtractor::new();
435
436        let mut msg = SqsMessage::default();
437        msg.attributes = HashMap::new();
438        msg.message_attributes = HashMap::new();
439
440        let mut event = SqsEvent::default();
441        event.records = vec![msg];
442
443        let links = extractor.extract_links(&event);
444        assert!(links.is_empty());
445    }
446
447    #[test]
448    fn test_parse_xray_trace_header() {
449        let header = "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1";
450
451        let ctx = parse_xray_trace_header(header).unwrap();
452
453        assert!(ctx.is_valid());
454        assert_eq!(
455            ctx.trace_id().to_string(),
456            "5759e988bd862e3fe1be46a994272793"
457        );
458        assert_eq!(ctx.span_id().to_string(), "53995c3f42cd8ad8");
459        assert!(ctx.is_sampled());
460        assert!(ctx.is_remote());
461    }
462
463    #[test]
464    fn test_parse_xray_trace_header_unsampled() {
465        let header = "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=0";
466
467        let ctx = parse_xray_trace_header(header).unwrap();
468        assert!(!ctx.is_sampled());
469    }
470
471    #[test]
472    fn test_parse_xray_trace_header_invalid() {
473        assert!(parse_xray_trace_header("invalid").is_none());
474        assert!(parse_xray_trace_header("Root=invalid;Parent=abc").is_none());
475        assert!(parse_xray_trace_header("Root=1-abc-def").is_none());
476    }
477
478    #[test]
479    fn test_convert_xray_trace_id() {
480        assert_eq!(
481            convert_xray_trace_id("1-5759e988-bd862e3fe1be46a994272793"),
482            Some("5759e988bd862e3fe1be46a994272793".to_string())
483        );
484    }
485
486    #[test]
487    fn test_hex_to_bytes() {
488        let bytes: [u8; 4] = hex_to_bytes("deadbeef").unwrap();
489        assert_eq!(bytes, [0xde, 0xad, 0xbe, 0xef]);
490    }
491
492    #[test]
493    fn test_hex_to_bytes_invalid() {
494        assert!(hex_to_bytes::<4>("deadbee").is_none());
495        assert!(hex_to_bytes::<4>("deadbeefx").is_none());
496        assert!(hex_to_bytes::<4>("deadbeeg").is_none());
497    }
498}