use crate::extractors::{set_common_attributes, set_response_attributes, SpanAttributesExtractor};
use crate::TelemetryCompletionHandler;
use futures_util::ready;
use lambda_runtime::{Error, LambdaEvent};
use opentelemetry::trace::Status;
use pin_project::pin_project;
use serde::{de::DeserializeOwned, Serialize};
use std::marker::PhantomData;
use std::{
future::Future,
pin::Pin,
task::{self, Poll},
};
use tower::{Layer, Service};
use tracing::{field::Empty, instrument::Instrumented, Instrument};
use tracing_opentelemetry::OpenTelemetrySpanExt;
#[pin_project]
pub struct CompletionFuture<Fut> {
#[pin]
future: Option<Fut>,
completion_handler: Option<TelemetryCompletionHandler>,
span: Option<tracing::Span>,
}
impl<Fut, R> Future for CompletionFuture<Fut>
where
Fut: Future<Output = Result<R, Error>>,
R: Serialize + Send + 'static,
{
type Output = Result<R, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let ready = ready!(self
.as_mut()
.project()
.future
.as_pin_mut()
.expect("future polled after completion")
.poll(cx));
if let Ok(response) = &ready {
if let Ok(value) = serde_json::to_value(response) {
if let Some(span) = self.span.as_ref() {
set_response_attributes(span, &value);
}
}
} else if let Err(error) = &ready {
if let Some(span) = self.span.as_ref() {
span.set_status(Status::error(error.to_string()));
}
}
Pin::set(&mut self.as_mut().project().future, None);
let this = self.project();
this.span.take();
if let Some(handler) = this.completion_handler.take() {
handler.complete();
}
Poll::Ready(ready)
}
}
#[derive(Clone)]
pub struct OtelTracingLayer<T: SpanAttributesExtractor> {
completion_handler: TelemetryCompletionHandler,
name: String,
_phantom: PhantomData<T>,
}
impl<T: SpanAttributesExtractor> OtelTracingLayer<T> {
pub fn new(completion_handler: TelemetryCompletionHandler) -> Self {
Self {
completion_handler,
name: "lambda-invocation".to_string(),
_phantom: PhantomData,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
}
impl<S, T> Layer<S> for OtelTracingLayer<T>
where
T: SpanAttributesExtractor + Clone,
{
type Service = OtelTracingService<S, T>;
fn layer(&self, inner: S) -> Self::Service {
OtelTracingService::<S, T> {
inner,
completion_handler: self.completion_handler.clone(),
name: self.name.clone(),
is_cold_start: true,
_phantom: PhantomData,
}
}
}
#[derive(Clone)]
pub struct OtelTracingService<S, T: SpanAttributesExtractor> {
inner: S,
completion_handler: TelemetryCompletionHandler,
name: String,
is_cold_start: bool,
_phantom: PhantomData<T>,
}
impl<S, F, T, R> Service<LambdaEvent<T>> for OtelTracingService<S, T>
where
S: Service<LambdaEvent<T>, Response = R, Error = Error, Future = F> + Send,
F: Future<Output = Result<R, Error>> + Send + 'static,
T: SpanAttributesExtractor + DeserializeOwned + Serialize + Send + 'static,
R: Serialize + Send + 'static,
{
type Response = R;
type Error = Error;
type Future = CompletionFuture<Instrumented<S::Future>>;
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, event: LambdaEvent<T>) -> Self::Future {
let span = tracing::info_span!(
parent: None,
"handler",
otel.name=Empty,
otel.kind=Empty,
otel.status_code=Empty,
otel.status_message=Empty,
requestId=%event.context.request_id,
);
span.record("otel.name", self.name.clone());
span.record("otel.kind", "SERVER");
set_common_attributes(&span, &event.context, self.is_cold_start);
if self.is_cold_start {
self.is_cold_start = false;
}
let attrs = event.payload.extract_span_attributes();
if let Some(span_name) = attrs.span_name {
span.record("otel.name", span_name);
}
if let Some(kind) = &attrs.kind {
span.record("otel.kind", kind.to_string());
}
for (key, value) in &attrs.attributes {
span.set_attribute(key.to_string(), value.to_string());
}
for link in attrs.links {
span.add_link_with_attributes(link.span_context, link.attributes);
}
if let Some(carrier) = attrs.carrier {
let parent_context = opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.extract(&carrier)
});
let _ = span.set_parent(parent_context);
}
span.set_attribute("faas.trigger", attrs.trigger.to_string());
let future = {
let _guard = span.enter();
self.inner.call(event)
};
CompletionFuture {
future: Some(future.instrument(span.clone())),
completion_handler: Some(self.completion_handler.clone()),
span: Some(span),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ProcessorMode;
use lambda_runtime::Context;
use opentelemetry::trace::TracerProvider as _;
use opentelemetry_sdk::{
trace::{SdkTracerProvider, SpanData, SpanExporter},
Resource,
};
use serial_test::serial;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tower::ServiceExt;
use tracing_subscriber::prelude::*;
#[derive(Debug)]
struct CountingExporter {
export_count: Arc<AtomicUsize>,
}
impl CountingExporter {
fn new() -> Self {
Self {
export_count: Arc::new(AtomicUsize::new(0)),
}
}
}
impl SpanExporter for CountingExporter {
fn export(
&self,
batch: Vec<SpanData>,
) -> impl std::future::Future<Output = opentelemetry_sdk::error::OTelSdkResult> + Send
{
self.export_count.fetch_add(batch.len(), Ordering::SeqCst);
futures_util::future::ready(Ok(()))
}
fn shutdown(&mut self) -> opentelemetry_sdk::error::OTelSdkResult {
Ok(())
}
}
#[tokio::test]
#[serial]
async fn test_basic_layer() -> Result<(), Error> {
let exporter = CountingExporter::new();
let export_count = exporter.export_count.clone();
let provider = SdkTracerProvider::builder()
.with_simple_exporter(exporter)
.with_resource(Resource::builder().build())
.build();
let provider = Arc::new(provider);
let _subscriber = tracing_subscriber::registry::Registry::default()
.with(tracing_opentelemetry::OpenTelemetryLayer::new(
provider.tracer("test"),
))
.set_default();
let completion_handler =
TelemetryCompletionHandler::new(provider.clone(), None, ProcessorMode::Sync);
let handler = |_: LambdaEvent<serde_json::Value>| async {
let _span = tracing::info_span!("test_span");
Ok::<_, Error>(serde_json::json!({"status": "ok"}))
};
let layer = OtelTracingLayer::new(completion_handler).with_name("test-handler");
let mut svc = tower::ServiceBuilder::new()
.layer(layer)
.service_fn(handler);
let event = LambdaEvent::new(serde_json::json!({}), Context::default());
let _ = svc.ready().await?.call(event).await?;
tokio::time::sleep(Duration::from_millis(500)).await;
assert!(export_count.load(Ordering::SeqCst) > 0);
Ok(())
}
}