use crate::logger::Logger;
use crate::ProcessorMode;
use lambda_extension::{service_fn, Error, Extension, NextEvent};
use opentelemetry_sdk::trace::SdkTracerProvider;
use std::sync::Arc;
use tokio::{
signal::unix::{signal, SignalKind},
sync::{
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
Mutex,
},
};
static LOGGER: Logger = Logger::const_new("extension");
pub struct OtelInternalExtension {
request_done_receiver: Mutex<UnboundedReceiver<()>>,
tracer_provider: Arc<SdkTracerProvider>,
}
impl OtelInternalExtension {
pub fn new(
request_done_receiver: UnboundedReceiver<()>,
tracer_provider: Arc<SdkTracerProvider>,
) -> Self {
Self {
request_done_receiver: Mutex::new(request_done_receiver),
tracer_provider,
}
}
pub async fn invoke(&self, event: lambda_extension::LambdaEvent) -> Result<(), Error> {
if let NextEvent::Invoke(_e) = event.next {
self.request_done_receiver
.lock()
.await
.recv()
.await
.ok_or_else(|| Error::from("channel closed"))?;
if let Err(err) = self.tracer_provider.force_flush() {
LOGGER.error(format!(
"OtelInternalExtension.invoke.Error: Error flushing tracer provider: {err:?}"
));
}
}
Ok(())
}
}
pub(crate) async fn register_extension(
tracer_provider: Arc<SdkTracerProvider>,
processor_mode: ProcessorMode,
) -> Result<UnboundedSender<()>, Error> {
LOGGER.debug("OtelInternalExtension.register_extension: starting registration");
let (request_done_sender, request_done_receiver) = unbounded_channel::<()>();
let extension = Arc::new(OtelInternalExtension::new(
request_done_receiver,
tracer_provider.clone(),
));
let mut ext = Extension::new();
if matches!(processor_mode, ProcessorMode::Async) {
ext = ext.with_events(&["INVOKE"]);
} else {
ext = ext.with_events(&[]);
}
let registered_extension = ext
.with_events_processor(service_fn(move |event| {
let extension = extension.clone();
async move { extension.invoke(event).await }
}))
.with_extension_name("otel-internal")
.register()
.await?;
tokio::spawn(async move {
if let Err(err) = registered_extension.run().await {
LOGGER.error(format!(
"OtelInternalExtension.run.Error: Error running extension: {err:?}"
));
}
});
tokio::spawn(async move {
let mut sigterm = signal(SignalKind::terminate()).unwrap();
if sigterm.recv().await.is_some() {
LOGGER.debug("OtelInternalExtension.SIGTERM: SIGTERM received, flushing spans");
if let Err(err) = tracer_provider.force_flush() {
LOGGER.error(format!(
"OtelInternalExtension.SIGTERM.Error: Error during shutdown: {err:?}"
));
}
LOGGER.debug("OtelInternalExtension.SIGTERM: Shutdown complete");
std::process::exit(0);
}
});
Ok(request_done_sender)
}
#[cfg(test)]
mod tests {
use super::*;
use lambda_extension::{InvokeEvent, LambdaEvent, ShutdownEvent};
use opentelemetry::trace::{Tracer, TracerProvider as _};
use opentelemetry::Context;
use opentelemetry_sdk::{
error::{OTelSdkError, OTelSdkResult},
trace::{SdkTracerProvider, Span, SpanData, SpanExporter, SpanProcessor},
Resource,
};
use std::sync::{
atomic::{AtomicUsize, Ordering},
Mutex,
};
use std::time::Duration;
#[derive(Debug, Default, Clone)]
struct TestExporter {
export_count: Arc<AtomicUsize>,
spans: Arc<Mutex<Vec<SpanData>>>,
}
impl TestExporter {
fn new() -> Self {
Self {
export_count: Arc::new(AtomicUsize::new(0)),
spans: Arc::new(Mutex::new(Vec::new())),
}
}
#[allow(dead_code)]
fn get_spans(&self) -> Vec<SpanData> {
self.spans.lock().unwrap().clone()
}
}
impl SpanExporter for TestExporter {
fn export(
&self,
spans: Vec<SpanData>,
) -> impl std::future::Future<Output = opentelemetry_sdk::error::OTelSdkResult> + Send
{
self.export_count.fetch_add(spans.len(), Ordering::SeqCst);
self.spans.lock().unwrap().extend(spans);
futures_util::future::ready(Ok(()))
}
}
fn setup_test_provider() -> (Arc<SdkTracerProvider>, Arc<TestExporter>) {
let exporter = TestExporter::new();
let provider = SdkTracerProvider::builder()
.with_simple_exporter(exporter.clone())
.with_resource(Resource::builder_empty().build())
.build();
(Arc::new(provider), Arc::new(exporter))
}
fn setup_batch_test_provider() -> (Arc<SdkTracerProvider>, Arc<TestExporter>) {
let exporter = TestExporter::new();
let provider = SdkTracerProvider::builder()
.with_batch_exporter(exporter.clone())
.with_resource(Resource::builder_empty().build())
.build();
(Arc::new(provider), Arc::new(exporter))
}
#[derive(Debug)]
struct FailingSpanProcessor;
impl SpanProcessor for FailingSpanProcessor {
fn on_start(&self, _span: &mut Span, _cx: &Context) {}
fn on_end(&self, _span: SpanData) {}
fn force_flush(&self) -> OTelSdkResult {
Err(OTelSdkError::InternalFailure("force flush failed".into()))
}
fn shutdown_with_timeout(&self, _timeout: Duration) -> OTelSdkResult {
Ok(())
}
}
#[tokio::test]
async fn test_extension_invoke_handling() -> Result<(), Error> {
let (provider, _) = setup_test_provider();
let (sender, receiver) = unbounded_channel();
let extension = OtelInternalExtension::new(receiver, provider);
let invoke_event = InvokeEvent {
deadline_ms: 1000,
request_id: "test-id".to_string(),
invoked_function_arn: "test-arn".to_string(),
tracing: Default::default(),
};
let event = LambdaEvent {
next: NextEvent::Invoke(invoke_event),
};
let handle = tokio::spawn(async move { extension.invoke(event).await });
sender.send(()).unwrap();
let result = handle.await.unwrap();
assert!(result.is_ok());
Ok(())
}
#[tokio::test]
async fn test_extension_channel_closed() -> Result<(), Error> {
let (provider, _) = setup_test_provider();
let (sender, receiver) = unbounded_channel();
let extension = OtelInternalExtension::new(receiver, provider);
let invoke_event = InvokeEvent {
deadline_ms: 1000,
request_id: "test-id".to_string(),
invoked_function_arn: "test-arn".to_string(),
tracing: Default::default(),
};
let event = LambdaEvent {
next: NextEvent::Invoke(invoke_event),
};
drop(sender);
let result = extension.invoke(event).await;
assert!(result.is_err());
Ok(())
}
#[tokio::test]
async fn test_extension_ignores_shutdown_events() -> Result<(), Error> {
let (provider, _) = setup_test_provider();
let (_sender, receiver) = unbounded_channel();
let extension = OtelInternalExtension::new(receiver, provider);
let event = LambdaEvent {
next: NextEvent::Shutdown(ShutdownEvent {
shutdown_reason: "SPINDOWN".to_string(),
deadline_ms: 1000,
}),
};
let result = extension.invoke(event).await;
assert!(result.is_ok());
Ok(())
}
#[tokio::test]
async fn test_extension_force_flush_exports_pending_spans() -> Result<(), Error> {
let (provider, exporter) = setup_batch_test_provider();
let tracer = provider.tracer("extension-test");
let (sender, receiver) = unbounded_channel();
{
let span = tracer.start("pending-span");
drop(span);
}
let extension = OtelInternalExtension::new(receiver, provider);
let event = LambdaEvent {
next: NextEvent::Invoke(InvokeEvent {
deadline_ms: 1000,
request_id: "test-id".to_string(),
invoked_function_arn: "test-arn".to_string(),
tracing: Default::default(),
}),
};
let handle = tokio::spawn(async move { extension.invoke(event).await });
sender.send(()).unwrap();
let result = handle.await.unwrap();
assert!(result.is_ok());
assert_eq!(exporter.export_count.load(Ordering::SeqCst), 1);
assert_eq!(exporter.get_spans().len(), 1);
Ok(())
}
#[tokio::test]
async fn test_extension_invoke_returns_ok_when_force_flush_fails() -> Result<(), Error> {
let provider = Arc::new(
SdkTracerProvider::builder()
.with_span_processor(FailingSpanProcessor)
.with_resource(Resource::builder_empty().build())
.build(),
);
let (sender, receiver) = unbounded_channel();
let extension = OtelInternalExtension::new(receiver, provider);
let event = LambdaEvent {
next: NextEvent::Invoke(InvokeEvent {
deadline_ms: 1000,
request_id: "test-id".to_string(),
invoked_function_arn: "test-arn".to_string(),
tracing: Default::default(),
}),
};
let handle = tokio::spawn(async move { extension.invoke(event).await });
sender.send(()).unwrap();
let result = handle.await.unwrap();
assert!(result.is_ok());
Ok(())
}
}