opentelemetry_lambda_tower/extractors/
sqs.rs1use crate::extractor::TraceContextExtractor;
10use aws_lambda_events::sqs::{SqsEvent, SqsEventObj};
11use lambda_runtime::Context as LambdaContext;
12use opentelemetry::Context;
13use opentelemetry::propagation::Extractor;
14use opentelemetry::trace::{
15 Link, SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState,
16};
17use opentelemetry_semantic_conventions::attribute::{
18 MESSAGING_BATCH_MESSAGE_COUNT, MESSAGING_DESTINATION_NAME, MESSAGING_MESSAGE_ID,
19 MESSAGING_OPERATION_TYPE, MESSAGING_SYSTEM,
20};
21use serde::Serialize;
22use std::collections::HashMap;
23use tracing::Span;
24
25#[derive(Clone, Debug, Default)]
49pub struct SqsEventExtractor;
50
51impl SqsEventExtractor {
52 pub fn new() -> Self {
54 Self
55 }
56
57 fn queue_name_from_arn(arn: &str) -> Option<&str> {
61 arn.rsplit(':').next()
62 }
63}
64
65impl TraceContextExtractor<SqsEvent> for SqsEventExtractor {
66 fn extract_context(&self, _event: &SqsEvent) -> Context {
67 Context::current()
68 }
69
70 fn extract_links(&self, event: &SqsEvent) -> Vec<Link> {
71 event
72 .records
73 .iter()
74 .filter_map(|msg| {
75 extract_link_from_sqs_message(&msg.message_attributes, &msg.attributes)
76 })
77 .collect()
78 }
79
80 fn trigger_type(&self) -> &'static str {
81 "pubsub"
82 }
83
84 fn span_name(&self, event: &SqsEvent, lambda_ctx: &LambdaContext) -> String {
85 let queue_name = event
86 .records
87 .first()
88 .and_then(|r| r.event_source_arn.as_deref())
89 .and_then(Self::queue_name_from_arn)
90 .unwrap_or(&lambda_ctx.env_config.function_name);
91
92 format!("{} process", queue_name)
93 }
94
95 fn record_attributes(&self, event: &SqsEvent, span: &Span) {
96 span.record(MESSAGING_SYSTEM, "aws_sqs");
97 span.record(MESSAGING_OPERATION_TYPE, "process");
98
99 if let Some(record) = event.records.first()
100 && let Some(ref arn) = record.event_source_arn
101 && let Some(queue_name) = Self::queue_name_from_arn(arn)
102 {
103 span.record(MESSAGING_DESTINATION_NAME, queue_name);
104 }
105
106 span.record(MESSAGING_BATCH_MESSAGE_COUNT, event.records.len() as i64);
107
108 if event.records.len() == 1
109 && let Some(ref msg_id) = event.records[0].message_id
110 {
111 span.record(MESSAGING_MESSAGE_ID, msg_id.as_str());
112 }
113 }
114}
115
116impl<T: Serialize + Send + Sync + 'static> TraceContextExtractor<SqsEventObj<T>>
117 for SqsEventExtractor
118{
119 fn extract_context(&self, _event: &SqsEventObj<T>) -> Context {
120 Context::current()
121 }
122
123 fn extract_links(&self, event: &SqsEventObj<T>) -> Vec<Link> {
124 event
125 .records
126 .iter()
127 .filter_map(|msg| {
128 extract_link_from_sqs_message(&msg.message_attributes, &msg.attributes)
129 })
130 .collect()
131 }
132
133 fn trigger_type(&self) -> &'static str {
134 "pubsub"
135 }
136
137 fn span_name(&self, event: &SqsEventObj<T>, lambda_ctx: &LambdaContext) -> String {
138 let queue_name = event
139 .records
140 .first()
141 .and_then(|r| r.event_source_arn.as_deref())
142 .and_then(Self::queue_name_from_arn)
143 .unwrap_or(&lambda_ctx.env_config.function_name);
144
145 format!("{} process", queue_name)
146 }
147
148 fn record_attributes(&self, event: &SqsEventObj<T>, span: &Span) {
149 span.record(MESSAGING_SYSTEM, "aws_sqs");
150 span.record(MESSAGING_OPERATION_TYPE, "process");
151
152 if let Some(record) = event.records.first()
153 && let Some(ref arn) = record.event_source_arn
154 && let Some(queue_name) = Self::queue_name_from_arn(arn)
155 {
156 span.record(MESSAGING_DESTINATION_NAME, queue_name);
157 }
158
159 span.record(MESSAGING_BATCH_MESSAGE_COUNT, event.records.len() as i64);
160
161 if event.records.len() == 1
162 && let Some(ref msg_id) = event.records[0].message_id
163 {
164 span.record(MESSAGING_MESSAGE_ID, msg_id.as_str());
165 }
166 }
167}
168
169use aws_lambda_events::sqs::SqsMessageAttribute;
170
171fn extract_link_from_sqs_message(
176 message_attributes: &HashMap<String, SqsMessageAttribute>,
177 system_attributes: &HashMap<String, String>,
178) -> Option<Link> {
179 if let Some(span_context) = extract_trace_context_from_message_attributes(message_attributes) {
180 return Some(Link::new(span_context, vec![], 0));
181 }
182
183 if let Some(trace_header) = system_attributes.get("AWSTraceHeader")
184 && let Some(span_context) = parse_xray_trace_header(trace_header)
185 {
186 return Some(Link::new(span_context, vec![], 0));
187 }
188
189 None
190}
191
192fn extract_trace_context_from_message_attributes(
197 message_attributes: &HashMap<String, SqsMessageAttribute>,
198) -> Option<SpanContext> {
199 let extractor = SqsMessageAttributeExtractor(message_attributes);
200 let ctx =
201 opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&extractor));
202
203 let span_context = ctx.span().span_context().clone();
204 if span_context.is_valid() {
205 Some(span_context)
206 } else {
207 None
208 }
209}
210
211struct SqsMessageAttributeExtractor<'a>(&'a HashMap<String, SqsMessageAttribute>);
213
214impl Extractor for SqsMessageAttributeExtractor<'_> {
215 fn get(&self, key: &str) -> Option<&str> {
216 self.0
217 .get(key)
218 .and_then(|attr| attr.string_value.as_deref())
219 }
220
221 fn keys(&self) -> Vec<&str> {
222 self.0.keys().map(|k| k.as_str()).collect()
223 }
224}
225
226pub fn parse_xray_trace_header(header: &str) -> Option<SpanContext> {
240 let mut trace_id_str = None;
241 let mut parent_id_str = None;
242 let mut sampled = false;
243
244 for part in header.split(';') {
245 let part = part.trim();
246 if let Some(root) = part.strip_prefix("Root=") {
247 trace_id_str = convert_xray_trace_id(root);
248 } else if let Some(parent) = part.strip_prefix("Parent=") {
249 parent_id_str = Some(parent.to_string());
250 } else if part == "Sampled=1" {
251 sampled = true;
252 }
253 }
254
255 let trace_id_hex = trace_id_str?;
256 let parent_id_hex = parent_id_str?;
257
258 let trace_id_bytes = hex_to_bytes::<16>(&trace_id_hex)?;
259 let trace_id = TraceId::from_bytes(trace_id_bytes);
260
261 let span_id_bytes = hex_to_bytes::<8>(&parent_id_hex)?;
262 let span_id = SpanId::from_bytes(span_id_bytes);
263
264 let flags = if sampled {
265 TraceFlags::SAMPLED
266 } else {
267 TraceFlags::default()
268 };
269
270 Some(SpanContext::new(
271 trace_id,
272 span_id,
273 flags,
274 true,
275 TraceState::default(),
276 ))
277}
278
279fn convert_xray_trace_id(xray_id: &str) -> Option<String> {
280 let parts: Vec<&str> = xray_id.split('-').collect();
281 if parts.len() == 3 && parts[0] == "1" {
282 let combined = format!("{}{}", parts[1], parts[2]);
283 if combined.len() == 32 {
284 return Some(combined);
285 }
286 }
287 None
288}
289
290fn hex_to_bytes<const N: usize>(hex: &str) -> Option<[u8; N]> {
291 if hex.len() != N * 2 {
292 return None;
293 }
294
295 let mut bytes = [0u8; N];
296 for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
297 let high = hex_char_to_nibble(chunk[0])?;
298 let low = hex_char_to_nibble(chunk[1])?;
299 bytes[i] = (high << 4) | low;
300 }
301 Some(bytes)
302}
303
304fn hex_char_to_nibble(c: u8) -> Option<u8> {
305 match c {
306 b'0'..=b'9' => Some(c - b'0'),
307 b'a'..=b'f' => Some(c - b'a' + 10),
308 b'A'..=b'F' => Some(c - b'A' + 10),
309 _ => None,
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use aws_lambda_events::sqs::SqsMessage;
317
318 fn create_test_sqs_event() -> SqsEvent {
319 let mut attributes = HashMap::new();
320 attributes.insert(
321 "AWSTraceHeader".to_string(),
322 "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1"
323 .to_string(),
324 );
325
326 let mut message = SqsMessage::default();
327 message.message_id = Some("msg-123".to_string());
328 message.receipt_handle = Some("receipt-123".to_string());
329 message.body = Some(r#"{"test": "data"}"#.to_string());
330 message.attributes = attributes;
331 message.message_attributes = HashMap::new();
332 message.event_source = Some("aws:sqs".to_string());
333 message.event_source_arn = Some("arn:aws:sqs:us-east-1:123456789:my-queue".to_string());
334 message.aws_region = Some("us-east-1".to_string());
335
336 let mut event = SqsEvent::default();
337 event.records = vec![message];
338 event
339 }
340
341 fn create_test_lambda_context() -> LambdaContext {
342 LambdaContext::default()
343 }
344
345 #[test]
346 fn test_trigger_type() {
347 let extractor = SqsEventExtractor::new();
348 assert_eq!(
349 <SqsEventExtractor as TraceContextExtractor<SqsEvent>>::trigger_type(&extractor),
350 "pubsub"
351 );
352 }
353
354 #[test]
355 fn test_span_name_includes_queue() {
356 let extractor = SqsEventExtractor::new();
357 let event = create_test_sqs_event();
358 let ctx = create_test_lambda_context();
359
360 let name = extractor.span_name(&event, &ctx);
361 assert_eq!(name, "my-queue process");
362 }
363
364 #[test]
365 fn test_queue_name_from_arn() {
366 assert_eq!(
367 SqsEventExtractor::queue_name_from_arn("arn:aws:sqs:us-east-1:123456789:my-queue"),
368 Some("my-queue")
369 );
370 assert_eq!(
371 SqsEventExtractor::queue_name_from_arn(
372 "arn:aws:sqs:eu-west-1:987654321:another-queue.fifo"
373 ),
374 Some("another-queue.fifo")
375 );
376 }
377
378 #[test]
379 fn test_extract_links_from_xray_header() {
380 let extractor = SqsEventExtractor::new();
381 let event = create_test_sqs_event();
382
383 let links = extractor.extract_links(&event);
384
385 assert_eq!(links.len(), 1);
386 let link = &links[0];
387 assert!(link.span_context.is_valid());
388 assert_eq!(
389 link.span_context.trace_id().to_string(),
390 "5759e988bd862e3fe1be46a994272793"
391 );
392 assert_eq!(link.span_context.span_id().to_string(), "53995c3f42cd8ad8");
393 assert!(link.span_context.is_sampled());
394 }
395
396 #[test]
397 fn test_extract_links_multiple_messages() {
398 let extractor = SqsEventExtractor::new();
399
400 let mut attrs1 = HashMap::new();
401 attrs1.insert(
402 "AWSTraceHeader".to_string(),
403 "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1"
404 .to_string(),
405 );
406
407 let mut attrs2 = HashMap::new();
408 attrs2.insert(
409 "AWSTraceHeader".to_string(),
410 "Root=1-67890abc-def0123456789abcdef01234;Parent=1234567890abcdef;Sampled=0"
411 .to_string(),
412 );
413
414 let mut msg1 = SqsMessage::default();
415 msg1.attributes = attrs1;
416 msg1.message_attributes = HashMap::new();
417
418 let mut msg2 = SqsMessage::default();
419 msg2.attributes = attrs2;
420 msg2.message_attributes = HashMap::new();
421
422 let mut event = SqsEvent::default();
423 event.records = vec![msg1, msg2];
424
425 let links = extractor.extract_links(&event);
426
427 assert_eq!(links.len(), 2);
428 assert!(links[0].span_context.is_sampled());
429 assert!(!links[1].span_context.is_sampled());
430 }
431
432 #[test]
433 fn test_extract_links_no_trace_header() {
434 let extractor = SqsEventExtractor::new();
435
436 let mut msg = SqsMessage::default();
437 msg.attributes = HashMap::new();
438 msg.message_attributes = HashMap::new();
439
440 let mut event = SqsEvent::default();
441 event.records = vec![msg];
442
443 let links = extractor.extract_links(&event);
444 assert!(links.is_empty());
445 }
446
447 #[test]
448 fn test_parse_xray_trace_header() {
449 let header = "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1";
450
451 let ctx = parse_xray_trace_header(header).unwrap();
452
453 assert!(ctx.is_valid());
454 assert_eq!(
455 ctx.trace_id().to_string(),
456 "5759e988bd862e3fe1be46a994272793"
457 );
458 assert_eq!(ctx.span_id().to_string(), "53995c3f42cd8ad8");
459 assert!(ctx.is_sampled());
460 assert!(ctx.is_remote());
461 }
462
463 #[test]
464 fn test_parse_xray_trace_header_unsampled() {
465 let header = "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=0";
466
467 let ctx = parse_xray_trace_header(header).unwrap();
468 assert!(!ctx.is_sampled());
469 }
470
471 #[test]
472 fn test_parse_xray_trace_header_invalid() {
473 assert!(parse_xray_trace_header("invalid").is_none());
474 assert!(parse_xray_trace_header("Root=invalid;Parent=abc").is_none());
475 assert!(parse_xray_trace_header("Root=1-abc-def").is_none());
476 }
477
478 #[test]
479 fn test_convert_xray_trace_id() {
480 assert_eq!(
481 convert_xray_trace_id("1-5759e988-bd862e3fe1be46a994272793"),
482 Some("5759e988bd862e3fe1be46a994272793".to_string())
483 );
484 }
485
486 #[test]
487 fn test_hex_to_bytes() {
488 let bytes: [u8; 4] = hex_to_bytes("deadbeef").unwrap();
489 assert_eq!(bytes, [0xde, 0xad, 0xbe, 0xef]);
490 }
491
492 #[test]
493 fn test_hex_to_bytes_invalid() {
494 assert!(hex_to_bytes::<4>("deadbee").is_none());
495 assert!(hex_to_bytes::<4>("deadbeefx").is_none());
496 assert!(hex_to_bytes::<4>("deadbeeg").is_none());
497 }
498}