opentelemetry_lambda_extension/
context.rs

1//! Invocation context management for span correlation.
2//!
3//! This module provides the central coordinator for correlating platform spans
4//! with function spans. It uses state-based correlation with watch channels
5//! for timeout-bounded waiting.
6
7use crate::config::CorrelationConfig;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Instant;
11use tokio::sync::{RwLock, watch};
12
13/// Unique identifier for a Lambda invocation request.
14pub type RequestId = String;
15
16/// Span context information extracted from a parent span.
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct SpanContext {
19    /// The trace ID (16 bytes as hex string).
20    pub trace_id: String,
21    /// The span ID (8 bytes as hex string).
22    pub span_id: String,
23    /// Trace flags (sampled, etc.).
24    pub trace_flags: u8,
25    /// Trace state (vendor-specific data).
26    pub trace_state: Option<String>,
27}
28
29/// Platform event received from the Lambda Telemetry API.
30#[derive(Debug, Clone)]
31pub struct PlatformEvent {
32    /// The type of platform event.
33    pub event_type: PlatformEventType,
34    /// Timestamp when the event occurred.
35    pub timestamp: chrono::DateTime<chrono::Utc>,
36    /// The request ID this event belongs to.
37    pub request_id: RequestId,
38    /// Additional event-specific data.
39    pub data: serde_json::Value,
40}
41
42/// Types of platform events from the Telemetry API.
43#[non_exhaustive]
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum PlatformEventType {
46    /// Function initialization started.
47    InitStart,
48    /// Runtime initialization completed.
49    InitRuntimeDone,
50    /// Init phase report.
51    InitReport,
52    /// Invocation started.
53    Start,
54    /// Runtime completed processing.
55    RuntimeDone,
56    /// Invocation report with metrics.
57    Report,
58}
59
60/// Context for a single invocation, tracking correlation state.
61struct InvocationContext {
62    request_id: RequestId,
63    started_at: Instant,
64    parent_span_context: Option<SpanContext>,
65    pending_platform_events: Vec<PlatformEvent>,
66    parent_ready_tx: watch::Sender<Option<SpanContext>>,
67    parent_ready_rx: watch::Receiver<Option<SpanContext>>,
68}
69
70impl InvocationContext {
71    fn new(request_id: RequestId) -> Self {
72        let (parent_ready_tx, parent_ready_rx) = watch::channel(None);
73        Self {
74            request_id,
75            started_at: Instant::now(),
76            parent_span_context: None,
77            pending_platform_events: Vec::new(),
78            parent_ready_tx,
79            parent_ready_rx,
80        }
81    }
82}
83
84/// Central coordinator for correlating platform spans with function spans.
85///
86/// This manager maintains state for each active invocation and provides
87/// methods for registering parent spans from function telemetry and
88/// waiting for correlation with platform events.
89pub struct InvocationContextManager {
90    contexts: Arc<RwLock<HashMap<RequestId, InvocationContext>>>,
91    config: CorrelationConfig,
92}
93
94impl InvocationContextManager {
95    /// Creates a new context manager with the given configuration.
96    pub fn new(config: CorrelationConfig) -> Self {
97        Self {
98            contexts: Arc::new(RwLock::new(HashMap::new())),
99            config,
100        }
101    }
102
103    /// Creates a new context manager with default configuration.
104    pub fn with_defaults() -> Self {
105        Self::new(CorrelationConfig::default())
106    }
107
108    /// Registers a new invocation context.
109    ///
110    /// Call this when receiving an INVOKE event from the Extensions API.
111    pub async fn register_invocation(&self, request_id: RequestId) {
112        let mut contexts = self.contexts.write().await;
113
114        if contexts.len() >= self.config.max_total_buffered_events {
115            self.cleanup_stale_contexts_locked(&mut contexts);
116        }
117
118        if !contexts.contains_key(&request_id) {
119            contexts.insert(request_id.clone(), InvocationContext::new(request_id));
120        }
121    }
122
123    /// Sets the parent span context for an invocation.
124    ///
125    /// Call this when a span with `faas.parent_span = true` is received
126    /// from the function's OTLP telemetry.
127    pub async fn set_parent_span(&self, request_id: &str, context: SpanContext) {
128        let mut contexts = self.contexts.write().await;
129
130        if let Some(inv_ctx) = contexts.get_mut(request_id) {
131            inv_ctx.parent_span_context = Some(context.clone());
132            let _ = inv_ctx.parent_ready_tx.send(Some(context));
133        }
134    }
135
136    /// Waits for the parent span context to become available.
137    ///
138    /// This method waits up to `max_correlation_delay` for the parent span
139    /// to be registered. Returns `None` if the timeout expires.
140    pub async fn wait_for_parent_span(&self, request_id: &str) -> Option<SpanContext> {
141        let rx = {
142            let contexts = self.contexts.read().await;
143            contexts.get(request_id)?.parent_ready_rx.clone()
144        };
145
146        let timeout = self.config.max_correlation_delay;
147
148        match tokio::time::timeout(timeout, async {
149            let mut rx = rx;
150            loop {
151                if rx.borrow().is_some() {
152                    return rx.borrow().clone();
153                }
154                if rx.changed().await.is_err() {
155                    return None;
156                }
157            }
158        })
159        .await
160        {
161            Ok(ctx) => ctx,
162            Err(_) => {
163                let contexts = self.contexts.read().await;
164                contexts
165                    .get(request_id)
166                    .and_then(|c| c.parent_span_context.clone())
167            }
168        }
169    }
170
171    /// Gets the current parent span context if available (non-blocking).
172    pub async fn get_parent_span(&self, request_id: &str) -> Option<SpanContext> {
173        let contexts = self.contexts.read().await;
174        contexts
175            .get(request_id)
176            .and_then(|c| c.parent_span_context.clone())
177    }
178
179    /// Adds a platform event to the invocation context.
180    ///
181    /// Events are buffered until the parent span is available for correlation.
182    pub async fn add_platform_event(&self, event: PlatformEvent) {
183        let mut contexts = self.contexts.write().await;
184
185        if let Some(inv_ctx) = contexts.get_mut(&event.request_id) {
186            if inv_ctx.pending_platform_events.len()
187                < self.config.max_buffered_events_per_invocation
188            {
189                inv_ctx.pending_platform_events.push(event);
190            } else {
191                tracing::warn!(
192                    request_id = %inv_ctx.request_id,
193                    "Platform event buffer full, dropping event"
194                );
195            }
196        } else {
197            let mut ctx = InvocationContext::new(event.request_id.clone());
198            ctx.pending_platform_events.push(event);
199            contexts.insert(ctx.request_id.clone(), ctx);
200        }
201    }
202
203    /// Takes all pending platform events for an invocation.
204    ///
205    /// Returns the events and clears the buffer.
206    pub async fn take_platform_events(&self, request_id: &str) -> Vec<PlatformEvent> {
207        let mut contexts = self.contexts.write().await;
208
209        if let Some(inv_ctx) = contexts.get_mut(request_id) {
210            std::mem::take(&mut inv_ctx.pending_platform_events)
211        } else {
212            Vec::new()
213        }
214    }
215
216    /// Removes an invocation context.
217    ///
218    /// Call this after the invocation is complete and all data has been flushed.
219    pub async fn remove_invocation(&self, request_id: &str) {
220        let mut contexts = self.contexts.write().await;
221        contexts.remove(request_id);
222    }
223
224    /// Returns the number of active invocation contexts.
225    pub async fn active_count(&self) -> usize {
226        let contexts = self.contexts.read().await;
227        contexts.len()
228    }
229
230    /// Cleans up stale invocation contexts.
231    ///
232    /// Removes contexts that have exceeded `max_invocation_lifetime`.
233    pub async fn cleanup_stale_contexts(&self) {
234        let mut contexts = self.contexts.write().await;
235        self.cleanup_stale_contexts_locked(&mut contexts);
236    }
237
238    fn cleanup_stale_contexts_locked(&self, contexts: &mut HashMap<RequestId, InvocationContext>) {
239        let max_lifetime = self.config.max_invocation_lifetime;
240        let now = Instant::now();
241
242        contexts.retain(|request_id, ctx| {
243            let age = now.duration_since(ctx.started_at);
244            if age > max_lifetime {
245                tracing::debug!(
246                    request_id = %request_id,
247                    age_secs = age.as_secs(),
248                    "Removing stale invocation context"
249                );
250                false
251            } else {
252                true
253            }
254        });
255    }
256
257    /// Returns whether to emit orphaned spans (spans without parent context).
258    pub fn emit_orphaned_spans(&self) -> bool {
259        self.config.emit_orphaned_spans
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use std::time::Duration;
267
268    fn test_config() -> CorrelationConfig {
269        CorrelationConfig {
270            max_correlation_delay: Duration::from_millis(100),
271            max_buffered_events_per_invocation: 10,
272            max_total_buffered_events: 100,
273            max_invocation_lifetime: Duration::from_secs(60),
274            emit_orphaned_spans: true,
275        }
276    }
277
278    #[tokio::test]
279    async fn test_register_and_get_parent_span() {
280        let manager = InvocationContextManager::new(test_config());
281
282        manager.register_invocation("req-123".to_string()).await;
283
284        assert!(manager.get_parent_span("req-123").await.is_none());
285
286        let span_ctx = SpanContext {
287            trace_id: "0102030405060708090a0b0c0d0e0f10".to_string(),
288            span_id: "0102030405060708".to_string(),
289            trace_flags: 1,
290            trace_state: None,
291        };
292
293        manager.set_parent_span("req-123", span_ctx.clone()).await;
294
295        let retrieved = manager.get_parent_span("req-123").await;
296        assert_eq!(retrieved, Some(span_ctx));
297    }
298
299    #[tokio::test]
300    async fn test_wait_for_parent_span_immediate() {
301        let manager = InvocationContextManager::new(test_config());
302
303        manager.register_invocation("req-456".to_string()).await;
304
305        let span_ctx = SpanContext {
306            trace_id: "trace-id".to_string(),
307            span_id: "span-id".to_string(),
308            trace_flags: 1,
309            trace_state: None,
310        };
311
312        manager.set_parent_span("req-456", span_ctx.clone()).await;
313
314        let result = manager.wait_for_parent_span("req-456").await;
315        assert_eq!(result, Some(span_ctx));
316    }
317
318    #[tokio::test]
319    async fn test_wait_for_parent_span_delayed() {
320        let manager = Arc::new(InvocationContextManager::new(test_config()));
321
322        manager.register_invocation("req-789".to_string()).await;
323
324        let manager_clone = manager.clone();
325        let set_handle = tokio::spawn(async move {
326            tokio::time::sleep(Duration::from_millis(20)).await;
327            let span_ctx = SpanContext {
328                trace_id: "delayed-trace".to_string(),
329                span_id: "delayed-span".to_string(),
330                trace_flags: 1,
331                trace_state: None,
332            };
333            manager_clone.set_parent_span("req-789", span_ctx).await;
334        });
335
336        let result = manager.wait_for_parent_span("req-789").await;
337        set_handle.await.unwrap();
338
339        assert!(result.is_some());
340        assert_eq!(result.unwrap().trace_id, "delayed-trace");
341    }
342
343    #[tokio::test]
344    async fn test_wait_for_parent_span_timeout() {
345        let mut config = test_config();
346        config.max_correlation_delay = Duration::from_millis(10);
347
348        let manager = InvocationContextManager::new(config);
349        manager.register_invocation("req-timeout".to_string()).await;
350
351        let result = manager.wait_for_parent_span("req-timeout").await;
352        assert!(result.is_none());
353    }
354
355    #[tokio::test]
356    async fn test_platform_events() {
357        let manager = InvocationContextManager::new(test_config());
358
359        let event = PlatformEvent {
360            event_type: PlatformEventType::Start,
361            timestamp: chrono::Utc::now(),
362            request_id: "req-events".to_string(),
363            data: serde_json::json!({"requestId": "req-events"}),
364        };
365
366        manager.add_platform_event(event.clone()).await;
367
368        let events = manager.take_platform_events("req-events").await;
369        assert_eq!(events.len(), 1);
370        assert_eq!(events[0].request_id, "req-events");
371
372        let events_again = manager.take_platform_events("req-events").await;
373        assert!(events_again.is_empty());
374    }
375
376    #[tokio::test]
377    async fn test_remove_invocation() {
378        let manager = InvocationContextManager::new(test_config());
379
380        manager.register_invocation("req-remove".to_string()).await;
381        assert_eq!(manager.active_count().await, 1);
382
383        manager.remove_invocation("req-remove").await;
384        assert_eq!(manager.active_count().await, 0);
385    }
386
387    #[tokio::test]
388    async fn test_cleanup_stale_contexts() {
389        let mut config = test_config();
390        config.max_invocation_lifetime = Duration::from_millis(10);
391
392        let manager = InvocationContextManager::new(config);
393
394        manager.register_invocation("req-stale".to_string()).await;
395        assert_eq!(manager.active_count().await, 1);
396
397        tokio::time::sleep(Duration::from_millis(20)).await;
398
399        manager.cleanup_stale_contexts().await;
400        assert_eq!(manager.active_count().await, 0);
401    }
402
403    #[tokio::test]
404    async fn test_event_buffer_limit() {
405        let mut config = test_config();
406        config.max_buffered_events_per_invocation = 2;
407
408        let manager = InvocationContextManager::new(config);
409        manager.register_invocation("req-limit".to_string()).await;
410
411        for i in 0..5 {
412            let event = PlatformEvent {
413                event_type: PlatformEventType::Start,
414                timestamp: chrono::Utc::now(),
415                request_id: "req-limit".to_string(),
416                data: serde_json::json!({"index": i}),
417            };
418            manager.add_platform_event(event).await;
419        }
420
421        let events = manager.take_platform_events("req-limit").await;
422        assert_eq!(events.len(), 2);
423    }
424}