opentelemetry_lambda_tower/extractors/
sns.rs

1//! SNS event extractor for notification triggers.
2//!
3//! Extracts trace context from SNS message attributes, checking:
4//! 1. `message_attributes` for W3C `traceparent` (injected by OTel-instrumented producers)
5//! 2. `message_attributes` for `AWSTraceHeader` in X-Ray format
6
7use crate::extractor::TraceContextExtractor;
8use aws_lambda_events::sns::{MessageAttribute, SnsEvent, SnsRecord};
9use lambda_runtime::Context as LambdaContext;
10use opentelemetry::Context;
11use opentelemetry::propagation::Extractor;
12use opentelemetry::trace::{
13    Link, SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState,
14};
15use opentelemetry_semantic_conventions::attribute::{
16    MESSAGING_BATCH_MESSAGE_COUNT, MESSAGING_DESTINATION_NAME, MESSAGING_MESSAGE_ID,
17    MESSAGING_OPERATION_TYPE, MESSAGING_SYSTEM,
18};
19use std::collections::HashMap;
20use tracing::Span;
21
22/// Extractor for SNS notification events.
23///
24/// SNS events may carry trace context in two locations within `message_attributes`:
25/// 1. W3C `traceparent`/`tracestate` - injected by OTel-instrumented producers
26/// 2. `AWSTraceHeader` - X-Ray format set by AWS
27///
28/// This extractor checks both, preferring W3C format when available.
29///
30/// Per OpenTelemetry semantic conventions for messaging systems, this extractor:
31/// - Does NOT set a parent context (returns current context)
32/// - Creates span links for each message's trace context
33///
34/// This approach is appropriate because the async nature of message queues
35/// means span links are more semantically correct than parent-child relationships.
36///
37/// # Example
38///
39/// ```ignore
40/// use opentelemetry_lambda_tower::{OtelTracingLayer, SnsEventExtractor};
41///
42/// let layer = OtelTracingLayer::new(SnsEventExtractor::new());
43/// ```
44#[derive(Clone, Debug, Default)]
45pub struct SnsEventExtractor;
46
47impl SnsEventExtractor {
48    /// Creates a new SNS event extractor.
49    pub fn new() -> Self {
50        Self
51    }
52
53    /// Extracts the topic name from an SNS topic ARN.
54    ///
55    /// ARN format: `arn:aws:sns:{region}:{account}:{topic-name}`
56    fn topic_name_from_arn(arn: &str) -> Option<&str> {
57        arn.rsplit(':').next()
58    }
59}
60
61impl TraceContextExtractor<SnsEvent> for SnsEventExtractor {
62    fn extract_context(&self, _event: &SnsEvent) -> Context {
63        Context::current()
64    }
65
66    fn extract_links(&self, event: &SnsEvent) -> Vec<Link> {
67        event
68            .records
69            .iter()
70            .filter_map(extract_link_from_record)
71            .collect()
72    }
73
74    fn trigger_type(&self) -> &'static str {
75        "pubsub"
76    }
77
78    fn span_name(&self, event: &SnsEvent, lambda_ctx: &LambdaContext) -> String {
79        let topic_name = event
80            .records
81            .first()
82            .map(|r| &r.sns.topic_arn)
83            .and_then(|arn| Self::topic_name_from_arn(arn))
84            .unwrap_or(&lambda_ctx.env_config.function_name);
85
86        format!("{} process", topic_name)
87    }
88
89    fn record_attributes(&self, event: &SnsEvent, span: &Span) {
90        span.record(MESSAGING_SYSTEM, "aws_sns");
91        span.record(MESSAGING_OPERATION_TYPE, "process");
92
93        if let Some(record) = event.records.first() {
94            if let Some(topic_name) = Self::topic_name_from_arn(&record.sns.topic_arn) {
95                span.record(MESSAGING_DESTINATION_NAME, topic_name);
96            }
97
98            span.record(MESSAGING_MESSAGE_ID, record.sns.message_id.as_str());
99        }
100
101        span.record(MESSAGING_BATCH_MESSAGE_COUNT, event.records.len() as i64);
102    }
103}
104
105/// Extracts a span link from an SNS record's message attributes.
106///
107/// Uses the globally configured propagator to extract trace context, then
108/// falls back to parsing `AWSTraceHeader` in X-Ray format as a Lambda-specific default.
109fn extract_link_from_record(record: &SnsRecord) -> Option<Link> {
110    if let Some(span_context) =
111        extract_trace_context_from_message_attributes(&record.sns.message_attributes)
112    {
113        return Some(Link::new(span_context, vec![], 0));
114    }
115
116    if let Some(trace_attr) = record.sns.message_attributes.get("AWSTraceHeader")
117        && !trace_attr.value.is_empty()
118        && let Some(span_context) = parse_xray_trace_header(&trace_attr.value)
119    {
120        return Some(Link::new(span_context, vec![], 0));
121    }
122
123    None
124}
125
126/// Extracts trace context from SNS message attributes using the global propagator.
127///
128/// The propagator determines which headers to look for (e.g. `traceparent` for W3C,
129/// `X-Amzn-Trace-Id` for X-Ray, `X-B3-*` for Zipkin, etc.).
130fn extract_trace_context_from_message_attributes(
131    message_attributes: &HashMap<String, MessageAttribute>,
132) -> Option<SpanContext> {
133    let extractor = SnsMessageAttributeExtractor(message_attributes);
134    let ctx =
135        opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&extractor));
136
137    let span_context = ctx.span().span_context().clone();
138    if span_context.is_valid() {
139        Some(span_context)
140    } else {
141        None
142    }
143}
144
145/// Adapter to extract trace context from SNS message attributes.
146struct SnsMessageAttributeExtractor<'a>(&'a HashMap<String, MessageAttribute>);
147
148impl Extractor for SnsMessageAttributeExtractor<'_> {
149    fn get(&self, key: &str) -> Option<&str> {
150        self.0.get(key).map(|attr| attr.value.as_str())
151    }
152
153    fn keys(&self) -> Vec<&str> {
154        self.0.keys().map(|k| k.as_str()).collect()
155    }
156}
157
158/// Parses an X-Ray trace header into a SpanContext.
159///
160/// X-Ray format: `Root=1-{epoch}-{random};Parent={span-id};Sampled={0|1}`
161fn parse_xray_trace_header(header: &str) -> Option<SpanContext> {
162    let mut trace_id_str = None;
163    let mut parent_id_str = None;
164    let mut sampled = false;
165
166    for part in header.split(';') {
167        let part = part.trim();
168        if let Some(root) = part.strip_prefix("Root=") {
169            trace_id_str = convert_xray_trace_id(root);
170        } else if let Some(parent) = part.strip_prefix("Parent=") {
171            parent_id_str = Some(parent.to_string());
172        } else if part == "Sampled=1" {
173            sampled = true;
174        }
175    }
176
177    let trace_id_hex = trace_id_str?;
178    let parent_id_hex = parent_id_str?;
179
180    let trace_id_bytes = hex_to_bytes::<16>(&trace_id_hex)?;
181    let trace_id = TraceId::from_bytes(trace_id_bytes);
182
183    let span_id_bytes = hex_to_bytes::<8>(&parent_id_hex)?;
184    let span_id = SpanId::from_bytes(span_id_bytes);
185
186    let flags = if sampled {
187        TraceFlags::SAMPLED
188    } else {
189        TraceFlags::default()
190    };
191
192    Some(SpanContext::new(
193        trace_id,
194        span_id,
195        flags,
196        true,
197        TraceState::default(),
198    ))
199}
200
201/// Converts X-Ray trace ID format to 32-character hex string.
202fn convert_xray_trace_id(xray_id: &str) -> Option<String> {
203    let parts: Vec<&str> = xray_id.split('-').collect();
204    if parts.len() == 3 && parts[0] == "1" {
205        let combined = format!("{}{}", parts[1], parts[2]);
206        if combined.len() == 32 {
207            return Some(combined);
208        }
209    }
210    None
211}
212
213/// Converts a hex string to a fixed-size byte array.
214fn hex_to_bytes<const N: usize>(hex: &str) -> Option<[u8; N]> {
215    if hex.len() != N * 2 {
216        return None;
217    }
218
219    let mut bytes = [0u8; N];
220    for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
221        let high = hex_char_to_nibble(chunk[0])?;
222        let low = hex_char_to_nibble(chunk[1])?;
223        bytes[i] = (high << 4) | low;
224    }
225    Some(bytes)
226}
227
228/// Converts a single hex character to its 4-bit value.
229fn hex_char_to_nibble(c: u8) -> Option<u8> {
230    match c {
231        b'0'..=b'9' => Some(c - b'0'),
232        b'a'..=b'f' => Some(c - b'a' + 10),
233        b'A'..=b'F' => Some(c - b'A' + 10),
234        _ => None,
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use aws_lambda_events::sns::{MessageAttribute, SnsMessage};
242    use chrono::Utc;
243    use std::collections::HashMap;
244
245    fn create_test_sns_event_with_trace() -> SnsEvent {
246        let mut attrs = HashMap::new();
247        let mut trace_attr = MessageAttribute::default();
248        trace_attr.data_type = "String".to_string();
249        trace_attr.value =
250            "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1"
251                .to_string();
252        attrs.insert("AWSTraceHeader".to_string(), trace_attr);
253
254        let mut sns_msg = SnsMessage::default();
255        sns_msg.sns_message_type = "Notification".to_string();
256        sns_msg.message_id = "msg-123".to_string();
257        sns_msg.topic_arn = "arn:aws:sns:us-east-1:123456789:my-topic".to_string();
258        sns_msg.timestamp = Utc::now();
259        sns_msg.signature_version = "1".to_string();
260        sns_msg.signature = "sig".to_string();
261        sns_msg.signing_cert_url = "https://cert".to_string();
262        sns_msg.unsubscribe_url = "https://unsub".to_string();
263        sns_msg.message = r#"{"test": "data"}"#.to_string();
264        sns_msg.message_attributes = attrs;
265
266        let mut record = SnsRecord::default();
267        record.event_source = "aws:sns".to_string();
268        record.event_version = "1.0".to_string();
269        record.event_subscription_arn =
270            "arn:aws:sns:us-east-1:123456789:my-topic:sub-123".to_string();
271        record.sns = sns_msg;
272
273        let mut event = SnsEvent::default();
274        event.records = vec![record];
275        event
276    }
277
278    #[test]
279    fn test_trigger_type() {
280        let extractor = SnsEventExtractor::new();
281        assert_eq!(extractor.trigger_type(), "pubsub");
282    }
283
284    #[test]
285    fn test_topic_name_from_arn() {
286        assert_eq!(
287            SnsEventExtractor::topic_name_from_arn("arn:aws:sns:us-east-1:123456789:my-topic"),
288            Some("my-topic")
289        );
290    }
291
292    #[test]
293    fn test_extract_links_with_trace_header() {
294        let extractor = SnsEventExtractor::new();
295        let event = create_test_sns_event_with_trace();
296
297        let links = extractor.extract_links(&event);
298
299        assert_eq!(links.len(), 1);
300        let link = &links[0];
301        assert!(link.span_context.is_valid());
302        assert_eq!(
303            link.span_context.trace_id().to_string(),
304            "5759e988bd862e3fe1be46a994272793"
305        );
306        assert_eq!(link.span_context.span_id().to_string(), "53995c3f42cd8ad8");
307        assert!(link.span_context.is_sampled());
308    }
309
310    #[test]
311    fn test_extract_links_no_trace_header() {
312        let extractor = SnsEventExtractor::new();
313
314        let mut sns_msg = SnsMessage::default();
315        sns_msg.sns_message_type = "Notification".to_string();
316        sns_msg.message_id = "msg-123".to_string();
317        sns_msg.topic_arn = "arn:aws:sns:us-east-1:123456789:my-topic".to_string();
318        sns_msg.timestamp = Utc::now();
319        sns_msg.signature_version = "1".to_string();
320        sns_msg.signature = "sig".to_string();
321        sns_msg.signing_cert_url = "https://cert".to_string();
322        sns_msg.unsubscribe_url = "https://unsub".to_string();
323        sns_msg.message = r#"{"test": "data"}"#.to_string();
324        sns_msg.message_attributes = HashMap::new();
325
326        let mut record = SnsRecord::default();
327        record.event_source = "aws:sns".to_string();
328        record.event_version = "1.0".to_string();
329        record.event_subscription_arn =
330            "arn:aws:sns:us-east-1:123456789:my-topic:sub-123".to_string();
331        record.sns = sns_msg;
332
333        let mut event = SnsEvent::default();
334        event.records = vec![record];
335
336        let links = extractor.extract_links(&event);
337        assert!(links.is_empty());
338    }
339
340    #[test]
341    fn test_parse_xray_trace_header() {
342        let header = "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1";
343
344        let ctx = parse_xray_trace_header(header).unwrap();
345
346        assert!(ctx.is_valid());
347        assert_eq!(
348            ctx.trace_id().to_string(),
349            "5759e988bd862e3fe1be46a994272793"
350        );
351        assert_eq!(ctx.span_id().to_string(), "53995c3f42cd8ad8");
352        assert!(ctx.is_sampled());
353    }
354}