lambda_runtime/layers/
trace.rs

1use tower::{Layer, Service};
2use tracing::{instrument::Instrumented, Instrument};
3
4use crate::{Context, LambdaInvocation};
5use lambda_runtime_api_client::BoxError;
6use std::task;
7
8/// Tower middleware to create a tracing span for invocations of the Lambda function.
9#[derive(Default)]
10pub struct TracingLayer {}
11
12impl TracingLayer {
13    /// Create a new tracing layer.
14    pub fn new() -> Self {
15        Self::default()
16    }
17}
18
19impl<S> Layer<S> for TracingLayer {
20    type Service = TracingService<S>;
21
22    fn layer(&self, inner: S) -> Self::Service {
23        TracingService { inner }
24    }
25}
26
27/// Tower service returned by [TracingLayer].
28pub struct TracingService<S> {
29    inner: S,
30}
31
32impl<S> Service<LambdaInvocation> for TracingService<S>
33where
34    S: Service<LambdaInvocation, Response = (), Error = BoxError>,
35{
36    type Response = ();
37    type Error = BoxError;
38    type Future = Instrumented<S::Future>;
39
40    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
41        self.inner.poll_ready(cx)
42    }
43
44    fn call(&mut self, req: LambdaInvocation) -> Self::Future {
45        let span = request_span(&req.context);
46        let future = {
47            // Enter the span before calling the inner service
48            // to ensure that it's assigned as parent of the inner spans.
49            let _guard = span.enter();
50            self.inner.call(req)
51        };
52        future.instrument(span)
53    }
54}
55
56/* ------------------------------------------- UTILS ------------------------------------------- */
57
58fn request_span(ctx: &Context) -> tracing::Span {
59    match &ctx.xray_trace_id {
60        Some(trace_id) => {
61            tracing::info_span!(
62                "Lambda runtime invoke",
63                requestId = &ctx.request_id,
64                xrayTraceId = trace_id
65            )
66        }
67        None => {
68            tracing::info_span!("Lambda runtime invoke", requestId = &ctx.request_id)
69        }
70    }
71}