Skip to main content

lambda_runtime/
runtime.rs

1use crate::{
2    layers::{CatchPanicService, RuntimeApiClientService, RuntimeApiResponseService},
3    requests::{IntoRequest, NextEventRequest},
4    types::{invoke_request_id, IntoFunctionResponse, LambdaEvent},
5    Config, Context, Diagnostic,
6};
7#[cfg(feature = "concurrency-tokio")]
8use futures::stream::FuturesUnordered;
9use http_body_util::BodyExt;
10use lambda_runtime_api_client::{BoxError, Client as ApiClient};
11use serde::{Deserialize, Serialize};
12#[cfg(feature = "concurrency-tokio")]
13use std::fmt;
14use std::{env, fmt::Debug, future::Future, sync::Arc};
15use tokio_stream::{Stream, StreamExt};
16use tower::{Layer, Service, ServiceExt};
17use tracing::trace;
18#[cfg(feature = "concurrency-tokio")]
19use tracing::{debug, error, info_span, warn, Instrument};
20
21/* ----------------------------------------- INVOCATION ---------------------------------------- */
22
23/// A simple container that provides information about a single invocation of a Lambda function.
24pub struct LambdaInvocation {
25    /// The header of the request sent to invoke the Lambda function.
26    pub parts: http::response::Parts,
27    /// The body of the request sent to invoke the Lambda function.
28    pub body: bytes::Bytes,
29    /// The context of the Lambda invocation.
30    pub context: Context,
31}
32
33/* ------------------------------------------ RUNTIME ------------------------------------------ */
34
35/// Lambda runtime executing a handler function on incoming requests.
36///
37/// Middleware can be added to a runtime using the [Runtime::layer] method in order to execute
38/// logic prior to processing the incoming request and/or after the response has been sent back
39/// to the Lambda Runtime API.
40///
41/// # Example
42/// ```no_run
43/// use lambda_runtime::{Error, LambdaEvent, Runtime};
44/// use serde_json::Value;
45/// use tower::service_fn;
46///
47/// #[tokio::main]
48/// async fn main() -> Result<(), Error> {
49///     let func = service_fn(func);
50///     Runtime::new(func).run().await?;
51///     Ok(())
52/// }
53///
54/// async fn func(event: LambdaEvent<Value>) -> Result<Value, Error> {
55///     Ok(event.payload)
56/// }
57/// ````
58pub struct Runtime<S> {
59    service: S,
60    config: Arc<Config>,
61    client: Arc<ApiClient>,
62    concurrency_limit: u32,
63}
64
65impl<F, EventPayload, Response, BufferedResponse, StreamingResponse, StreamItem, StreamError>
66    Runtime<
67        RuntimeApiClientService<
68            RuntimeApiResponseService<
69                CatchPanicService<'_, F>,
70                EventPayload,
71                Response,
72                BufferedResponse,
73                StreamingResponse,
74                StreamItem,
75                StreamError,
76            >,
77        >,
78    >
79where
80    F: Service<LambdaEvent<EventPayload>, Response = Response>,
81    F::Future: Future<Output = Result<Response, F::Error>>,
82    F::Error: Into<Diagnostic> + Debug,
83    EventPayload: for<'de> Deserialize<'de>,
84    Response: IntoFunctionResponse<BufferedResponse, StreamingResponse>,
85    BufferedResponse: Serialize,
86    StreamingResponse: Stream<Item = Result<StreamItem, StreamError>> + Unpin + Send + 'static,
87    StreamItem: Into<bytes::Bytes> + Send,
88    StreamError: Into<BoxError> + Send + Debug,
89{
90    /// Create a new runtime that executes the provided handler for incoming requests.
91    ///
92    /// In order to start the runtime and poll for events on the [Lambda Runtime
93    /// APIs](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-api.html), you must call
94    /// [Runtime::run].
95    ///
96    /// Note that manually creating a [Runtime] does not add tracing to the executed handler
97    /// as is done by [super::run]. If you want to add the default tracing functionality, call
98    /// [Runtime::layer] with a [super::layers::TracingLayer].
99    ///
100    ///
101    /// # Panics
102    ///
103    /// This function panics if required Lambda environment variables are missing
104    /// (`AWS_LAMBDA_FUNCTION_NAME`, `AWS_LAMBDA_FUNCTION_MEMORY_SIZE`,
105    /// `AWS_LAMBDA_FUNCTION_VERSION`, `AWS_LAMBDA_RUNTIME_API`).
106    pub fn new(handler: F) -> Self {
107        trace!("Loading config from env");
108        let config = Arc::new(Config::from_env());
109        let concurrency_limit = max_concurrency_from_env().unwrap_or(1).max(1);
110        // Strategy: allocate all worker tasks up-front, so size the client pool to match.
111        let pool_size = concurrency_limit as usize;
112        let client = Arc::new(
113            ApiClient::builder()
114                .with_pool_size(pool_size)
115                .build()
116                .expect("Unable to create a runtime client"),
117        );
118        Self {
119            service: wrap_handler(handler, client.clone()),
120            config,
121            client,
122            concurrency_limit,
123        }
124    }
125}
126
127impl<S> Runtime<S> {
128    /// Add a new layer to this runtime. For an incoming request, this layer will be executed
129    /// before any layer that has been added prior.
130    ///
131    /// # Example
132    /// ```no_run
133    /// use lambda_runtime::{layers, Error, LambdaEvent, Runtime};
134    /// use serde_json::Value;
135    /// use tower::service_fn;
136    ///
137    /// #[tokio::main]
138    /// async fn main() -> Result<(), Error> {
139    ///     let runtime = Runtime::new(service_fn(echo)).layer(
140    ///         layers::TracingLayer::new()
141    ///     );
142    ///     runtime.run().await?;
143    ///     Ok(())
144    /// }
145    ///
146    /// async fn echo(event: LambdaEvent<Value>) -> Result<Value, Error> {
147    ///     Ok(event.payload)
148    /// }
149    /// ```
150    pub fn layer<L>(self, layer: L) -> Runtime<L::Service>
151    where
152        L: Layer<S>,
153        L::Service: Service<LambdaInvocation, Response = (), Error = BoxError>,
154    {
155        Runtime {
156            client: self.client,
157            config: self.config,
158            service: layer.layer(self.service),
159            concurrency_limit: self.concurrency_limit,
160        }
161    }
162}
163
164#[cfg(feature = "concurrency-tokio")]
165impl<S> Runtime<S>
166where
167    S: Service<LambdaInvocation, Response = (), Error = BoxError> + Clone + Send + 'static,
168    S::Future: Send,
169{
170    /// Start the runtime and begin polling for events on the Lambda Runtime API,
171    /// in a mode that is compatible with Lambda Managed Instances.
172    ///
173    /// When `AWS_LAMBDA_MAX_CONCURRENCY` is set to a value greater than 1, this
174    /// spawns multiple tokio worker tasks to handle concurrent invocations. When the
175    /// environment variable is unset or `<= 1`, it falls back to sequential
176    /// behavior, so the same handler can run on both classic Lambda and Lambda
177    /// Managed Instances.
178    ///
179    /// # Panics
180    ///
181    /// This function panics if called outside of a Tokio runtime.
182    #[cfg_attr(docsrs, doc(cfg(feature = "concurrency-tokio")))]
183    pub async fn run_concurrent(self) -> Result<(), BoxError> {
184        if tokio::runtime::Handle::try_current().is_err() {
185            panic!("`run_concurrent` must be called from within a Tokio runtime");
186        }
187
188        if self.concurrency_limit > 1 {
189            trace!("Concurrent mode: _X_AMZN_TRACE_ID is not set; use context.xray_trace_id");
190            Self::run_concurrent_inner(self.service, self.config, self.client, self.concurrency_limit).await
191        } else {
192            debug!(
193                "Concurrent polling disabled (AWS_LAMBDA_MAX_CONCURRENCY unset or <= 1); falling back to sequential polling"
194            );
195            let incoming = incoming(&self.client);
196            Self::run_with_incoming(self.service, self.config, incoming).await
197        }
198    }
199
200    /// Concurrent processing using N independent long-poll loops (for Lambda managed-concurrency).
201    async fn run_concurrent_inner(
202        service: S,
203        config: Arc<Config>,
204        client: Arc<ApiClient>,
205        concurrency_limit: u32,
206    ) -> Result<(), BoxError> {
207        let limit = concurrency_limit as usize;
208
209        // Use FuturesUnordered so we can observe worker exits as they happen,
210        // rather than waiting for all workers to finish (join_all).
211        let mut workers: FuturesUnordered<tokio::task::JoinHandle<(tokio::task::Id, Result<(), BoxError>)>> =
212            FuturesUnordered::new();
213        let spawn_worker = |service: S, config: Arc<Config>, client: Arc<ApiClient>| {
214            tokio::spawn(async move {
215                let task_id = tokio::task::id();
216                let result = concurrent_worker_loop(service, config, client).await;
217                (task_id, result)
218            })
219        };
220        // Spawn one worker per concurrency slot; the last uses the owned service to avoid an extra clone.
221        for _ in 1..limit {
222            workers.push(spawn_worker(service.clone(), config.clone(), client.clone()));
223        }
224        workers.push(spawn_worker(service, config, client));
225
226        // Track worker exits across tasks. A single worker failing should not
227        // terminate the whole runtime (LMI keeps running with the remaining
228        // healthy workers). We only return an error once there are no workers
229        // left (i.e., we cannot keep at least 1 worker alive).
230        //
231        // Note: Handler errors (Err returned from user code) do NOT trigger this.
232        // They are reported to Lambda via /invocation/{id}/error and the worker
233        // continues. This only captures unrecoverable runtime failures like
234        // API client failures, runtime panics, etc.
235        let mut errors: Vec<WorkerError> = Vec::new();
236        let mut remaining_workers = limit;
237        while let Some(result) = futures::StreamExt::next(&mut workers).await {
238            remaining_workers = remaining_workers.saturating_sub(1);
239            match result {
240                Ok((task_id, Ok(()))) => {
241                    // `concurrent_worker_loop` runs indefinitely, so an Ok return indicates
242                    // an unexpected worker exit; we still decrement because the task is gone.
243                    error!(
244                        task_id = %task_id,
245                        remaining_workers,
246                        "Concurrent worker exited cleanly (unexpected - loop should run forever)"
247                    );
248                    errors.push(WorkerError::CleanExit(task_id));
249                }
250                Ok((task_id, Err(err))) => {
251                    error!(
252                        task_id = %task_id,
253                        error = %err,
254                        remaining_workers,
255                        "Concurrent worker exited with error"
256                    );
257                    errors.push(WorkerError::Failure(task_id, err));
258                }
259                Err(join_err) => {
260                    let task_id = join_err.id();
261                    let err: BoxError = Box::new(join_err);
262                    error!(
263                        task_id = %task_id,
264                        error = %err,
265                        remaining_workers,
266                        "Concurrent worker panicked"
267                    );
268                    errors.push(WorkerError::Failure(task_id, err));
269                }
270            }
271        }
272
273        match errors.len() {
274            0 => Ok(()),
275            _ => Err(Box::new(ConcurrentWorkerErrors { errors })),
276        }
277    }
278}
279
280#[cfg(feature = "concurrency-tokio")]
281#[derive(Debug)]
282enum WorkerError {
283    CleanExit(tokio::task::Id),
284    Failure(tokio::task::Id, BoxError),
285}
286
287#[cfg(feature = "concurrency-tokio")]
288#[derive(Debug)]
289struct ConcurrentWorkerErrors {
290    errors: Vec<WorkerError>,
291}
292
293#[cfg(feature = "concurrency-tokio")]
294#[derive(Serialize)]
295struct ConcurrentWorkerErrorsPayload<'a> {
296    message: &'a str,
297    #[serde(skip_serializing_if = "Vec::is_empty")]
298    clean: Vec<String>,
299    #[serde(skip_serializing_if = "Vec::is_empty")]
300    failures: Vec<WorkerFailurePayload>,
301}
302
303#[cfg(feature = "concurrency-tokio")]
304#[derive(Serialize)]
305struct WorkerFailurePayload {
306    id: String,
307    err: String,
308}
309
310#[cfg(feature = "concurrency-tokio")]
311impl fmt::Display for ConcurrentWorkerErrors {
312    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
313        let mut clean = Vec::new();
314        let mut failures = Vec::new();
315        for error in &self.errors {
316            match error {
317                WorkerError::CleanExit(task_id) => clean.push(task_id),
318                WorkerError::Failure(task_id, err) => failures.push((task_id, err)),
319            }
320        }
321
322        let clean_ids: Vec<String> = clean.iter().map(|task_id| task_id.to_string()).collect();
323        let failure_entries: Vec<WorkerFailurePayload> = failures
324            .iter()
325            .map(|(task_id, err)| WorkerFailurePayload {
326                id: task_id.to_string(),
327                err: err.to_string(),
328            })
329            .collect();
330
331        let message = if failures.is_empty() && !clean.is_empty() {
332            "all concurrent workers exited cleanly (unexpected - loop should run forever)"
333        } else {
334            "concurrent workers exited unexpectedly"
335        };
336
337        let payload = ConcurrentWorkerErrorsPayload {
338            message,
339            clean: clean_ids,
340            failures: failure_entries,
341        };
342        let json = serde_json::to_string(&payload).map_err(|_| fmt::Error)?;
343        write!(f, "{json}")
344    }
345}
346
347#[cfg(feature = "concurrency-tokio")]
348impl std::error::Error for ConcurrentWorkerErrors {}
349
350impl<S> Runtime<S>
351where
352    S: Service<LambdaInvocation, Response = (), Error = BoxError>,
353{
354    /// Start the runtime and begin polling for events on the Lambda Runtime API.
355    ///
356    /// The runtime will process requests sequentially.
357    ///
358    /// # Managed concurrency
359    /// If `AWS_LAMBDA_MAX_CONCURRENCY` is set, a warning is logged.
360    /// If your handler can satisfy `Clone + Send + 'static`,
361    /// prefer [`Runtime::run_concurrent`] (requires the `concurrency-tokio` feature),
362    /// which honors managed concurrency and falls back to sequential behavior when
363    /// unset.
364    pub async fn run(self) -> Result<(), BoxError> {
365        if let Some(raw) = concurrency_env_value() {
366            log_or_print!(
367                tracing: tracing::warn!(
368                    "AWS_LAMBDA_MAX_CONCURRENCY is set to '{raw}', but the concurrency-tokio feature is not enabled; running sequentially",
369                ),
370                fallback: eprintln!("AWS_LAMBDA_MAX_CONCURRENCY is set to '{raw}', but the concurrency-tokio feature is not enabled; running sequentially")
371            );
372        }
373        let incoming = incoming(&self.client);
374        Self::run_with_incoming(self.service, self.config, incoming).await
375    }
376
377    /// Internal utility function to start the runtime with a customized incoming stream.
378    /// This implements the core of the [Runtime::run] method.
379    pub(crate) async fn run_with_incoming(
380        mut service: S,
381        config: Arc<Config>,
382        incoming: impl Stream<Item = Result<http::Response<hyper::body::Incoming>, BoxError>> + Send,
383    ) -> Result<(), BoxError> {
384        tokio::pin!(incoming);
385        while let Some(next_event_response) = incoming.next().await {
386            trace!("New event arrived (run loop)");
387            let event = next_event_response?;
388            process_invocation(&mut service, &config, event, true).await?;
389        }
390        Ok(())
391    }
392}
393
394/* ------------------------------------------- UTILS ------------------------------------------- */
395
396#[allow(clippy::type_complexity)]
397fn wrap_handler<'a, F, EventPayload, Response, BufferedResponse, StreamingResponse, StreamItem, StreamError>(
398    handler: F,
399    client: Arc<ApiClient>,
400) -> RuntimeApiClientService<
401    RuntimeApiResponseService<
402        CatchPanicService<'a, F>,
403        EventPayload,
404        Response,
405        BufferedResponse,
406        StreamingResponse,
407        StreamItem,
408        StreamError,
409    >,
410>
411where
412    F: Service<LambdaEvent<EventPayload>, Response = Response>,
413    F::Future: Future<Output = Result<Response, F::Error>>,
414    F::Error: Into<Diagnostic> + Debug,
415    EventPayload: for<'de> Deserialize<'de>,
416    Response: IntoFunctionResponse<BufferedResponse, StreamingResponse>,
417    BufferedResponse: Serialize,
418    StreamingResponse: Stream<Item = Result<StreamItem, StreamError>> + Unpin + Send + 'static,
419    StreamItem: Into<bytes::Bytes> + Send,
420    StreamError: Into<BoxError> + Send + Debug,
421{
422    let safe_service = CatchPanicService::new(handler);
423    let response_service = RuntimeApiResponseService::new(safe_service);
424    RuntimeApiClientService::new(response_service, client)
425}
426
427fn incoming(
428    client: &ApiClient,
429) -> impl Stream<Item = Result<http::Response<hyper::body::Incoming>, BoxError>> + Send + '_ {
430    async_stream::stream! {
431        loop {
432            trace!("Waiting for next event (incoming loop)");
433            let req = NextEventRequest.into_req().expect("Unable to construct request");
434            let res = client.call(req).await;
435            yield res;
436        }
437    }
438}
439
440/// Creates a future that polls the `/next` endpoint.
441#[cfg(feature = "concurrency-tokio")]
442async fn next_event_future(client: &ApiClient) -> Result<http::Response<hyper::body::Incoming>, BoxError> {
443    let req = NextEventRequest.into_req()?;
444    client.call(req).await
445}
446
447fn max_concurrency_from_env() -> Option<u32> {
448    env::var("AWS_LAMBDA_MAX_CONCURRENCY")
449        .ok()
450        .and_then(|v| v.parse::<u32>().ok())
451        .filter(|&c| c > 0)
452}
453
454fn concurrency_env_value() -> Option<String> {
455    env::var("AWS_LAMBDA_MAX_CONCURRENCY").ok()
456}
457
458#[cfg(feature = "concurrency-tokio")]
459async fn concurrent_worker_loop<S>(mut service: S, config: Arc<Config>, client: Arc<ApiClient>) -> Result<(), BoxError>
460where
461    S: Service<LambdaInvocation, Response = (), Error = BoxError>,
462    S::Future: Send,
463{
464    let task_id = tokio::task::id();
465    let span = info_span!("worker", task_id = %task_id);
466    loop {
467        let event = match next_event_future(client.as_ref()).instrument(span.clone()).await {
468            Ok(event) => event,
469            Err(e) => {
470                warn!(task_id = %task_id, error = %e, "Error polling /next, retrying");
471                continue;
472            }
473        };
474
475        process_invocation(&mut service, &config, event, false)
476            .instrument(span.clone())
477            .await?;
478    }
479}
480
481async fn process_invocation<S>(
482    service: &mut S,
483    config: &Arc<Config>,
484    event: http::Response<hyper::body::Incoming>,
485    set_amzn_trace_env: bool,
486) -> Result<(), BoxError>
487where
488    S: Service<LambdaInvocation, Response = (), Error = BoxError>,
489{
490    let (parts, incoming) = event.into_parts();
491
492    #[cfg(debug_assertions)]
493    if parts.status == http::StatusCode::NO_CONTENT {
494        // Ignore the event if the status code is 204.
495        // This is a way to keep the runtime alive when
496        // there are no events pending to be processed.
497        return Ok(());
498    }
499
500    // Build the invocation such that it can be sent to the service right away
501    // when it is ready
502    let body = incoming.collect().await?.to_bytes();
503    let context = Context::new(invoke_request_id(&parts.headers)?, config.clone(), &parts.headers)?;
504    let invocation = LambdaInvocation { parts, body, context };
505
506    if set_amzn_trace_env {
507        // Setup Amazon's default tracing data
508        amzn_trace_env(&invocation.context);
509    }
510
511    // Wait for service to be ready
512    let ready = service.ready().await?;
513
514    // Once ready, call the service which will respond to the Lambda runtime API
515    ready.call(invocation).await?;
516    Ok(())
517}
518
519fn amzn_trace_env(ctx: &Context) {
520    match &ctx.xray_trace_id {
521        Some(trace_id) => env::set_var("_X_AMZN_TRACE_ID", trace_id),
522        None => env::remove_var("_X_AMZN_TRACE_ID"),
523    }
524}
525
526/* --------------------------------------------------------------------------------------------- */
527/*                                             TESTS                                             */
528/* --------------------------------------------------------------------------------------------- */
529
530#[cfg(test)]
531mod endpoint_tests {
532    use super::{incoming, wrap_handler};
533    use crate::{
534        requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest},
535        Config, Diagnostic, Error, Runtime,
536    };
537    use bytes::Bytes;
538    use futures::future::BoxFuture;
539    use http::{HeaderValue, Method, Request, Response, StatusCode};
540    use http_body_util::{BodyExt, Full};
541    use httpmock::prelude::*;
542
543    use hyper::{body::Incoming, service::service_fn};
544    use hyper_util::{
545        rt::{tokio::TokioIo, TokioExecutor},
546        server::conn::auto::Builder as ServerBuilder,
547    };
548    use lambda_runtime_api_client::Client;
549    use std::{
550        convert::Infallible,
551        env,
552        sync::{
553            atomic::{AtomicUsize, Ordering},
554            Arc,
555        },
556        time::Duration,
557    };
558    use tokio::{net::TcpListener, sync::Notify};
559    use tokio_stream::StreamExt;
560
561    #[tokio::test]
562    async fn test_next_event() -> Result<(), Error> {
563        let server = MockServer::start();
564        let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9";
565        let deadline = "1542409706888";
566
567        let mock = server.mock(|when, then| {
568            when.method(GET).path("/2018-06-01/runtime/invocation/next");
569            then.status(200)
570                .header("content-type", "application/json")
571                .header("lambda-runtime-aws-request-id", request_id)
572                .header("lambda-runtime-deadline-ms", deadline)
573                .body("{}");
574        });
575
576        let base = server.base_url().parse().expect("Invalid mock server Uri");
577        let client = Client::builder().with_endpoint(base).build()?;
578
579        let req = NextEventRequest.into_req()?;
580        let rsp = client.call(req).await.expect("Unable to send request");
581
582        mock.assert_async().await;
583        assert_eq!(rsp.status(), StatusCode::OK);
584        assert_eq!(
585            rsp.headers()["lambda-runtime-aws-request-id"],
586            &HeaderValue::from_static(request_id)
587        );
588        assert_eq!(
589            rsp.headers()["lambda-runtime-deadline-ms"],
590            &HeaderValue::from_static(deadline)
591        );
592
593        let body = rsp.into_body().collect().await?.to_bytes();
594        assert_eq!("{}", std::str::from_utf8(&body)?);
595        Ok(())
596    }
597
598    #[tokio::test]
599    async fn test_ok_response() -> Result<(), Error> {
600        let server = MockServer::start();
601
602        let mock = server.mock(|when, then| {
603            when.method(POST)
604                .path("/2018-06-01/runtime/invocation/156cb537-e2d4-11e8-9b34-d36013741fb9/response")
605                .body("\"{}\"");
606            then.status(200).body("");
607        });
608
609        let base = server.base_url().parse().expect("Invalid mock server Uri");
610        let client = Client::builder().with_endpoint(base).build()?;
611
612        let req = EventCompletionRequest::new("156cb537-e2d4-11e8-9b34-d36013741fb9", "{}");
613        let req = req.into_req()?;
614
615        let rsp = client.call(req).await?;
616
617        mock.assert_async().await;
618        assert_eq!(rsp.status(), StatusCode::OK);
619        Ok(())
620    }
621
622    #[tokio::test]
623    async fn test_error_response() -> Result<(), Error> {
624        let diagnostic = Diagnostic {
625            error_type: "InvalidEventDataError".into(),
626            error_message: "Error parsing event data".into(),
627        };
628        let body = serde_json::to_string(&diagnostic)?;
629
630        let server = MockServer::start();
631        let mock = server.mock(|when, then| {
632            when.method(POST)
633                .path("/2018-06-01/runtime/invocation/156cb537-e2d4-11e8-9b34-d36013741fb9/error")
634                .header("lambda-runtime-function-error-type", "unhandled")
635                .body(body);
636            then.status(200).body("");
637        });
638
639        let base = server.base_url().parse().expect("Invalid mock server Uri");
640        let client = Client::builder().with_endpoint(base).build()?;
641
642        let req = EventErrorRequest {
643            request_id: "156cb537-e2d4-11e8-9b34-d36013741fb9",
644            diagnostic,
645        };
646        let req = req.into_req()?;
647        let rsp = client.call(req).await?;
648
649        mock.assert_async().await;
650        assert_eq!(rsp.status(), StatusCode::OK);
651        Ok(())
652    }
653
654    #[tokio::test]
655    async fn successful_end_to_end_run() -> Result<(), Error> {
656        let server = MockServer::start();
657        let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9";
658        let deadline = "1542409706888";
659
660        let next_request = server.mock(|when, then| {
661            when.method(GET).path("/2018-06-01/runtime/invocation/next");
662            then.status(200)
663                .header("content-type", "application/json")
664                .header("lambda-runtime-aws-request-id", request_id)
665                .header("lambda-runtime-deadline-ms", deadline)
666                .body("{}");
667        });
668        let next_response = server.mock(|when, then| {
669            when.method(POST)
670                .path(format!("/2018-06-01/runtime/invocation/{request_id}/response"))
671                .body("{}");
672            then.status(200).body("");
673        });
674
675        let base = server.base_url().parse().expect("Invalid mock server Uri");
676        let client = Client::builder().with_endpoint(base).build()?;
677
678        async fn func(event: crate::LambdaEvent<serde_json::Value>) -> Result<serde_json::Value, Error> {
679            let (event, _) = event.into_parts();
680            Ok(event)
681        }
682        let f = crate::service_fn(func);
683
684        // set env vars needed to init Config if they are not already set in the environment
685        if env::var("AWS_LAMBDA_RUNTIME_API").is_err() {
686            env::set_var("AWS_LAMBDA_RUNTIME_API", server.base_url());
687        }
688        if env::var("AWS_LAMBDA_FUNCTION_NAME").is_err() {
689            env::set_var("AWS_LAMBDA_FUNCTION_NAME", "test_fn");
690        }
691        if env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE").is_err() {
692            env::set_var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "128");
693        }
694        if env::var("AWS_LAMBDA_FUNCTION_VERSION").is_err() {
695            env::set_var("AWS_LAMBDA_FUNCTION_VERSION", "1");
696        }
697        if env::var("AWS_LAMBDA_LOG_STREAM_NAME").is_err() {
698            env::set_var("AWS_LAMBDA_LOG_STREAM_NAME", "test_stream");
699        }
700        if env::var("AWS_LAMBDA_LOG_GROUP_NAME").is_err() {
701            env::set_var("AWS_LAMBDA_LOG_GROUP_NAME", "test_log");
702        }
703        let config = Config::from_env();
704
705        let client = Arc::new(client);
706        let runtime = Runtime {
707            client: client.clone(),
708            config: Arc::new(config),
709            service: wrap_handler(f, client),
710            concurrency_limit: 1,
711        };
712        let client = &runtime.client;
713        let incoming = incoming(client).take(1);
714        Runtime::run_with_incoming(runtime.service, runtime.config, incoming).await?;
715
716        next_request.assert_async().await;
717        next_response.assert_async().await;
718        Ok(())
719    }
720
721    async fn run_panicking_handler<F>(func: F) -> Result<(), Error>
722    where
723        F: FnMut(crate::LambdaEvent<serde_json::Value>) -> BoxFuture<'static, Result<serde_json::Value, Error>>
724            + Send
725            + 'static,
726    {
727        let server = MockServer::start();
728        let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9";
729        let deadline = "1542409706888";
730
731        let next_request = server.mock(|when, then| {
732            when.method(GET).path("/2018-06-01/runtime/invocation/next");
733            then.status(200)
734                .header("content-type", "application/json")
735                .header("lambda-runtime-aws-request-id", request_id)
736                .header("lambda-runtime-deadline-ms", deadline)
737                .body("{}");
738        });
739
740        let next_response = server.mock(|when, then| {
741            when.method(POST)
742                .path(format!("/2018-06-01/runtime/invocation/{request_id}/error"))
743                .header("lambda-runtime-function-error-type", "unhandled");
744            then.status(200).body("");
745        });
746
747        let base = server.base_url().parse().expect("Invalid mock server Uri");
748        let client = Client::builder().with_endpoint(base).build()?;
749
750        let f = crate::service_fn(func);
751
752        let config = Arc::new(Config {
753            function_name: "test_fn".to_string(),
754            memory: 128,
755            version: "1".to_string(),
756            log_stream: "test_stream".to_string(),
757            log_group: "test_log".to_string(),
758        });
759
760        let client = Arc::new(client);
761        let runtime = Runtime {
762            client: client.clone(),
763            config,
764            service: wrap_handler(f, client),
765            concurrency_limit: 1,
766        };
767        let client = &runtime.client;
768        let incoming = incoming(client).take(1);
769        Runtime::run_with_incoming(runtime.service, runtime.config, incoming).await?;
770
771        next_request.assert_async().await;
772        next_response.assert_async().await;
773        Ok(())
774    }
775
776    #[tokio::test]
777    async fn panic_in_async_run() -> Result<(), Error> {
778        run_panicking_handler(|_| Box::pin(async { panic!("This is intentionally here") })).await
779    }
780
781    #[tokio::test]
782    async fn panic_outside_async_run() -> Result<(), Error> {
783        run_panicking_handler(|_| {
784            panic!("This is intentionally here");
785        })
786        .await
787    }
788
789    #[cfg(feature = "concurrency-tokio")]
790    #[tokio::test]
791    async fn concurrent_worker_crash_does_not_stop_other_workers() -> Result<(), Error> {
792        let next_calls = Arc::new(AtomicUsize::new(0));
793        let response_calls = Arc::new(AtomicUsize::new(0));
794        let first_error_served = Arc::new(Notify::new());
795
796        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
797        let addr = listener.local_addr().unwrap();
798        let base: http::Uri = format!("http://{addr}").parse().unwrap();
799
800        let server_handle = {
801            let next_calls = next_calls.clone();
802            let response_calls = response_calls.clone();
803            let first_error_served = first_error_served.clone();
804            tokio::spawn(async move {
805                loop {
806                    let (tcp, _) = match listener.accept().await {
807                        Ok(v) => v,
808                        Err(_) => return,
809                    };
810
811                    let next_calls = next_calls.clone();
812                    let response_calls = response_calls.clone();
813                    let first_error_served = first_error_served.clone();
814                    let service = service_fn(move |req: Request<Incoming>| {
815                        let next_calls = next_calls.clone();
816                        let response_calls = response_calls.clone();
817                        let first_error_served = first_error_served.clone();
818                        async move {
819                            let (parts, body) = req.into_parts();
820                            let method = parts.method;
821                            let path = parts.uri.path().to_string();
822
823                            if method == Method::POST {
824                                // Drain request body to support keep-alive.
825                                let _ = body.collect().await;
826                            }
827
828                            if method == Method::GET && path == "/2018-06-01/runtime/invocation/next" {
829                                let call_index = next_calls.fetch_add(1, Ordering::SeqCst);
830                                match call_index {
831                                    // First worker errors (missing request id header).
832                                    0 => {
833                                        first_error_served.notify_one();
834                                        let res = Response::builder()
835                                            .status(StatusCode::OK)
836                                            .header("lambda-runtime-deadline-ms", "1542409706888")
837                                            .body(Full::new(Bytes::from_static(b"{}")))
838                                            .unwrap();
839                                        return Ok::<_, Infallible>(res);
840                                    }
841                                    // Second worker should keep running and process an invocation, even if another worker errors.
842                                    1 => {
843                                        first_error_served.notified().await;
844                                        let res = Response::builder()
845                                            .status(StatusCode::OK)
846                                            .header("content-type", "application/json")
847                                            .header("lambda-runtime-aws-request-id", "good-request")
848                                            .header("lambda-runtime-deadline-ms", "1542409706888")
849                                            .body(Full::new(Bytes::from_static(b"{}")))
850                                            .unwrap();
851                                        return Ok::<_, Infallible>(res);
852                                    }
853                                    // Finally, error the remaining worker so the runtime can terminate and the test can assert behavior.
854                                    2 => {
855                                        let res = Response::builder()
856                                            .status(StatusCode::OK)
857                                            .header("lambda-runtime-deadline-ms", "1542409706888")
858                                            .body(Full::new(Bytes::from_static(b"{}")))
859                                            .unwrap();
860                                        return Ok::<_, Infallible>(res);
861                                    }
862                                    _ => {
863                                        let res = Response::builder()
864                                            .status(StatusCode::NO_CONTENT)
865                                            .body(Full::new(Bytes::new()))
866                                            .unwrap();
867                                        return Ok::<_, Infallible>(res);
868                                    }
869                                }
870                            }
871
872                            if method == Method::POST && path.ends_with("/response") {
873                                response_calls.fetch_add(1, Ordering::SeqCst);
874                                let res = Response::builder()
875                                    .status(StatusCode::OK)
876                                    .body(Full::new(Bytes::new()))
877                                    .unwrap();
878                                return Ok::<_, Infallible>(res);
879                            }
880
881                            let res = Response::builder()
882                                .status(StatusCode::NOT_FOUND)
883                                .body(Full::new(Bytes::new()))
884                                .unwrap();
885                            Ok::<_, Infallible>(res)
886                        }
887                    });
888
889                    let io = TokioIo::new(tcp);
890                    tokio::spawn(async move {
891                        if let Err(err) = ServerBuilder::new(TokioExecutor::new())
892                            .serve_connection(io, service)
893                            .await
894                        {
895                            eprintln!("Error serving connection: {err:?}");
896                        }
897                    });
898                }
899            })
900        };
901
902        async fn func(event: crate::LambdaEvent<serde_json::Value>) -> Result<serde_json::Value, Error> {
903            Ok(event.payload)
904        }
905
906        let handler = crate::service_fn(func);
907        let client = Arc::new(Client::builder().with_endpoint(base).build()?);
908        let runtime = Runtime {
909            client: client.clone(),
910            config: Arc::new(Config {
911                function_name: "test_fn".to_string(),
912                memory: 128,
913                version: "1".to_string(),
914                log_stream: "test_stream".to_string(),
915                log_group: "test_log".to_string(),
916            }),
917            service: wrap_handler(handler, client),
918            concurrency_limit: 2,
919        };
920
921        let res = tokio::time::timeout(Duration::from_secs(2), runtime.run_concurrent()).await;
922        assert!(res.is_ok(), "run_concurrent timed out");
923        assert!(
924            res.unwrap().is_err(),
925            "expected runtime to terminate once all workers crashed"
926        );
927
928        assert_eq!(
929            response_calls.load(Ordering::SeqCst),
930            1,
931            "expected remaining worker to keep running after a worker crash"
932        );
933
934        server_handle.abort();
935        Ok(())
936    }
937
938    #[cfg(feature = "concurrency-tokio")]
939    // Must be current-thread (the default) so the thread-local tracing
940    // subscriber set via `set_default` propagates to spawned tasks.
941    #[tokio::test]
942    async fn test_concurrent_structured_logging_isolation() -> Result<(), Error> {
943        use std::collections::HashSet;
944        use tracing::info;
945        use tracing_capture::{CaptureLayer, SharedStorage};
946        use tracing_subscriber::layer::SubscriberExt;
947
948        let storage = SharedStorage::default();
949        let subscriber = tracing_subscriber::registry().with(CaptureLayer::new(&storage));
950        let _guard = tracing::subscriber::set_default(subscriber);
951
952        let request_count = Arc::new(AtomicUsize::new(0));
953        let done = Arc::new(tokio::sync::Notify::new());
954        let listener = TcpListener::bind("127.0.0.1:0").await?;
955        let addr = listener.local_addr()?;
956        let base: http::Uri = format!("http://{addr}").parse()?;
957
958        let server_handle = {
959            let request_count = request_count.clone();
960            let done = done.clone();
961            tokio::spawn(async move {
962                loop {
963                    let (tcp, _) = match listener.accept().await {
964                        Ok(v) => v,
965                        Err(_) => return,
966                    };
967
968                    let request_count = request_count.clone();
969                    let done = done.clone();
970                    let service = service_fn(move |req: Request<Incoming>| {
971                        let request_count = request_count.clone();
972                        let done = done.clone();
973                        async move {
974                            let (parts, body) = req.into_parts();
975                            if parts.method == Method::POST {
976                                let _ = body.collect().await;
977                            }
978
979                            if parts.method == Method::GET && parts.uri.path() == "/2018-06-01/runtime/invocation/next"
980                            {
981                                let count = request_count.fetch_add(1, Ordering::SeqCst);
982                                if count < 300 {
983                                    let request_id = format!("test-request-{}", count + 1);
984                                    let res = Response::builder()
985                                        .status(StatusCode::OK)
986                                        .header("lambda-runtime-aws-request-id", &request_id)
987                                        .header("lambda-runtime-deadline-ms", "9999999999999")
988                                        .body(Full::new(Bytes::from_static(b"{}")))
989                                        .unwrap();
990                                    return Ok::<_, Infallible>(res);
991                                } else {
992                                    done.notify_one();
993                                    let res = Response::builder()
994                                        .status(StatusCode::NO_CONTENT)
995                                        .body(Full::new(Bytes::new()))
996                                        .unwrap();
997                                    return Ok::<_, Infallible>(res);
998                                }
999                            }
1000
1001                            if parts.method == Method::POST && parts.uri.path().contains("/response") {
1002                                let res = Response::builder()
1003                                    .status(StatusCode::OK)
1004                                    .body(Full::new(Bytes::new()))
1005                                    .unwrap();
1006                                return Ok::<_, Infallible>(res);
1007                            }
1008
1009                            let res = Response::builder()
1010                                .status(StatusCode::NOT_FOUND)
1011                                .body(Full::new(Bytes::new()))
1012                                .unwrap();
1013                            Ok::<_, Infallible>(res)
1014                        }
1015                    });
1016
1017                    let io = TokioIo::new(tcp);
1018                    tokio::spawn(async move {
1019                        let _ = ServerBuilder::new(TokioExecutor::new())
1020                            .serve_connection(io, service)
1021                            .await;
1022                    });
1023                }
1024            })
1025        };
1026
1027        async fn test_handler(event: crate::LambdaEvent<serde_json::Value>) -> Result<(), Error> {
1028            let request_id = &event.context.request_id;
1029            info!(observed_request_id = request_id);
1030            tokio::time::sleep(Duration::from_millis(100)).await;
1031            Ok(())
1032        }
1033
1034        let handler = crate::service_fn(test_handler);
1035        let client = Arc::new(Client::builder().with_endpoint(base).build()?);
1036
1037        // Add tracing layer to capture span fields
1038        use crate::layers::trace::TracingLayer;
1039        use tower::ServiceBuilder;
1040        let service = ServiceBuilder::new()
1041            .layer(TracingLayer::new())
1042            .service(wrap_handler(handler, client.clone()));
1043
1044        let runtime = Runtime {
1045            client: client.clone(),
1046            config: Arc::new(Config {
1047                function_name: "test_fn".to_string(),
1048                memory: 128,
1049                version: "1".to_string(),
1050                log_stream: "test_stream".to_string(),
1051                log_group: "test_log".to_string(),
1052            }),
1053            service,
1054            concurrency_limit: 3,
1055        };
1056
1057        let runtime_handle = tokio::spawn(async move { runtime.run_concurrent().await });
1058
1059        done.notified().await;
1060        // Give handlers time to complete after server signals done
1061        tokio::time::sleep(Duration::from_millis(500)).await;
1062
1063        runtime_handle.abort();
1064        server_handle.abort();
1065
1066        let storage = storage.lock();
1067        let events: Vec<_> = storage
1068            .all_events()
1069            .filter(|e| e.value("observed_request_id").is_some())
1070            .collect();
1071
1072        assert!(
1073            events.len() >= 300,
1074            "Should have at least 300 log entries, got {}",
1075            events.len()
1076        );
1077
1078        let mut seen_ids = HashSet::new();
1079        for event in &events {
1080            let observed_id = event["observed_request_id"].as_str().unwrap();
1081
1082            // Find the parent "Lambda runtime invoke" span and get its requestId
1083            let span_request_id = event
1084                .ancestors()
1085                .find(|s| s.metadata().name() == "Lambda runtime invoke")
1086                .and_then(|s| s.value("requestId"))
1087                .and_then(|v| v.as_str())
1088                .expect("Event should have a Lambda runtime invoke ancestor with requestId");
1089
1090            assert!(
1091                observed_id.starts_with("test-request-"),
1092                "Request ID should match pattern: {}",
1093                observed_id
1094            );
1095            assert!(
1096                seen_ids.insert(observed_id.to_string()),
1097                "Request ID should be unique: {}",
1098                observed_id
1099            );
1100
1101            // Verify span request ID matches logged request ID
1102            assert_eq!(
1103                observed_id, span_request_id,
1104                "Span request ID should match logged request ID: span={}, logged={}",
1105                span_request_id, observed_id
1106            );
1107        }
1108
1109        Ok(())
1110    }
1111}