opentelemetry_lambda_tower/extractors/
sns.rs1use crate::extractor::TraceContextExtractor;
8use aws_lambda_events::sns::{MessageAttribute, SnsEvent, SnsRecord};
9use lambda_runtime::Context as LambdaContext;
10use opentelemetry::Context;
11use opentelemetry::propagation::Extractor;
12use opentelemetry::trace::{
13 Link, SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState,
14};
15use opentelemetry_semantic_conventions::attribute::{
16 MESSAGING_BATCH_MESSAGE_COUNT, MESSAGING_DESTINATION_NAME, MESSAGING_MESSAGE_ID,
17 MESSAGING_OPERATION_TYPE, MESSAGING_SYSTEM,
18};
19use std::collections::HashMap;
20use tracing::Span;
21
22#[derive(Clone, Debug, Default)]
45pub struct SnsEventExtractor;
46
47impl SnsEventExtractor {
48 pub fn new() -> Self {
50 Self
51 }
52
53 fn topic_name_from_arn(arn: &str) -> Option<&str> {
57 arn.rsplit(':').next()
58 }
59}
60
61impl TraceContextExtractor<SnsEvent> for SnsEventExtractor {
62 fn extract_context(&self, _event: &SnsEvent) -> Context {
63 Context::current()
64 }
65
66 fn extract_links(&self, event: &SnsEvent) -> Vec<Link> {
67 event
68 .records
69 .iter()
70 .filter_map(extract_link_from_record)
71 .collect()
72 }
73
74 fn trigger_type(&self) -> &'static str {
75 "pubsub"
76 }
77
78 fn span_name(&self, event: &SnsEvent, lambda_ctx: &LambdaContext) -> String {
79 let topic_name = event
80 .records
81 .first()
82 .map(|r| &r.sns.topic_arn)
83 .and_then(|arn| Self::topic_name_from_arn(arn))
84 .unwrap_or(&lambda_ctx.env_config.function_name);
85
86 format!("{} process", topic_name)
87 }
88
89 fn record_attributes(&self, event: &SnsEvent, span: &Span) {
90 span.record(MESSAGING_SYSTEM, "aws_sns");
91 span.record(MESSAGING_OPERATION_TYPE, "process");
92
93 if let Some(record) = event.records.first() {
94 if let Some(topic_name) = Self::topic_name_from_arn(&record.sns.topic_arn) {
95 span.record(MESSAGING_DESTINATION_NAME, topic_name);
96 }
97
98 span.record(MESSAGING_MESSAGE_ID, record.sns.message_id.as_str());
99 }
100
101 span.record(MESSAGING_BATCH_MESSAGE_COUNT, event.records.len() as i64);
102 }
103}
104
105fn extract_link_from_record(record: &SnsRecord) -> Option<Link> {
110 if let Some(span_context) =
111 extract_trace_context_from_message_attributes(&record.sns.message_attributes)
112 {
113 return Some(Link::new(span_context, vec![], 0));
114 }
115
116 if let Some(trace_attr) = record.sns.message_attributes.get("AWSTraceHeader")
117 && !trace_attr.value.is_empty()
118 && let Some(span_context) = parse_xray_trace_header(&trace_attr.value)
119 {
120 return Some(Link::new(span_context, vec![], 0));
121 }
122
123 None
124}
125
126fn extract_trace_context_from_message_attributes(
131 message_attributes: &HashMap<String, MessageAttribute>,
132) -> Option<SpanContext> {
133 let extractor = SnsMessageAttributeExtractor(message_attributes);
134 let ctx =
135 opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&extractor));
136
137 let span_context = ctx.span().span_context().clone();
138 if span_context.is_valid() {
139 Some(span_context)
140 } else {
141 None
142 }
143}
144
145struct SnsMessageAttributeExtractor<'a>(&'a HashMap<String, MessageAttribute>);
147
148impl Extractor for SnsMessageAttributeExtractor<'_> {
149 fn get(&self, key: &str) -> Option<&str> {
150 self.0.get(key).map(|attr| attr.value.as_str())
151 }
152
153 fn keys(&self) -> Vec<&str> {
154 self.0.keys().map(|k| k.as_str()).collect()
155 }
156}
157
158fn parse_xray_trace_header(header: &str) -> Option<SpanContext> {
162 let mut trace_id_str = None;
163 let mut parent_id_str = None;
164 let mut sampled = false;
165
166 for part in header.split(';') {
167 let part = part.trim();
168 if let Some(root) = part.strip_prefix("Root=") {
169 trace_id_str = convert_xray_trace_id(root);
170 } else if let Some(parent) = part.strip_prefix("Parent=") {
171 parent_id_str = Some(parent.to_string());
172 } else if part == "Sampled=1" {
173 sampled = true;
174 }
175 }
176
177 let trace_id_hex = trace_id_str?;
178 let parent_id_hex = parent_id_str?;
179
180 let trace_id_bytes = hex_to_bytes::<16>(&trace_id_hex)?;
181 let trace_id = TraceId::from_bytes(trace_id_bytes);
182
183 let span_id_bytes = hex_to_bytes::<8>(&parent_id_hex)?;
184 let span_id = SpanId::from_bytes(span_id_bytes);
185
186 let flags = if sampled {
187 TraceFlags::SAMPLED
188 } else {
189 TraceFlags::default()
190 };
191
192 Some(SpanContext::new(
193 trace_id,
194 span_id,
195 flags,
196 true,
197 TraceState::default(),
198 ))
199}
200
201fn convert_xray_trace_id(xray_id: &str) -> Option<String> {
203 let parts: Vec<&str> = xray_id.split('-').collect();
204 if parts.len() == 3 && parts[0] == "1" {
205 let combined = format!("{}{}", parts[1], parts[2]);
206 if combined.len() == 32 {
207 return Some(combined);
208 }
209 }
210 None
211}
212
213fn hex_to_bytes<const N: usize>(hex: &str) -> Option<[u8; N]> {
215 if hex.len() != N * 2 {
216 return None;
217 }
218
219 let mut bytes = [0u8; N];
220 for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
221 let high = hex_char_to_nibble(chunk[0])?;
222 let low = hex_char_to_nibble(chunk[1])?;
223 bytes[i] = (high << 4) | low;
224 }
225 Some(bytes)
226}
227
228fn hex_char_to_nibble(c: u8) -> Option<u8> {
230 match c {
231 b'0'..=b'9' => Some(c - b'0'),
232 b'a'..=b'f' => Some(c - b'a' + 10),
233 b'A'..=b'F' => Some(c - b'A' + 10),
234 _ => None,
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use aws_lambda_events::sns::{MessageAttribute, SnsMessage};
242 use chrono::Utc;
243 use std::collections::HashMap;
244
245 fn create_test_sns_event_with_trace() -> SnsEvent {
246 let mut attrs = HashMap::new();
247 let mut trace_attr = MessageAttribute::default();
248 trace_attr.data_type = "String".to_string();
249 trace_attr.value =
250 "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1"
251 .to_string();
252 attrs.insert("AWSTraceHeader".to_string(), trace_attr);
253
254 let mut sns_msg = SnsMessage::default();
255 sns_msg.sns_message_type = "Notification".to_string();
256 sns_msg.message_id = "msg-123".to_string();
257 sns_msg.topic_arn = "arn:aws:sns:us-east-1:123456789:my-topic".to_string();
258 sns_msg.timestamp = Utc::now();
259 sns_msg.signature_version = "1".to_string();
260 sns_msg.signature = "sig".to_string();
261 sns_msg.signing_cert_url = "https://cert".to_string();
262 sns_msg.unsubscribe_url = "https://unsub".to_string();
263 sns_msg.message = r#"{"test": "data"}"#.to_string();
264 sns_msg.message_attributes = attrs;
265
266 let mut record = SnsRecord::default();
267 record.event_source = "aws:sns".to_string();
268 record.event_version = "1.0".to_string();
269 record.event_subscription_arn =
270 "arn:aws:sns:us-east-1:123456789:my-topic:sub-123".to_string();
271 record.sns = sns_msg;
272
273 let mut event = SnsEvent::default();
274 event.records = vec![record];
275 event
276 }
277
278 #[test]
279 fn test_trigger_type() {
280 let extractor = SnsEventExtractor::new();
281 assert_eq!(extractor.trigger_type(), "pubsub");
282 }
283
284 #[test]
285 fn test_topic_name_from_arn() {
286 assert_eq!(
287 SnsEventExtractor::topic_name_from_arn("arn:aws:sns:us-east-1:123456789:my-topic"),
288 Some("my-topic")
289 );
290 }
291
292 #[test]
293 fn test_extract_links_with_trace_header() {
294 let extractor = SnsEventExtractor::new();
295 let event = create_test_sns_event_with_trace();
296
297 let links = extractor.extract_links(&event);
298
299 assert_eq!(links.len(), 1);
300 let link = &links[0];
301 assert!(link.span_context.is_valid());
302 assert_eq!(
303 link.span_context.trace_id().to_string(),
304 "5759e988bd862e3fe1be46a994272793"
305 );
306 assert_eq!(link.span_context.span_id().to_string(), "53995c3f42cd8ad8");
307 assert!(link.span_context.is_sampled());
308 }
309
310 #[test]
311 fn test_extract_links_no_trace_header() {
312 let extractor = SnsEventExtractor::new();
313
314 let mut sns_msg = SnsMessage::default();
315 sns_msg.sns_message_type = "Notification".to_string();
316 sns_msg.message_id = "msg-123".to_string();
317 sns_msg.topic_arn = "arn:aws:sns:us-east-1:123456789:my-topic".to_string();
318 sns_msg.timestamp = Utc::now();
319 sns_msg.signature_version = "1".to_string();
320 sns_msg.signature = "sig".to_string();
321 sns_msg.signing_cert_url = "https://cert".to_string();
322 sns_msg.unsubscribe_url = "https://unsub".to_string();
323 sns_msg.message = r#"{"test": "data"}"#.to_string();
324 sns_msg.message_attributes = HashMap::new();
325
326 let mut record = SnsRecord::default();
327 record.event_source = "aws:sns".to_string();
328 record.event_version = "1.0".to_string();
329 record.event_subscription_arn =
330 "arn:aws:sns:us-east-1:123456789:my-topic:sub-123".to_string();
331 record.sns = sns_msg;
332
333 let mut event = SnsEvent::default();
334 event.records = vec![record];
335
336 let links = extractor.extract_links(&event);
337 assert!(links.is_empty());
338 }
339
340 #[test]
341 fn test_parse_xray_trace_header() {
342 let header = "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1";
343
344 let ctx = parse_xray_trace_header(header).unwrap();
345
346 assert!(ctx.is_valid());
347 assert_eq!(
348 ctx.trace_id().to_string(),
349 "5759e988bd862e3fe1be46a994272793"
350 );
351 assert_eq!(ctx.span_id().to_string(), "53995c3f42cd8ad8");
352 assert!(ctx.is_sampled());
353 }
354}