use crate::extractors::{set_common_attributes, set_response_attributes, SpanAttributesExtractor};
use crate::TelemetryCompletionHandler;
use futures_util::future::BoxFuture;
use lambda_runtime::{Error, LambdaEvent};
use serde::{de::DeserializeOwned, Serialize};
use std::future::Future;
use std::sync::atomic::{AtomicBool, Ordering};
use tracing::field::Empty;
use tracing::Instrument;
use tracing_opentelemetry::OpenTelemetrySpanExt;
static IS_COLD_START: AtomicBool = AtomicBool::new(true);
pub type TracedHandler<T, R> =
Box<dyn Fn(LambdaEvent<T>) -> BoxFuture<'static, Result<R, Error>> + Send + Sync>;
pub(crate) async fn traced_handler<T, R, F, Fut>(
name: &'static str,
event: LambdaEvent<T>,
completion_handler: TelemetryCompletionHandler,
handler_fn: F,
) -> Result<R, Error>
where
T: SpanAttributesExtractor + DeserializeOwned + Serialize + Send + 'static,
R: Serialize + Send + 'static,
F: FnOnce(LambdaEvent<T>) -> Fut,
Fut: Future<Output = Result<R, Error>> + Send,
{
let result = {
let span = tracing::info_span!(
parent: None,
"handler",
otel.name=Empty,
otel.kind=Empty,
otel.status_code=Empty,
otel.status_message=Empty,
requestId=%event.context.request_id,
);
span.record("otel.name", name.to_string());
span.record("otel.kind", "SERVER");
let is_cold = IS_COLD_START.swap(false, Ordering::Relaxed);
set_common_attributes(&span, &event.context, is_cold);
let attrs = event.payload.extract_span_attributes();
if let Some(span_name) = attrs.span_name {
span.record("otel.name", span_name);
}
if let Some(kind) = &attrs.kind {
span.record("otel.kind", kind.to_string());
}
for (key, value) in &attrs.attributes {
span.set_attribute(key.to_string(), value.to_string());
}
for link in attrs.links {
span.add_link_with_attributes(link.span_context, link.attributes);
}
if let Some(carrier) = attrs.carrier {
let parent_context = opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.extract(&carrier)
});
let _ = span.set_parent(parent_context);
}
span.set_attribute("faas.trigger", attrs.trigger.to_string());
let result = handler_fn(event).instrument(span.clone()).await;
if let Ok(response) = &result {
if let Ok(value) = serde_json::to_value(response) {
set_response_attributes(&span, &value);
}
} else if let Err(error) = &result {
span.set_status(opentelemetry::trace::Status::error(error.to_string()));
}
result
};
completion_handler.complete();
result
}
pub fn create_traced_handler<T, R, F, Fut>(
name: &'static str,
completion_handler: TelemetryCompletionHandler,
handler_fn: F,
) -> TracedHandler<T, R>
where
T: SpanAttributesExtractor + DeserializeOwned + Serialize + Send + 'static,
R: Serialize + Send + 'static,
F: Fn(LambdaEvent<T>) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<R, Error>> + Send + 'static,
{
Box::new(move |event: LambdaEvent<T>| {
let completion_handler = completion_handler.clone();
let handler_fn = handler_fn.clone();
Box::pin(traced_handler(name, event, completion_handler, handler_fn))
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mode::ProcessorMode;
use lambda_runtime::Context;
use opentelemetry::trace::Status;
use opentelemetry::trace::TracerProvider as _;
use opentelemetry_sdk::{
trace::{SdkTracerProvider, SpanData, SpanExporter},
Resource,
};
use serde_json::Value;
use serial_test::serial;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
};
use std::time::Duration;
use tracing_subscriber::prelude::*;
#[derive(Debug, Default, Clone)]
struct TestExporter {
export_count: Arc<AtomicUsize>,
spans: Arc<Mutex<Vec<SpanData>>>,
}
impl TestExporter {
fn new() -> Self {
Self {
export_count: Arc::new(AtomicUsize::new(0)),
spans: Arc::new(Mutex::new(Vec::new())),
}
}
fn get_spans(&self) -> Vec<SpanData> {
self.spans.lock().unwrap().clone()
}
fn find_attribute(span: &SpanData, key: &str) -> Option<String> {
span.attributes
.iter()
.find(|kv| kv.key.as_str() == key)
.map(|kv| kv.value.to_string())
}
}
impl SpanExporter for TestExporter {
fn export(
&self,
spans: Vec<SpanData>,
) -> impl std::future::Future<Output = opentelemetry_sdk::error::OTelSdkResult> + Send
{
self.export_count.fetch_add(spans.len(), Ordering::SeqCst);
self.spans.lock().unwrap().extend(spans);
futures_util::future::ready(Ok(()))
}
}
fn setup_test_provider() -> (
Arc<SdkTracerProvider>,
Arc<TestExporter>,
tracing::dispatcher::DefaultGuard,
) {
let exporter = Arc::new(TestExporter::new());
let provider = SdkTracerProvider::builder()
.with_simple_exporter(exporter.as_ref().clone())
.with_resource(Resource::builder().build())
.build();
let subscriber = tracing_subscriber::registry::Registry::default()
.with(tracing_opentelemetry::OpenTelemetryLayer::new(
provider.tracer("test"),
))
.set_default();
(Arc::new(provider), exporter, subscriber)
}
async fn wait_for_spans(duration: Duration) {
tokio::time::sleep(duration).await;
}
#[tokio::test]
#[serial]
async fn test_successful_response() -> Result<(), Error> {
let (provider, exporter, _guard) = setup_test_provider();
let completion_handler =
TelemetryCompletionHandler::new(provider, None, ProcessorMode::Sync);
async fn handler(_: LambdaEvent<Value>) -> Result<Value, Error> {
Ok(serde_json::json!({ "statusCode": 200, "body": "Success" }))
}
let traced_handler = create_traced_handler("test-handler", completion_handler, handler);
let event = LambdaEvent::new(serde_json::json!({}), Context::default());
let result = traced_handler(event).await?;
wait_for_spans(Duration::from_millis(100)).await;
let spans = exporter.get_spans();
assert!(!spans.is_empty(), "No spans were exported");
let span = &spans[0];
assert_eq!(span.name, "test-handler", "Unexpected span name");
assert_eq!(
TestExporter::find_attribute(span, "http.status_code"),
Some("200".to_string())
);
assert_eq!(result["statusCode"], 200);
Ok(())
}
#[tokio::test]
#[serial]
async fn test_error_response() -> Result<(), Error> {
let (provider, exporter, _guard) = setup_test_provider();
let completion_handler =
TelemetryCompletionHandler::new(provider, None, ProcessorMode::Sync);
async fn handler(_: LambdaEvent<Value>) -> Result<Value, Error> {
Ok(serde_json::json!({
"statusCode": 500,
"body": "Internal Server Error"
}))
}
let traced_handler = create_traced_handler("test-handler", completion_handler, handler);
let event = LambdaEvent::new(serde_json::json!({}), Context::default());
let result = traced_handler(event).await?;
wait_for_spans(Duration::from_millis(100)).await;
let spans = exporter.get_spans();
assert!(!spans.is_empty(), "No spans were exported");
let span = &spans[0];
assert_eq!(span.name, "test-handler", "Unexpected span name");
assert_eq!(
TestExporter::find_attribute(span, "http.status_code"),
Some("500".to_string())
);
assert!(matches!(span.status, Status::Error { .. }));
assert_eq!(result["statusCode"], 500);
Ok(())
}
#[tokio::test]
#[serial]
async fn test_handler_reuse() -> Result<(), Error> {
let (provider, exporter, _guard) = setup_test_provider();
let completion_handler =
TelemetryCompletionHandler::new(provider, None, ProcessorMode::Sync);
async fn handler(_: LambdaEvent<Value>) -> Result<Value, Error> {
Ok(serde_json::json!({ "status": "ok" }))
}
let traced_handler = create_traced_handler("test-handler", completion_handler, handler);
for _ in 0..3 {
let event = LambdaEvent::new(serde_json::json!({}), Context::default());
let _ = traced_handler(event).await?;
}
wait_for_spans(Duration::from_millis(100)).await;
let spans = exporter.get_spans();
assert_eq!(spans.len(), 3, "Expected exactly 3 spans");
for span in spans {
assert_eq!(span.name, "test-handler", "Unexpected span name");
}
Ok(())
}
#[tokio::test]
#[serial]
async fn test_handler_with_closure() -> Result<(), Error> {
let (provider, exporter, _guard) = setup_test_provider();
let completion_handler =
TelemetryCompletionHandler::new(provider, None, ProcessorMode::Sync);
let prefix = "test-prefix".to_string();
let handler = move |_event: LambdaEvent<Value>| {
let prefix = prefix.clone();
async move {
Ok(serde_json::json!({
"statusCode": 200,
"prefix": prefix
}))
}
};
let traced_handler = create_traced_handler("test-handler", completion_handler, handler);
let event = LambdaEvent::new(serde_json::json!({}), Context::default());
let result = traced_handler(event).await?;
wait_for_spans(Duration::from_millis(100)).await;
let spans = exporter.get_spans();
assert!(!spans.is_empty(), "No spans were exported");
assert_eq!(result["prefix"], "test-prefix");
assert_eq!(spans[0].name, "test-handler", "Unexpected span name");
Ok(())
}
}