opentelemetry_lambda_tower/extractors/
sns.rs

1//! SNS event extractor for notification triggers.
2//!
3//! Extracts trace context from SNS message attributes. SNS follows similar
4//! patterns to SQS for trace propagation.
5
6use crate::extractor::TraceContextExtractor;
7use aws_lambda_events::sns::{SnsEvent, SnsRecord};
8use lambda_runtime::Context as LambdaContext;
9use opentelemetry::Context;
10use opentelemetry::trace::{Link, SpanContext, SpanId, TraceFlags, TraceId, TraceState};
11use opentelemetry_semantic_conventions::attribute::{
12    MESSAGING_BATCH_MESSAGE_COUNT, MESSAGING_DESTINATION_NAME, MESSAGING_MESSAGE_ID,
13    MESSAGING_OPERATION_TYPE, MESSAGING_SYSTEM,
14};
15use tracing::Span;
16
17/// Extractor for SNS notification events.
18///
19/// SNS events can carry trace context in the `AWSTraceHeader` message
20/// attribute using X-Ray format. This extractor:
21///
22/// 1. Does NOT set a parent context (returns current context)
23/// 2. Creates span links for each message's trace context
24///
25/// This follows OpenTelemetry semantic conventions for messaging systems,
26/// where the async nature of message queues means span links are more
27/// appropriate than parent-child relationships.
28///
29/// # Example
30///
31/// ```ignore
32/// use opentelemetry_lambda_tower::{OtelTracingLayer, SnsEventExtractor};
33///
34/// let layer = OtelTracingLayer::new(SnsEventExtractor::new());
35/// ```
36#[derive(Clone, Debug, Default)]
37pub struct SnsEventExtractor;
38
39impl SnsEventExtractor {
40    /// Creates a new SNS event extractor.
41    pub fn new() -> Self {
42        Self
43    }
44
45    /// Extracts the topic name from an SNS topic ARN.
46    ///
47    /// ARN format: `arn:aws:sns:{region}:{account}:{topic-name}`
48    fn topic_name_from_arn(arn: &str) -> Option<&str> {
49        arn.rsplit(':').next()
50    }
51}
52
53impl TraceContextExtractor<SnsEvent> for SnsEventExtractor {
54    fn extract_context(&self, _event: &SnsEvent) -> Context {
55        Context::current()
56    }
57
58    fn extract_links(&self, event: &SnsEvent) -> Vec<Link> {
59        event
60            .records
61            .iter()
62            .filter_map(extract_link_from_record)
63            .collect()
64    }
65
66    fn trigger_type(&self) -> &'static str {
67        "pubsub"
68    }
69
70    fn span_name(&self, event: &SnsEvent, lambda_ctx: &LambdaContext) -> String {
71        let topic_name = event
72            .records
73            .first()
74            .map(|r| &r.sns.topic_arn)
75            .and_then(|arn| Self::topic_name_from_arn(arn))
76            .unwrap_or(&lambda_ctx.env_config.function_name);
77
78        format!("{} process", topic_name)
79    }
80
81    fn record_attributes(&self, event: &SnsEvent, span: &Span) {
82        span.record(MESSAGING_SYSTEM, "aws_sns");
83        span.record(MESSAGING_OPERATION_TYPE, "process");
84
85        if let Some(record) = event.records.first() {
86            if let Some(topic_name) = Self::topic_name_from_arn(&record.sns.topic_arn) {
87                span.record(MESSAGING_DESTINATION_NAME, topic_name);
88            }
89
90            span.record(MESSAGING_MESSAGE_ID, record.sns.message_id.as_str());
91        }
92
93        span.record(MESSAGING_BATCH_MESSAGE_COUNT, event.records.len() as i64);
94    }
95}
96
97/// Extracts a span link from an SNS record's AWSTraceHeader message attribute.
98fn extract_link_from_record(record: &SnsRecord) -> Option<Link> {
99    let trace_attr = record.sns.message_attributes.get("AWSTraceHeader")?;
100
101    if trace_attr.value.is_empty() {
102        return None;
103    }
104
105    let span_context = parse_xray_trace_header(&trace_attr.value)?;
106
107    Some(Link::new(span_context, vec![], 0))
108}
109
110/// Parses an X-Ray trace header into a SpanContext.
111///
112/// X-Ray format: `Root=1-{epoch}-{random};Parent={span-id};Sampled={0|1}`
113fn parse_xray_trace_header(header: &str) -> Option<SpanContext> {
114    let mut trace_id_str = None;
115    let mut parent_id_str = None;
116    let mut sampled = false;
117
118    for part in header.split(';') {
119        let part = part.trim();
120        if let Some(root) = part.strip_prefix("Root=") {
121            trace_id_str = convert_xray_trace_id(root);
122        } else if let Some(parent) = part.strip_prefix("Parent=") {
123            parent_id_str = Some(parent.to_string());
124        } else if part == "Sampled=1" {
125            sampled = true;
126        }
127    }
128
129    let trace_id_hex = trace_id_str?;
130    let parent_id_hex = parent_id_str?;
131
132    let trace_id_bytes = hex_to_bytes::<16>(&trace_id_hex)?;
133    let trace_id = TraceId::from_bytes(trace_id_bytes);
134
135    let span_id_bytes = hex_to_bytes::<8>(&parent_id_hex)?;
136    let span_id = SpanId::from_bytes(span_id_bytes);
137
138    let flags = if sampled {
139        TraceFlags::SAMPLED
140    } else {
141        TraceFlags::default()
142    };
143
144    Some(SpanContext::new(
145        trace_id,
146        span_id,
147        flags,
148        true,
149        TraceState::default(),
150    ))
151}
152
153/// Converts X-Ray trace ID format to 32-character hex string.
154fn convert_xray_trace_id(xray_id: &str) -> Option<String> {
155    let parts: Vec<&str> = xray_id.split('-').collect();
156    if parts.len() == 3 && parts[0] == "1" {
157        let combined = format!("{}{}", parts[1], parts[2]);
158        if combined.len() == 32 {
159            return Some(combined);
160        }
161    }
162    None
163}
164
165/// Converts a hex string to a fixed-size byte array.
166fn hex_to_bytes<const N: usize>(hex: &str) -> Option<[u8; N]> {
167    if hex.len() != N * 2 {
168        return None;
169    }
170
171    let mut bytes = [0u8; N];
172    for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
173        let high = hex_char_to_nibble(chunk[0])?;
174        let low = hex_char_to_nibble(chunk[1])?;
175        bytes[i] = (high << 4) | low;
176    }
177    Some(bytes)
178}
179
180/// Converts a single hex character to its 4-bit value.
181fn hex_char_to_nibble(c: u8) -> Option<u8> {
182    match c {
183        b'0'..=b'9' => Some(c - b'0'),
184        b'a'..=b'f' => Some(c - b'a' + 10),
185        b'A'..=b'F' => Some(c - b'A' + 10),
186        _ => None,
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use aws_lambda_events::sns::{MessageAttribute, SnsMessage};
194    use chrono::Utc;
195    use std::collections::HashMap;
196
197    fn create_test_sns_event_with_trace() -> SnsEvent {
198        let mut attrs = HashMap::new();
199        let mut trace_attr = MessageAttribute::default();
200        trace_attr.data_type = "String".to_string();
201        trace_attr.value =
202            "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1"
203                .to_string();
204        attrs.insert("AWSTraceHeader".to_string(), trace_attr);
205
206        let mut sns_msg = SnsMessage::default();
207        sns_msg.sns_message_type = "Notification".to_string();
208        sns_msg.message_id = "msg-123".to_string();
209        sns_msg.topic_arn = "arn:aws:sns:us-east-1:123456789:my-topic".to_string();
210        sns_msg.timestamp = Utc::now();
211        sns_msg.signature_version = "1".to_string();
212        sns_msg.signature = "sig".to_string();
213        sns_msg.signing_cert_url = "https://cert".to_string();
214        sns_msg.unsubscribe_url = "https://unsub".to_string();
215        sns_msg.message = r#"{"test": "data"}"#.to_string();
216        sns_msg.message_attributes = attrs;
217
218        let mut record = SnsRecord::default();
219        record.event_source = "aws:sns".to_string();
220        record.event_version = "1.0".to_string();
221        record.event_subscription_arn =
222            "arn:aws:sns:us-east-1:123456789:my-topic:sub-123".to_string();
223        record.sns = sns_msg;
224
225        let mut event = SnsEvent::default();
226        event.records = vec![record];
227        event
228    }
229
230    #[test]
231    fn test_trigger_type() {
232        let extractor = SnsEventExtractor::new();
233        assert_eq!(extractor.trigger_type(), "pubsub");
234    }
235
236    #[test]
237    fn test_topic_name_from_arn() {
238        assert_eq!(
239            SnsEventExtractor::topic_name_from_arn("arn:aws:sns:us-east-1:123456789:my-topic"),
240            Some("my-topic")
241        );
242    }
243
244    #[test]
245    fn test_extract_links_with_trace_header() {
246        let extractor = SnsEventExtractor::new();
247        let event = create_test_sns_event_with_trace();
248
249        let links = extractor.extract_links(&event);
250
251        assert_eq!(links.len(), 1);
252        let link = &links[0];
253        assert!(link.span_context.is_valid());
254        assert_eq!(
255            link.span_context.trace_id().to_string(),
256            "5759e988bd862e3fe1be46a994272793"
257        );
258        assert_eq!(link.span_context.span_id().to_string(), "53995c3f42cd8ad8");
259        assert!(link.span_context.is_sampled());
260    }
261
262    #[test]
263    fn test_extract_links_no_trace_header() {
264        let extractor = SnsEventExtractor::new();
265
266        let mut sns_msg = SnsMessage::default();
267        sns_msg.sns_message_type = "Notification".to_string();
268        sns_msg.message_id = "msg-123".to_string();
269        sns_msg.topic_arn = "arn:aws:sns:us-east-1:123456789:my-topic".to_string();
270        sns_msg.timestamp = Utc::now();
271        sns_msg.signature_version = "1".to_string();
272        sns_msg.signature = "sig".to_string();
273        sns_msg.signing_cert_url = "https://cert".to_string();
274        sns_msg.unsubscribe_url = "https://unsub".to_string();
275        sns_msg.message = r#"{"test": "data"}"#.to_string();
276        sns_msg.message_attributes = HashMap::new();
277
278        let mut record = SnsRecord::default();
279        record.event_source = "aws:sns".to_string();
280        record.event_version = "1.0".to_string();
281        record.event_subscription_arn =
282            "arn:aws:sns:us-east-1:123456789:my-topic:sub-123".to_string();
283        record.sns = sns_msg;
284
285        let mut event = SnsEvent::default();
286        event.records = vec![record];
287
288        let links = extractor.extract_links(&event);
289        assert!(links.is_empty());
290    }
291
292    #[test]
293    fn test_parse_xray_trace_header() {
294        let header = "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1";
295
296        let ctx = parse_xray_trace_header(header).unwrap();
297
298        assert!(ctx.is_valid());
299        assert_eq!(
300            ctx.trace_id().to_string(),
301            "5759e988bd862e3fe1be46a994272793"
302        );
303        assert_eq!(ctx.span_id().to_string(), "53995c3f42cd8ad8");
304        assert!(ctx.is_sampled());
305    }
306}