opentelemetry_lambda_tower/extractors/
sns.rs1use crate::extractor::TraceContextExtractor;
7use aws_lambda_events::sns::{SnsEvent, SnsRecord};
8use lambda_runtime::Context as LambdaContext;
9use opentelemetry::Context;
10use opentelemetry::trace::{Link, SpanContext, SpanId, TraceFlags, TraceId, TraceState};
11use opentelemetry_semantic_conventions::attribute::{
12 MESSAGING_BATCH_MESSAGE_COUNT, MESSAGING_DESTINATION_NAME, MESSAGING_MESSAGE_ID,
13 MESSAGING_OPERATION_TYPE, MESSAGING_SYSTEM,
14};
15use tracing::Span;
16
17#[derive(Clone, Debug, Default)]
37pub struct SnsEventExtractor;
38
39impl SnsEventExtractor {
40 pub fn new() -> Self {
42 Self
43 }
44
45 fn topic_name_from_arn(arn: &str) -> Option<&str> {
49 arn.rsplit(':').next()
50 }
51}
52
53impl TraceContextExtractor<SnsEvent> for SnsEventExtractor {
54 fn extract_context(&self, _event: &SnsEvent) -> Context {
55 Context::current()
56 }
57
58 fn extract_links(&self, event: &SnsEvent) -> Vec<Link> {
59 event
60 .records
61 .iter()
62 .filter_map(extract_link_from_record)
63 .collect()
64 }
65
66 fn trigger_type(&self) -> &'static str {
67 "pubsub"
68 }
69
70 fn span_name(&self, event: &SnsEvent, lambda_ctx: &LambdaContext) -> String {
71 let topic_name = event
72 .records
73 .first()
74 .map(|r| &r.sns.topic_arn)
75 .and_then(|arn| Self::topic_name_from_arn(arn))
76 .unwrap_or(&lambda_ctx.env_config.function_name);
77
78 format!("{} process", topic_name)
79 }
80
81 fn record_attributes(&self, event: &SnsEvent, span: &Span) {
82 span.record(MESSAGING_SYSTEM, "aws_sns");
83 span.record(MESSAGING_OPERATION_TYPE, "process");
84
85 if let Some(record) = event.records.first() {
86 if let Some(topic_name) = Self::topic_name_from_arn(&record.sns.topic_arn) {
87 span.record(MESSAGING_DESTINATION_NAME, topic_name);
88 }
89
90 span.record(MESSAGING_MESSAGE_ID, record.sns.message_id.as_str());
91 }
92
93 span.record(MESSAGING_BATCH_MESSAGE_COUNT, event.records.len() as i64);
94 }
95}
96
97fn extract_link_from_record(record: &SnsRecord) -> Option<Link> {
99 let trace_attr = record.sns.message_attributes.get("AWSTraceHeader")?;
100
101 if trace_attr.value.is_empty() {
102 return None;
103 }
104
105 let span_context = parse_xray_trace_header(&trace_attr.value)?;
106
107 Some(Link::new(span_context, vec![], 0))
108}
109
110fn parse_xray_trace_header(header: &str) -> Option<SpanContext> {
114 let mut trace_id_str = None;
115 let mut parent_id_str = None;
116 let mut sampled = false;
117
118 for part in header.split(';') {
119 let part = part.trim();
120 if let Some(root) = part.strip_prefix("Root=") {
121 trace_id_str = convert_xray_trace_id(root);
122 } else if let Some(parent) = part.strip_prefix("Parent=") {
123 parent_id_str = Some(parent.to_string());
124 } else if part == "Sampled=1" {
125 sampled = true;
126 }
127 }
128
129 let trace_id_hex = trace_id_str?;
130 let parent_id_hex = parent_id_str?;
131
132 let trace_id_bytes = hex_to_bytes::<16>(&trace_id_hex)?;
133 let trace_id = TraceId::from_bytes(trace_id_bytes);
134
135 let span_id_bytes = hex_to_bytes::<8>(&parent_id_hex)?;
136 let span_id = SpanId::from_bytes(span_id_bytes);
137
138 let flags = if sampled {
139 TraceFlags::SAMPLED
140 } else {
141 TraceFlags::default()
142 };
143
144 Some(SpanContext::new(
145 trace_id,
146 span_id,
147 flags,
148 true,
149 TraceState::default(),
150 ))
151}
152
153fn convert_xray_trace_id(xray_id: &str) -> Option<String> {
155 let parts: Vec<&str> = xray_id.split('-').collect();
156 if parts.len() == 3 && parts[0] == "1" {
157 let combined = format!("{}{}", parts[1], parts[2]);
158 if combined.len() == 32 {
159 return Some(combined);
160 }
161 }
162 None
163}
164
165fn hex_to_bytes<const N: usize>(hex: &str) -> Option<[u8; N]> {
167 if hex.len() != N * 2 {
168 return None;
169 }
170
171 let mut bytes = [0u8; N];
172 for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
173 let high = hex_char_to_nibble(chunk[0])?;
174 let low = hex_char_to_nibble(chunk[1])?;
175 bytes[i] = (high << 4) | low;
176 }
177 Some(bytes)
178}
179
180fn hex_char_to_nibble(c: u8) -> Option<u8> {
182 match c {
183 b'0'..=b'9' => Some(c - b'0'),
184 b'a'..=b'f' => Some(c - b'a' + 10),
185 b'A'..=b'F' => Some(c - b'A' + 10),
186 _ => None,
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use aws_lambda_events::sns::{MessageAttribute, SnsMessage};
194 use chrono::Utc;
195 use std::collections::HashMap;
196
197 fn create_test_sns_event_with_trace() -> SnsEvent {
198 let mut attrs = HashMap::new();
199 let mut trace_attr = MessageAttribute::default();
200 trace_attr.data_type = "String".to_string();
201 trace_attr.value =
202 "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1"
203 .to_string();
204 attrs.insert("AWSTraceHeader".to_string(), trace_attr);
205
206 let mut sns_msg = SnsMessage::default();
207 sns_msg.sns_message_type = "Notification".to_string();
208 sns_msg.message_id = "msg-123".to_string();
209 sns_msg.topic_arn = "arn:aws:sns:us-east-1:123456789:my-topic".to_string();
210 sns_msg.timestamp = Utc::now();
211 sns_msg.signature_version = "1".to_string();
212 sns_msg.signature = "sig".to_string();
213 sns_msg.signing_cert_url = "https://cert".to_string();
214 sns_msg.unsubscribe_url = "https://unsub".to_string();
215 sns_msg.message = r#"{"test": "data"}"#.to_string();
216 sns_msg.message_attributes = attrs;
217
218 let mut record = SnsRecord::default();
219 record.event_source = "aws:sns".to_string();
220 record.event_version = "1.0".to_string();
221 record.event_subscription_arn =
222 "arn:aws:sns:us-east-1:123456789:my-topic:sub-123".to_string();
223 record.sns = sns_msg;
224
225 let mut event = SnsEvent::default();
226 event.records = vec![record];
227 event
228 }
229
230 #[test]
231 fn test_trigger_type() {
232 let extractor = SnsEventExtractor::new();
233 assert_eq!(extractor.trigger_type(), "pubsub");
234 }
235
236 #[test]
237 fn test_topic_name_from_arn() {
238 assert_eq!(
239 SnsEventExtractor::topic_name_from_arn("arn:aws:sns:us-east-1:123456789:my-topic"),
240 Some("my-topic")
241 );
242 }
243
244 #[test]
245 fn test_extract_links_with_trace_header() {
246 let extractor = SnsEventExtractor::new();
247 let event = create_test_sns_event_with_trace();
248
249 let links = extractor.extract_links(&event);
250
251 assert_eq!(links.len(), 1);
252 let link = &links[0];
253 assert!(link.span_context.is_valid());
254 assert_eq!(
255 link.span_context.trace_id().to_string(),
256 "5759e988bd862e3fe1be46a994272793"
257 );
258 assert_eq!(link.span_context.span_id().to_string(), "53995c3f42cd8ad8");
259 assert!(link.span_context.is_sampled());
260 }
261
262 #[test]
263 fn test_extract_links_no_trace_header() {
264 let extractor = SnsEventExtractor::new();
265
266 let mut sns_msg = SnsMessage::default();
267 sns_msg.sns_message_type = "Notification".to_string();
268 sns_msg.message_id = "msg-123".to_string();
269 sns_msg.topic_arn = "arn:aws:sns:us-east-1:123456789:my-topic".to_string();
270 sns_msg.timestamp = Utc::now();
271 sns_msg.signature_version = "1".to_string();
272 sns_msg.signature = "sig".to_string();
273 sns_msg.signing_cert_url = "https://cert".to_string();
274 sns_msg.unsubscribe_url = "https://unsub".to_string();
275 sns_msg.message = r#"{"test": "data"}"#.to_string();
276 sns_msg.message_attributes = HashMap::new();
277
278 let mut record = SnsRecord::default();
279 record.event_source = "aws:sns".to_string();
280 record.event_version = "1.0".to_string();
281 record.event_subscription_arn =
282 "arn:aws:sns:us-east-1:123456789:my-topic:sub-123".to_string();
283 record.sns = sns_msg;
284
285 let mut event = SnsEvent::default();
286 event.records = vec![record];
287
288 let links = extractor.extract_links(&event);
289 assert!(links.is_empty());
290 }
291
292 #[test]
293 fn test_parse_xray_trace_header() {
294 let header = "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1";
295
296 let ctx = parse_xray_trace_header(header).unwrap();
297
298 assert!(ctx.is_valid());
299 assert_eq!(
300 ctx.trace_id().to_string(),
301 "5759e988bd862e3fe1be46a994272793"
302 );
303 assert_eq!(ctx.span_id().to_string(), "53995c3f42cd8ad8");
304 assert!(ctx.is_sampled());
305 }
306}