opentelemetry_lambda_extension/
service.rs

1//! Tower services for Lambda extension lifecycle and telemetry processing.
2//!
3//! This module provides Tower `Service` implementations that integrate with the
4//! `lambda_extension` crate for proper lifecycle management. Using the official
5//! Lambda extension library ensures correct handling of SHUTDOWN events and
6//! telemetry delivery timing.
7//!
8//! The services use a shared `RwLock` to coordinate shutdown with telemetry
9//! processing. The `TelemetryService` holds a read lock while processing events,
10//! and the `EventsService` acquires a write lock on SHUTDOWN before performing
11//! the final flush. This ensures all in-flight telemetry is processed before
12//! shutdown completes.
13
14use crate::aggregator::SignalAggregator;
15use crate::config::Config;
16use crate::conversion::{MetricsConverter, TelemetryProcessor};
17use crate::exporter::OtlpExporter;
18use crate::flush::FlushManager;
19use crate::receiver::Signal;
20use lambda_extension::{Error, LambdaEvent, LambdaTelemetry, LambdaTelemetryRecord, NextEvent};
21use opentelemetry_proto::tonic::resource::v1::Resource;
22use std::future::Future;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26use std::time::Duration;
27use tokio::sync::{Mutex, RwLock, oneshot};
28use tower::Service;
29
30/// Shared state for extension services.
31///
32/// This holds the components that need to be shared between the events
33/// processor and telemetry processor services.
34pub struct ExtensionState {
35    pub(crate) aggregator: Arc<SignalAggregator>,
36    pub(crate) exporter: Arc<OtlpExporter>,
37    pub(crate) flush_manager: Arc<Mutex<FlushManager>>,
38    pub(crate) telemetry_processor: Arc<Mutex<TelemetryProcessor>>,
39    pub(crate) metrics_converter: MetricsConverter,
40    #[allow(dead_code)]
41    pub(crate) config: Config,
42    /// Lock to coordinate shutdown with telemetry processing.
43    ///
44    /// `TelemetryService` acquires a read lock while processing events.
45    /// `EventsService` acquires a write lock on SHUTDOWN before final flush.
46    /// This ensures all in-flight telemetry is processed before shutdown.
47    processing_lock: RwLock<()>,
48    /// Channel to signal that shutdown processing is complete.
49    ///
50    /// The sender is stored in a Mutex so it can be taken when shutdown occurs.
51    /// The receiver should be used with `tokio::select!` to exit the event loop.
52    shutdown_tx: Mutex<Option<oneshot::Sender<()>>>,
53}
54
55impl ExtensionState {
56    /// Creates new extension state with the given configuration and resource.
57    ///
58    /// Returns the state and a receiver that will be signalled when shutdown is complete.
59    /// Use the receiver with `tokio::select!` to exit the event loop gracefully.
60    pub fn new(
61        config: Config,
62        resource: Resource,
63    ) -> Result<(Self, oneshot::Receiver<()>), crate::exporter::ExportError> {
64        let exporter = OtlpExporter::new(config.exporter.clone())?;
65        let (shutdown_tx, shutdown_rx) = oneshot::channel();
66
67        let state = Self {
68            aggregator: Arc::new(SignalAggregator::new(config.flush.clone())),
69            exporter: Arc::new(exporter),
70            flush_manager: Arc::new(Mutex::new(FlushManager::new(config.flush.clone()))),
71            telemetry_processor: Arc::new(Mutex::new(TelemetryProcessor::new(resource.clone()))),
72            metrics_converter: MetricsConverter::new(resource),
73            config,
74            processing_lock: RwLock::new(()),
75            shutdown_tx: Mutex::new(Some(shutdown_tx)),
76        };
77
78        Ok((state, shutdown_rx))
79    }
80
81    /// Signals that shutdown processing is complete.
82    ///
83    /// This should be called after `final_flush()` to allow the event loop to exit.
84    pub async fn signal_shutdown_complete(&self) {
85        if let Some(tx) = self.shutdown_tx.lock().await.take() {
86            let _ = tx.send(());
87            tracing::debug!("Shutdown complete signal sent");
88        }
89    }
90
91    /// Performs a flush of all pending signals to the exporter.
92    pub async fn flush_all(&self) {
93        let batches = self.aggregator.get_all_batches().await;
94        let mut flush_manager = self.flush_manager.lock().await;
95
96        for batch in batches {
97            let result = self.exporter.export(batch).await;
98            match result {
99                crate::exporter::ExportResult::Success => {
100                    flush_manager.record_flush();
101                }
102                crate::exporter::ExportResult::Fallback
103                | crate::exporter::ExportResult::Skipped => {
104                    flush_manager.record_flush_timeout();
105                }
106            }
107        }
108    }
109
110    /// Waits for any in-progress telemetry processing to complete.
111    ///
112    /// This acquires a write lock on the processing lock, which blocks until
113    /// all read locks (held by `TelemetryService` during processing) are released.
114    /// The timeout prevents indefinite blocking if something goes wrong.
115    pub async fn wait_for_processing_complete(&self, timeout: Duration) {
116        let result = tokio::time::timeout(timeout, self.processing_lock.write()).await;
117        if result.is_err() {
118            tracing::warn!(
119                timeout_ms = timeout.as_millis(),
120                "Timed out waiting for telemetry processing to complete"
121            );
122        }
123        // Lock is immediately dropped, we just needed to wait for it
124    }
125
126    /// Performs a final flush draining all signals.
127    pub async fn final_flush(&self) {
128        tracing::info!("Performing final flush");
129
130        let batches = self.aggregator.drain_all().await;
131        let count = batches.len();
132
133        for batch in batches {
134            let result = self.exporter.export(batch).await;
135            tracing::debug!(?result, "Final flush batch");
136        }
137
138        let dropped = self.aggregator.dropped_count().await;
139        if dropped > 0 {
140            tracing::warn!(
141                dropped = dropped,
142                "Signals were dropped due to queue limits"
143            );
144        }
145
146        tracing::info!(batches = count, dropped = dropped, "Final flush complete");
147    }
148}
149
150/// Tower service for processing Lambda extension lifecycle events.
151///
152/// This service handles INVOKE and SHUTDOWN events from the Extensions API.
153/// On SHUTDOWN, it performs a final flush of all buffered telemetry.
154pub struct EventsService {
155    state: Arc<ExtensionState>,
156}
157
158impl EventsService {
159    /// Creates a new events service with the given shared state.
160    pub fn new(state: Arc<ExtensionState>) -> Self {
161        Self { state }
162    }
163}
164
165impl Service<LambdaEvent> for EventsService {
166    type Response = ();
167    type Error = Error;
168    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
169
170    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
171        Poll::Ready(Ok(()))
172    }
173
174    fn call(&mut self, event: LambdaEvent) -> Self::Future {
175        let state = Arc::clone(&self.state);
176
177        Box::pin(async move {
178            match event.next {
179                NextEvent::Invoke(invoke) => {
180                    tracing::debug!(request_id = %invoke.request_id, "Received INVOKE event");
181
182                    // Record invocation for adaptive flush pattern detection
183                    {
184                        let mut flush_manager = state.flush_manager.lock().await;
185                        flush_manager.record_invocation();
186                    }
187
188                    // Check if we should flush based on pending count
189                    let pending = state.aggregator.pending_count().await;
190                    let should_flush = {
191                        let flush_manager = state.flush_manager.lock().await;
192                        flush_manager
193                            .should_flush(Some(invoke.deadline_ms as i64), pending, false)
194                            .is_some()
195                    };
196
197                    if should_flush {
198                        tracing::debug!(pending, "Flushing during invocation");
199                        state.flush_all().await;
200                    }
201                }
202                NextEvent::Shutdown(shutdown) => {
203                    tracing::info!(reason = ?shutdown.shutdown_reason, "Received SHUTDOWN event");
204
205                    // Wait for any in-flight telemetry processing to complete
206                    // This ensures we don't flush before the last batch of telemetry
207                    // (e.g., platform.report) has been processed and added to the aggregator
208                    state
209                        .wait_for_processing_complete(Duration::from_millis(500))
210                        .await;
211
212                    // Emit shutdown metric
213                    let shutdown_reason = format!("{:?}", shutdown.shutdown_reason);
214                    let shutdown_metric = state
215                        .metrics_converter
216                        .create_shutdown_metric(&shutdown_reason);
217                    state.aggregator.add(Signal::Metrics(shutdown_metric)).await;
218
219                    // Final flush of all signals
220                    state.final_flush().await;
221
222                    // Signal shutdown complete to exit the event loop gracefully
223                    state.signal_shutdown_complete().await;
224                }
225            }
226
227            Ok(())
228        })
229    }
230}
231
232/// Tower service for processing Lambda Telemetry API events.
233///
234/// This service receives platform telemetry events and converts them to
235/// OTLP metrics and traces, adding them to the aggregator for export.
236#[derive(Clone)]
237pub struct TelemetryService {
238    state: Arc<ExtensionState>,
239}
240
241impl TelemetryService {
242    /// Creates a new telemetry service with the given shared state.
243    pub fn new(state: Arc<ExtensionState>) -> Self {
244        Self { state }
245    }
246}
247
248impl Service<Vec<LambdaTelemetry>> for TelemetryService {
249    type Response = ();
250    type Error = Error;
251    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
252
253    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
254        Poll::Ready(Ok(()))
255    }
256
257    fn call(&mut self, events: Vec<LambdaTelemetry>) -> Self::Future {
258        let state = Arc::clone(&self.state);
259
260        Box::pin(async move {
261            // Acquire read lock to prevent shutdown from flushing while we're processing
262            let _guard = state.processing_lock.read().await;
263
264            tracing::debug!(count = events.len(), "Processing telemetry events");
265
266            // Check if any event is a RuntimeDone (signals end of invocation)
267            let has_runtime_done = events
268                .iter()
269                .any(|e| matches!(e.record, LambdaTelemetryRecord::PlatformRuntimeDone { .. }));
270
271            // Convert lambda_extension telemetry events to our internal format
272            let internal_events = convert_telemetry_events(events);
273
274            // Process through our TelemetryProcessor
275            let (metrics, _traces) = {
276                let mut processor = state.telemetry_processor.lock().await;
277                processor.process_events(internal_events)
278            };
279
280            // Add metrics to aggregator
281            for metric in metrics {
282                state
283                    .aggregator
284                    .add(Signal::Metrics(
285                        opentelemetry_proto::tonic::collector::metrics::v1::ExportMetricsServiceRequest {
286                            resource_metrics: metric.resource_metrics,
287                        },
288                    ))
289                    .await;
290            }
291
292            // If we received RuntimeDone, this is the post-invoke phase.
293            // Check if we should flush based on the strategy (e.g., FlushStrategy::End).
294            if has_runtime_done {
295                let pending = state.aggregator.pending_count().await;
296                let should_flush = {
297                    let flush_manager = state.flush_manager.lock().await;
298                    flush_manager
299                        .should_flush_on_invocation_end(pending)
300                        .is_some()
301                };
302
303                if should_flush {
304                    tracing::debug!(pending, "Flushing at end of invocation (post-invoke phase)");
305                    state.flush_all().await;
306                }
307            }
308
309            Ok(())
310        })
311    }
312}
313
314/// Converts lambda_extension telemetry events to our internal format.
315fn convert_telemetry_events(events: Vec<LambdaTelemetry>) -> Vec<crate::telemetry::TelemetryEvent> {
316    use crate::telemetry::{
317        ReportMetrics, ReportRecord, RuntimeDoneRecord, RuntimeMetrics, SpanRecord, StartRecord,
318        TelemetryEvent, TracingRecord,
319    };
320
321    events
322        .into_iter()
323        .filter_map(|event| {
324            let time = event.time.to_rfc3339();
325
326            match event.record {
327                LambdaTelemetryRecord::PlatformStart {
328                    request_id,
329                    version,
330                    tracing,
331                } => Some(TelemetryEvent::Start {
332                    time,
333                    record: StartRecord {
334                        request_id,
335                        version,
336                        tracing: tracing.map(|t| TracingRecord {
337                            span_id: None,
338                            trace_type: Some(format!("{:?}", t.r#type)),
339                            value: Some(t.value),
340                        }),
341                    },
342                }),
343
344                LambdaTelemetryRecord::PlatformRuntimeDone {
345                    request_id,
346                    status,
347                    error_type: _,
348                    metrics,
349                    spans,
350                    tracing,
351                } => Some(TelemetryEvent::RuntimeDone {
352                    time,
353                    record: RuntimeDoneRecord {
354                        request_id,
355                        status: format!("{:?}", status),
356                        metrics: metrics.map(|m| RuntimeMetrics {
357                            duration_ms: m.duration_ms,
358                            produced_bytes: m.produced_bytes,
359                        }),
360                        spans: spans
361                            .into_iter()
362                            .map(|s| SpanRecord {
363                                name: s.name,
364                                start: s.start.timestamp_millis() as f64,
365                                duration_ms: s.duration_ms,
366                            })
367                            .collect(),
368                        tracing: tracing.map(|t| TracingRecord {
369                            span_id: None,
370                            trace_type: Some(format!("{:?}", t.r#type)),
371                            value: Some(t.value),
372                        }),
373                    },
374                }),
375
376                LambdaTelemetryRecord::PlatformReport {
377                    request_id,
378                    status,
379                    error_type: _,
380                    metrics,
381                    spans: _,
382                    tracing,
383                } => Some(TelemetryEvent::Report {
384                    time,
385                    record: ReportRecord {
386                        request_id,
387                        status: format!("{:?}", status),
388                        metrics: ReportMetrics {
389                            duration_ms: metrics.duration_ms,
390                            billed_duration_ms: metrics.billed_duration_ms,
391                            memory_size_mb: metrics.memory_size_mb,
392                            max_memory_used_mb: metrics.max_memory_used_mb,
393                            init_duration_ms: metrics.init_duration_ms,
394                            restore_duration_ms: metrics.restore_duration_ms,
395                        },
396                        tracing: tracing.map(|t| TracingRecord {
397                            span_id: None,
398                            trace_type: Some(format!("{:?}", t.r#type)),
399                            value: Some(t.value),
400                        }),
401                    },
402                }),
403
404                // Log other events but don't convert them
405                _ => {
406                    tracing::trace!(?event, "Ignoring non-platform telemetry event");
407                    None
408                }
409            }
410        })
411        .collect()
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use lambda_extension::LambdaTelemetry;
418
419    #[test]
420    fn test_extension_state_creation() {
421        let config = Config::default();
422        let sdk_resource = crate::resource::detect_resource();
423        let proto_resource = crate::resource::to_proto_resource(&sdk_resource);
424
425        // This will fail if exporter can't be created, but that's fine for unit tests
426        let result = ExtensionState::new(config, proto_resource);
427        assert!(result.is_ok());
428        let (_state, _shutdown_rx) = result.unwrap();
429    }
430
431    #[test]
432    fn test_simulator_telemetry_format_deserialization() {
433        // This is the exact format our simulator sends
434        let json = r#"[{"time":"2025-11-30T22:29:09.581655Z","type":"platform.start","record":{"requestId":"38432cb4-cb8b-4162-982d-923d3c3f6d10","tracing":{"type":"X-Amzn-Trace-Id","value":"Root=1-692cc535-0338d3516cb745b7b41f878e"},"version":"$LATEST"}}]"#;
435
436        let result: Result<Vec<LambdaTelemetry>, _> = serde_json::from_str(json);
437        match &result {
438            Ok(events) => println!("Success: {:?}", events),
439            Err(e) => println!("Error: {}", e),
440        }
441        assert!(result.is_ok(), "Failed to deserialize: {:?}", result.err());
442    }
443
444    #[test]
445    fn test_full_simulator_batch_deserialization() {
446        // Full batch similar to what the test produces
447        let json = r#"[{"time":"2025-11-30T22:35:51.565094Z","type":"platform.start","record":{"requestId":"0c90003a-8970-474c-b696-fca5336ef4f5","tracing":{"type":"X-Amzn-Trace-Id","value":"Root=1-692cc6c7-f2ce8d3383524609b99c07a9"},"version":"$LATEST"}},{"time":"2025-11-30T22:35:51.565857Z","type":"platform.initRuntimeDone","record":{"initializationType":"on-demand","phase":"init","status":"success"}},{"time":"2025-11-30T22:35:51.565857Z","type":"platform.initReport","record":{"initializationType":"on-demand","phase":"init","status":"success","metrics":{"durationMs":565.4}}},{"time":"2025-11-30T22:35:51.578834Z","type":"platform.runtimeDone","record":{"requestId":"0c90003a-8970-474c-b696-fca5336ef4f5","status":"success","metrics":{"durationMs":13.74},"spans":[],"tracing":{"type":"X-Amzn-Trace-Id","value":"Root=1-692cc6c7-f2ce8d3383524609b99c07a9"}}},{"time":"2025-11-30T22:35:51.578909Z","type":"platform.report","record":{"requestId":"0c90003a-8970-474c-b696-fca5336ef4f5","status":"success","metrics":{"durationMs":13.74,"billedDurationMs":100,"memorySizeMB":128,"maxMemoryUsedMB":64},"tracing":{"type":"X-Amzn-Trace-Id","value":"Root=1-692cc6c7-f2ce8d3383524609b99c07a9"}}}]"#;
448
449        let result: Result<Vec<LambdaTelemetry>, _> = serde_json::from_str(json);
450        match &result {
451            Ok(events) => {
452                println!("Success: {} events parsed", events.len());
453                for (i, event) in events.iter().enumerate() {
454                    println!("  Event {}: {:?}", i, event);
455                }
456            }
457            Err(e) => println!("Error: {}", e),
458        }
459        assert!(result.is_ok(), "Failed to deserialize: {:?}", result.err());
460    }
461}