opentelemetry_lambda_extension/
context.rs1use crate::config::CorrelationConfig;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Instant;
11use tokio::sync::{RwLock, watch};
12
13pub type RequestId = String;
15
16#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct SpanContext {
19 pub trace_id: String,
21 pub span_id: String,
23 pub trace_flags: u8,
25 pub trace_state: Option<String>,
27}
28
29#[derive(Debug, Clone)]
31pub struct PlatformEvent {
32 pub event_type: PlatformEventType,
34 pub timestamp: chrono::DateTime<chrono::Utc>,
36 pub request_id: RequestId,
38 pub data: serde_json::Value,
40}
41
42#[non_exhaustive]
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum PlatformEventType {
46 InitStart,
48 InitRuntimeDone,
50 InitReport,
52 Start,
54 RuntimeDone,
56 Report,
58}
59
60struct InvocationContext {
62 request_id: RequestId,
63 started_at: Instant,
64 parent_span_context: Option<SpanContext>,
65 pending_platform_events: Vec<PlatformEvent>,
66 parent_ready_tx: watch::Sender<Option<SpanContext>>,
67 parent_ready_rx: watch::Receiver<Option<SpanContext>>,
68}
69
70impl InvocationContext {
71 fn new(request_id: RequestId) -> Self {
72 let (parent_ready_tx, parent_ready_rx) = watch::channel(None);
73 Self {
74 request_id,
75 started_at: Instant::now(),
76 parent_span_context: None,
77 pending_platform_events: Vec::new(),
78 parent_ready_tx,
79 parent_ready_rx,
80 }
81 }
82}
83
84pub struct InvocationContextManager {
90 contexts: Arc<RwLock<HashMap<RequestId, InvocationContext>>>,
91 config: CorrelationConfig,
92}
93
94impl InvocationContextManager {
95 pub fn new(config: CorrelationConfig) -> Self {
97 Self {
98 contexts: Arc::new(RwLock::new(HashMap::new())),
99 config,
100 }
101 }
102
103 pub fn with_defaults() -> Self {
105 Self::new(CorrelationConfig::default())
106 }
107
108 pub async fn register_invocation(&self, request_id: RequestId) {
112 let mut contexts = self.contexts.write().await;
113
114 if contexts.len() >= self.config.max_total_buffered_events {
115 self.cleanup_stale_contexts_locked(&mut contexts);
116 }
117
118 if !contexts.contains_key(&request_id) {
119 contexts.insert(request_id.clone(), InvocationContext::new(request_id));
120 }
121 }
122
123 pub async fn set_parent_span(&self, request_id: &str, context: SpanContext) {
128 let mut contexts = self.contexts.write().await;
129
130 if let Some(inv_ctx) = contexts.get_mut(request_id) {
131 inv_ctx.parent_span_context = Some(context.clone());
132 let _ = inv_ctx.parent_ready_tx.send(Some(context));
133 }
134 }
135
136 pub async fn wait_for_parent_span(&self, request_id: &str) -> Option<SpanContext> {
141 let rx = {
142 let contexts = self.contexts.read().await;
143 contexts.get(request_id)?.parent_ready_rx.clone()
144 };
145
146 let timeout = self.config.max_correlation_delay;
147
148 match tokio::time::timeout(timeout, async {
149 let mut rx = rx;
150 loop {
151 if rx.borrow().is_some() {
152 return rx.borrow().clone();
153 }
154 if rx.changed().await.is_err() {
155 return None;
156 }
157 }
158 })
159 .await
160 {
161 Ok(ctx) => ctx,
162 Err(_) => {
163 let contexts = self.contexts.read().await;
164 contexts
165 .get(request_id)
166 .and_then(|c| c.parent_span_context.clone())
167 }
168 }
169 }
170
171 pub async fn get_parent_span(&self, request_id: &str) -> Option<SpanContext> {
173 let contexts = self.contexts.read().await;
174 contexts
175 .get(request_id)
176 .and_then(|c| c.parent_span_context.clone())
177 }
178
179 pub async fn add_platform_event(&self, event: PlatformEvent) {
183 let mut contexts = self.contexts.write().await;
184
185 if let Some(inv_ctx) = contexts.get_mut(&event.request_id) {
186 if inv_ctx.pending_platform_events.len()
187 < self.config.max_buffered_events_per_invocation
188 {
189 inv_ctx.pending_platform_events.push(event);
190 } else {
191 tracing::warn!(
192 request_id = %inv_ctx.request_id,
193 "Platform event buffer full, dropping event"
194 );
195 }
196 } else {
197 let mut ctx = InvocationContext::new(event.request_id.clone());
198 ctx.pending_platform_events.push(event);
199 contexts.insert(ctx.request_id.clone(), ctx);
200 }
201 }
202
203 pub async fn take_platform_events(&self, request_id: &str) -> Vec<PlatformEvent> {
207 let mut contexts = self.contexts.write().await;
208
209 if let Some(inv_ctx) = contexts.get_mut(request_id) {
210 std::mem::take(&mut inv_ctx.pending_platform_events)
211 } else {
212 Vec::new()
213 }
214 }
215
216 pub async fn remove_invocation(&self, request_id: &str) {
220 let mut contexts = self.contexts.write().await;
221 contexts.remove(request_id);
222 }
223
224 pub async fn active_count(&self) -> usize {
226 let contexts = self.contexts.read().await;
227 contexts.len()
228 }
229
230 pub async fn cleanup_stale_contexts(&self) {
234 let mut contexts = self.contexts.write().await;
235 self.cleanup_stale_contexts_locked(&mut contexts);
236 }
237
238 fn cleanup_stale_contexts_locked(&self, contexts: &mut HashMap<RequestId, InvocationContext>) {
239 let max_lifetime = self.config.max_invocation_lifetime;
240 let now = Instant::now();
241
242 contexts.retain(|request_id, ctx| {
243 let age = now.duration_since(ctx.started_at);
244 if age > max_lifetime {
245 tracing::debug!(
246 request_id = %request_id,
247 age_secs = age.as_secs(),
248 "Removing stale invocation context"
249 );
250 false
251 } else {
252 true
253 }
254 });
255 }
256
257 pub fn emit_orphaned_spans(&self) -> bool {
259 self.config.emit_orphaned_spans
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use std::time::Duration;
267
268 fn test_config() -> CorrelationConfig {
269 CorrelationConfig {
270 max_correlation_delay: Duration::from_millis(100),
271 max_buffered_events_per_invocation: 10,
272 max_total_buffered_events: 100,
273 max_invocation_lifetime: Duration::from_secs(60),
274 emit_orphaned_spans: true,
275 }
276 }
277
278 #[tokio::test]
279 async fn test_register_and_get_parent_span() {
280 let manager = InvocationContextManager::new(test_config());
281
282 manager.register_invocation("req-123".to_string()).await;
283
284 assert!(manager.get_parent_span("req-123").await.is_none());
285
286 let span_ctx = SpanContext {
287 trace_id: "0102030405060708090a0b0c0d0e0f10".to_string(),
288 span_id: "0102030405060708".to_string(),
289 trace_flags: 1,
290 trace_state: None,
291 };
292
293 manager.set_parent_span("req-123", span_ctx.clone()).await;
294
295 let retrieved = manager.get_parent_span("req-123").await;
296 assert_eq!(retrieved, Some(span_ctx));
297 }
298
299 #[tokio::test]
300 async fn test_wait_for_parent_span_immediate() {
301 let manager = InvocationContextManager::new(test_config());
302
303 manager.register_invocation("req-456".to_string()).await;
304
305 let span_ctx = SpanContext {
306 trace_id: "trace-id".to_string(),
307 span_id: "span-id".to_string(),
308 trace_flags: 1,
309 trace_state: None,
310 };
311
312 manager.set_parent_span("req-456", span_ctx.clone()).await;
313
314 let result = manager.wait_for_parent_span("req-456").await;
315 assert_eq!(result, Some(span_ctx));
316 }
317
318 #[tokio::test]
319 async fn test_wait_for_parent_span_delayed() {
320 let manager = Arc::new(InvocationContextManager::new(test_config()));
321
322 manager.register_invocation("req-789".to_string()).await;
323
324 let manager_clone = manager.clone();
325 let set_handle = tokio::spawn(async move {
326 tokio::time::sleep(Duration::from_millis(20)).await;
327 let span_ctx = SpanContext {
328 trace_id: "delayed-trace".to_string(),
329 span_id: "delayed-span".to_string(),
330 trace_flags: 1,
331 trace_state: None,
332 };
333 manager_clone.set_parent_span("req-789", span_ctx).await;
334 });
335
336 let result = manager.wait_for_parent_span("req-789").await;
337 set_handle.await.unwrap();
338
339 assert!(result.is_some());
340 assert_eq!(result.unwrap().trace_id, "delayed-trace");
341 }
342
343 #[tokio::test]
344 async fn test_wait_for_parent_span_timeout() {
345 let mut config = test_config();
346 config.max_correlation_delay = Duration::from_millis(10);
347
348 let manager = InvocationContextManager::new(config);
349 manager.register_invocation("req-timeout".to_string()).await;
350
351 let result = manager.wait_for_parent_span("req-timeout").await;
352 assert!(result.is_none());
353 }
354
355 #[tokio::test]
356 async fn test_platform_events() {
357 let manager = InvocationContextManager::new(test_config());
358
359 let event = PlatformEvent {
360 event_type: PlatformEventType::Start,
361 timestamp: chrono::Utc::now(),
362 request_id: "req-events".to_string(),
363 data: serde_json::json!({"requestId": "req-events"}),
364 };
365
366 manager.add_platform_event(event.clone()).await;
367
368 let events = manager.take_platform_events("req-events").await;
369 assert_eq!(events.len(), 1);
370 assert_eq!(events[0].request_id, "req-events");
371
372 let events_again = manager.take_platform_events("req-events").await;
373 assert!(events_again.is_empty());
374 }
375
376 #[tokio::test]
377 async fn test_remove_invocation() {
378 let manager = InvocationContextManager::new(test_config());
379
380 manager.register_invocation("req-remove".to_string()).await;
381 assert_eq!(manager.active_count().await, 1);
382
383 manager.remove_invocation("req-remove").await;
384 assert_eq!(manager.active_count().await, 0);
385 }
386
387 #[tokio::test]
388 async fn test_cleanup_stale_contexts() {
389 let mut config = test_config();
390 config.max_invocation_lifetime = Duration::from_millis(10);
391
392 let manager = InvocationContextManager::new(config);
393
394 manager.register_invocation("req-stale".to_string()).await;
395 assert_eq!(manager.active_count().await, 1);
396
397 tokio::time::sleep(Duration::from_millis(20)).await;
398
399 manager.cleanup_stale_contexts().await;
400 assert_eq!(manager.active_count().await, 0);
401 }
402
403 #[tokio::test]
404 async fn test_event_buffer_limit() {
405 let mut config = test_config();
406 config.max_buffered_events_per_invocation = 2;
407
408 let manager = InvocationContextManager::new(config);
409 manager.register_invocation("req-limit".to_string()).await;
410
411 for i in 0..5 {
412 let event = PlatformEvent {
413 event_type: PlatformEventType::Start,
414 timestamp: chrono::Utc::now(),
415 request_id: "req-limit".to_string(),
416 data: serde_json::json!({"index": i}),
417 };
418 manager.add_platform_event(event).await;
419 }
420
421 let events = manager.take_platform_events("req-limit").await;
422 assert_eq!(events.len(), 2);
423 }
424}