lambda_otel_lite/
layer.rs

1//! Tower middleware for OpenTelemetry tracing in AWS Lambda functions.
2//!
3//! This module provides a Tower middleware layer that automatically creates OpenTelemetry spans
4//! for Lambda invocations. It supports automatic extraction of span attributes from common AWS
5//! event types and allows for custom attribute extraction through a flexible trait system.
6//!
7//! # When to Use the Tower Layer
8//!
9//! The Tower layer approach is recommended when:
10//! - You need middleware composition (e.g., combining with other Tower layers)
11//! - You want a more idiomatic Rust approach using traits
12//! - Your application has complex middleware requirements
13//! - You're already using Tower in your application
14//!
15//! For simpler use cases, consider using the handler wrapper approach instead.
16//!
17//! # Architecture
18//!
19//! The layer operates by wrapping a Lambda service with OpenTelemetry instrumentation:
20//! 1. Creates a span for each Lambda invocation
21//! 2. Extracts attributes from the event using either:
22//!    - Built-in implementations of `SpanAttributesExtractor` for supported event types
23//!    - Custom implementations of `SpanAttributesExtractor` for user-defined types
24//!    - A closure-based extractor for one-off customizations
25//! 3. Propagates context from incoming requests via headers
26//! 4. Tracks response status codes
27//! 5. Signals completion for span export through the `TelemetryCompletionHandler`
28//!
29//! # Features
30//!
31//! - Automatic span creation for Lambda invocations
32//! - Built-in support for common AWS event types:
33//!   - API Gateway v1/v2 (HTTP method, path, route, protocol)
34//!   - Application Load Balancer (HTTP method, path, target group ARN)
35//! - Extensible attribute extraction through the `SpanAttributesExtractor` trait
36//! - Custom attribute extraction through closure-based extractors
37//! - Automatic context propagation from HTTP headers
38//! - Response status code tracking
39//!
40//! # Basic Usage
41//!
42//! ```no_run
43//! use lambda_otel_lite::{init_telemetry, OtelTracingLayer, TelemetryConfig};
44//! use lambda_runtime::{service_fn, Error, LambdaEvent, Runtime};
45//! use aws_lambda_events::event::apigw::ApiGatewayV2httpRequest;
46//! use tower::ServiceBuilder;
47//!
48//! async fn function_handler(event: LambdaEvent<ApiGatewayV2httpRequest>) -> Result<serde_json::Value, Error> {
49//!     Ok(serde_json::json!({"statusCode": 200}))
50//! }
51//!
52//! #[tokio::main]
53//! async fn main() -> Result<(), Error> {
54//!     let (_, completion_handler) = init_telemetry(TelemetryConfig::default()).await?;
55//!     
56//!     let service = ServiceBuilder::new()
57//!         .layer(OtelTracingLayer::new(completion_handler)
58//!             .with_name("my-handler"))
59//!         .service_fn(function_handler);
60//!
61//!     Runtime::new(service).run().await
62//! }
63//! ```
64//!
65//! # Custom Attribute Extraction
66//!
67//! You can implement the `SpanAttributesExtractor` trait for your own event types:
68//!
69//! ```rust,no_run
70//! use lambda_otel_lite::{SpanAttributes, SpanAttributesExtractor};
71//! use std::collections::HashMap;
72//! use opentelemetry::Value;
73//! struct MyEvent {
74//!     user_id: String,
75//! }
76//!
77//! impl SpanAttributesExtractor for MyEvent {
78//!     fn extract_span_attributes(&self) -> SpanAttributes {
79//!         let mut attributes = HashMap::new();
80//!         attributes.insert("user.id".to_string(), Value::String(self.user_id.clone().into()));
81//!         SpanAttributes {
82//!             attributes,
83//!             ..SpanAttributes::default()
84//!         }
85//!     }
86//! }
87//! ```
88//!
89//! # Context Propagation
90//!
91//! The layer automatically extracts and propagates tracing context from HTTP headers
92//! in supported event types. This enables distributed tracing across service boundaries.
93//! The W3C Trace Context format is used for propagation.
94//!
95//! # Response Tracking
96//!
97//! For HTTP responses, the layer automatically:
98//! - Sets `http.status_code` from the response statusCode
99//! - Sets span status to ERROR for 5xx responses
100//! - Sets span status to OK for all other responses
101
102use crate::extractors::{set_common_attributes, set_response_attributes, SpanAttributesExtractor};
103use crate::TelemetryCompletionHandler;
104use futures_util::ready;
105use lambda_runtime::{Error, LambdaEvent};
106use opentelemetry::trace::Status;
107use pin_project::pin_project;
108use serde::{de::DeserializeOwned, Serialize};
109use std::marker::PhantomData;
110use std::{
111    future::Future,
112    pin::Pin,
113    task::{self, Poll},
114};
115use tower::{Layer, Service};
116use tracing::{field::Empty, instrument::Instrumented, Instrument};
117use tracing_opentelemetry::OpenTelemetrySpanExt;
118
119/// Future that calls complete() on the completion handler when the inner future completes.
120///
121/// This future wraps the inner service future to ensure that spans are properly completed
122/// and exported. It:
123/// 1. Polls the inner future to completion
124/// 2. Extracts response attributes (e.g., status code)
125/// 3. Sets span status based on response
126/// 4. Signals completion through the completion handler
127///
128/// This type is created automatically by `OtelTracingService` - you shouldn't need to
129/// construct it directly.
130#[pin_project]
131pub struct CompletionFuture<Fut> {
132    #[pin]
133    future: Option<Fut>,
134    completion_handler: Option<TelemetryCompletionHandler>,
135    span: Option<tracing::Span>,
136}
137
138impl<Fut, R> Future for CompletionFuture<Fut>
139where
140    Fut: Future<Output = Result<R, Error>>,
141    R: Serialize + Send + 'static,
142{
143    type Output = Result<R, Error>;
144
145    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
146        let ready = ready!(self
147            .as_mut()
148            .project()
149            .future
150            .as_pin_mut()
151            .expect("future polled after completion")
152            .poll(cx));
153
154        // Extract response attributes if it's a successful response
155        if let Ok(response) = &ready {
156            if let Ok(value) = serde_json::to_value(response) {
157                if let Some(span) = self.span.as_ref() {
158                    set_response_attributes(span, &value);
159                }
160            }
161        } else if let Err(error) = &ready {
162            if let Some(span) = self.span.as_ref() {
163                // Set error status according to OpenTelemetry spec
164                span.set_status(Status::error(error.to_string()));
165            }
166        }
167
168        // Drop the future and span before calling complete
169        Pin::set(&mut self.as_mut().project().future, None);
170        let this = self.project();
171        this.span.take(); // Take ownership and drop the span
172
173        // Now that the span is closed, complete telemetry
174        if let Some(handler) = this.completion_handler.take() {
175            handler.complete();
176        }
177
178        Poll::Ready(ready)
179    }
180}
181
182/// Tower middleware to create an OpenTelemetry tracing span for Lambda invocations.
183///
184/// This layer wraps a Lambda service to automatically create and configure OpenTelemetry
185/// spans for each invocation. It supports:
186/// - Automatic span creation with configurable names
187/// - Automatic attribute extraction from supported event types
188/// - Context propagation from HTTP headers
189/// - Response status tracking
190///
191/// # Example
192///
193/// ```no_run
194/// use lambda_otel_lite::{init_telemetry, OtelTracingLayer, TelemetryConfig, SpanAttributes};
195/// use lambda_runtime::{service_fn, Error, LambdaEvent, Runtime};
196/// use aws_lambda_events::event::apigw::ApiGatewayV2httpRequest;
197/// use tower::ServiceBuilder;
198///
199/// async fn handler(event: LambdaEvent<ApiGatewayV2httpRequest>) -> Result<serde_json::Value, Error> {
200///     Ok(serde_json::json!({ "statusCode": 200 }))
201/// }
202///
203/// # async fn example() -> Result<(), Error> {
204/// let (_, completion_handler) = init_telemetry(TelemetryConfig::default()).await?;
205///
206/// // Create a layer with custom name
207/// let layer = OtelTracingLayer::new(completion_handler)
208///     .with_name("api-handler");
209///
210/// // Apply the layer to your handler
211/// let service = ServiceBuilder::new()
212///     .layer(layer)
213///     .service_fn(handler);
214///
215/// Runtime::new(service).run().await
216/// # }
217/// ```
218#[derive(Clone)]
219pub struct OtelTracingLayer<T: SpanAttributesExtractor> {
220    completion_handler: TelemetryCompletionHandler,
221    name: String,
222    _phantom: PhantomData<T>,
223}
224
225impl<T: SpanAttributesExtractor> OtelTracingLayer<T> {
226    /// Create a new OpenTelemetry tracing layer with the required completion handler.
227    ///
228    /// The completion handler is used to signal when spans should be exported. It's typically
229    /// obtained from [`init_telemetry`](crate::init_telemetry).
230    ///
231    /// # Arguments
232    ///
233    /// * `completion_handler` - Handler for managing span export timing
234    pub fn new(completion_handler: TelemetryCompletionHandler) -> Self {
235        Self {
236            completion_handler,
237            name: "lambda-invocation".to_string(),
238            _phantom: PhantomData,
239        }
240    }
241
242    /// Set the span name.
243    ///
244    /// This name will be used for all spans created by this layer. It should describe
245    /// the purpose of the Lambda function (e.g., "process-order", "api-handler").
246    ///
247    /// # Arguments
248    ///
249    /// * `name` - The name to use for spans
250    pub fn with_name(mut self, name: impl Into<String>) -> Self {
251        self.name = name.into();
252        self
253    }
254}
255
256impl<S, T> Layer<S> for OtelTracingLayer<T>
257where
258    T: SpanAttributesExtractor + Clone,
259{
260    type Service = OtelTracingService<S, T>;
261
262    fn layer(&self, inner: S) -> Self::Service {
263        OtelTracingService::<S, T> {
264            inner,
265            completion_handler: self.completion_handler.clone(),
266            name: self.name.clone(),
267            is_cold_start: true,
268            _phantom: PhantomData,
269        }
270    }
271}
272
273/// Tower service returned by [OtelTracingLayer].
274///
275/// This service wraps the inner Lambda service to:
276/// 1. Create a span for each invocation
277/// 2. Extract and set span attributes
278/// 3. Propagate context from headers
279/// 4. Track response status
280/// 5. Signal completion for span export
281///
282/// The service is created automatically by the layer - you shouldn't need to
283/// construct it directly.
284#[derive(Clone)]
285pub struct OtelTracingService<S, T: SpanAttributesExtractor> {
286    inner: S,
287    completion_handler: TelemetryCompletionHandler,
288    name: String,
289    is_cold_start: bool,
290    _phantom: PhantomData<T>,
291}
292
293impl<S, F, T, R> Service<LambdaEvent<T>> for OtelTracingService<S, T>
294where
295    S: Service<LambdaEvent<T>, Response = R, Error = Error, Future = F> + Send,
296    F: Future<Output = Result<R, Error>> + Send + 'static,
297    T: SpanAttributesExtractor + DeserializeOwned + Serialize + Send + 'static,
298    R: Serialize + Send + 'static,
299{
300    type Response = R;
301    type Error = Error;
302    type Future = CompletionFuture<Instrumented<S::Future>>;
303
304    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
305        self.inner.poll_ready(cx)
306    }
307
308    fn call(&mut self, event: LambdaEvent<T>) -> Self::Future {
309        let span = tracing::info_span!(
310            parent: None,
311            "handler",
312            otel.name=Empty,
313            otel.kind=Empty,
314            otel.status_code=Empty,
315            otel.status_message=Empty,
316            requestId=%event.context.request_id,
317        );
318
319        // Set the span name and default kind
320        span.record("otel.name", self.name.clone());
321        span.record("otel.kind", "SERVER");
322
323        // Set common Lambda attributes with cold start tracking
324        set_common_attributes(&span, &event.context, self.is_cold_start);
325        if self.is_cold_start {
326            self.is_cold_start = false;
327        }
328
329        // Extract attributes directly using the trait
330        let attrs = event.payload.extract_span_attributes();
331
332        // Apply extracted attributes
333        if let Some(span_name) = attrs.span_name {
334            span.record("otel.name", span_name);
335        }
336
337        if let Some(kind) = &attrs.kind {
338            span.record("otel.kind", kind.to_string());
339        }
340
341        for (key, value) in &attrs.attributes {
342            span.set_attribute(key.to_string(), value.to_string());
343        }
344
345        for link in attrs.links {
346            span.add_link_with_attributes(link.span_context, link.attributes);
347        }
348
349        // Propagate context from headers
350        if let Some(carrier) = attrs.carrier {
351            let parent_context = opentelemetry::global::get_text_map_propagator(|propagator| {
352                propagator.extract(&carrier)
353            });
354            span.set_parent(parent_context);
355        }
356
357        // Set trigger type
358        span.set_attribute("faas.trigger", attrs.trigger.to_string());
359
360        let future = {
361            let _guard = span.enter();
362            self.inner.call(event)
363        };
364
365        CompletionFuture {
366            future: Some(future.instrument(span.clone())),
367            completion_handler: Some(self.completion_handler.clone()),
368            span: Some(span),
369        }
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use crate::ProcessorMode;
377    use futures_util::future::BoxFuture;
378    use lambda_runtime::Context;
379    use opentelemetry::trace::TracerProvider as _;
380    use opentelemetry_sdk::{
381        trace::{SdkTracerProvider, SpanData, SpanExporter},
382        Resource,
383    };
384    use serial_test::serial;
385    use std::sync::atomic::{AtomicUsize, Ordering};
386    use std::sync::Arc;
387    use std::time::Duration;
388    use tower::ServiceExt;
389    use tracing_subscriber::prelude::*;
390
391    // Mock exporter that counts exports
392    #[derive(Debug)]
393    struct CountingExporter {
394        export_count: Arc<AtomicUsize>,
395    }
396
397    impl CountingExporter {
398        fn new() -> Self {
399            Self {
400                export_count: Arc::new(AtomicUsize::new(0)),
401            }
402        }
403    }
404
405    impl SpanExporter for CountingExporter {
406        fn export(
407            &mut self,
408            batch: Vec<SpanData>,
409        ) -> BoxFuture<'static, opentelemetry_sdk::error::OTelSdkResult> {
410            self.export_count.fetch_add(batch.len(), Ordering::SeqCst);
411            Box::pin(futures_util::future::ready(Ok(())))
412        }
413
414        fn shutdown(&mut self) -> opentelemetry_sdk::error::OTelSdkResult {
415            Ok(())
416        }
417    }
418
419    #[tokio::test]
420    #[serial]
421    async fn test_basic_layer() -> Result<(), Error> {
422        let exporter = CountingExporter::new();
423        let export_count = exporter.export_count.clone();
424
425        let provider = SdkTracerProvider::builder()
426            .with_simple_exporter(exporter)
427            .with_resource(Resource::builder().build())
428            .build();
429        let provider = Arc::new(provider);
430
431        // Set up tracing subscriber
432        let _subscriber = tracing_subscriber::registry::Registry::default()
433            .with(tracing_opentelemetry::OpenTelemetryLayer::new(
434                provider.tracer("test"),
435            ))
436            .set_default();
437
438        let completion_handler =
439            TelemetryCompletionHandler::new(provider.clone(), None, ProcessorMode::Sync);
440
441        let handler = |_: LambdaEvent<serde_json::Value>| async {
442            // Create a span to ensure it's captured
443            let _span = tracing::info_span!("test_span");
444            Ok::<_, Error>(serde_json::json!({"status": "ok"}))
445        };
446
447        let layer = OtelTracingLayer::new(completion_handler).with_name("test-handler");
448
449        let mut svc = tower::ServiceBuilder::new()
450            .layer(layer)
451            .service_fn(handler);
452
453        let event = LambdaEvent::new(serde_json::json!({}), Context::default());
454
455        let _ = svc.ready().await?.call(event).await?;
456
457        // Wait a bit longer for spans to be exported
458        tokio::time::sleep(Duration::from_millis(500)).await;
459
460        assert!(export_count.load(Ordering::SeqCst) > 0);
461
462        Ok(())
463    }
464}