1use crate::{
2 layers::{CatchPanicService, RuntimeApiClientService, RuntimeApiResponseService},
3 requests::{IntoRequest, NextEventRequest},
4 types::{invoke_request_id, IntoFunctionResponse, LambdaEvent},
5 Config, Context, Diagnostic,
6};
7#[cfg(feature = "concurrency-tokio")]
8use futures::stream::FuturesUnordered;
9use http_body_util::BodyExt;
10use lambda_runtime_api_client::{BoxError, Client as ApiClient};
11use serde::{Deserialize, Serialize};
12#[cfg(feature = "concurrency-tokio")]
13use std::fmt;
14use std::{env, fmt::Debug, future::Future, sync::Arc};
15use tokio_stream::{Stream, StreamExt};
16use tower::{Layer, Service, ServiceExt};
17use tracing::trace;
18#[cfg(feature = "concurrency-tokio")]
19use tracing::{debug, error, info_span, warn, Instrument};
20
21pub struct LambdaInvocation {
25 pub parts: http::response::Parts,
27 pub body: bytes::Bytes,
29 pub context: Context,
31}
32
33pub struct Runtime<S> {
59 service: S,
60 config: Arc<Config>,
61 client: Arc<ApiClient>,
62 concurrency_limit: u32,
63}
64
65impl<F, EventPayload, Response, BufferedResponse, StreamingResponse, StreamItem, StreamError>
66 Runtime<
67 RuntimeApiClientService<
68 RuntimeApiResponseService<
69 CatchPanicService<'_, F>,
70 EventPayload,
71 Response,
72 BufferedResponse,
73 StreamingResponse,
74 StreamItem,
75 StreamError,
76 >,
77 >,
78 >
79where
80 F: Service<LambdaEvent<EventPayload>, Response = Response>,
81 F::Future: Future<Output = Result<Response, F::Error>>,
82 F::Error: Into<Diagnostic> + Debug,
83 EventPayload: for<'de> Deserialize<'de>,
84 Response: IntoFunctionResponse<BufferedResponse, StreamingResponse>,
85 BufferedResponse: Serialize,
86 StreamingResponse: Stream<Item = Result<StreamItem, StreamError>> + Unpin + Send + 'static,
87 StreamItem: Into<bytes::Bytes> + Send,
88 StreamError: Into<BoxError> + Send + Debug,
89{
90 pub fn new(handler: F) -> Self {
107 trace!("Loading config from env");
108 let config = Arc::new(Config::from_env());
109 let concurrency_limit = max_concurrency_from_env().unwrap_or(1).max(1);
110 let pool_size = concurrency_limit as usize;
112 let client = Arc::new(
113 ApiClient::builder()
114 .with_pool_size(pool_size)
115 .build()
116 .expect("Unable to create a runtime client"),
117 );
118 Self {
119 service: wrap_handler(handler, client.clone()),
120 config,
121 client,
122 concurrency_limit,
123 }
124 }
125}
126
127impl<S> Runtime<S> {
128 pub fn layer<L>(self, layer: L) -> Runtime<L::Service>
151 where
152 L: Layer<S>,
153 L::Service: Service<LambdaInvocation, Response = (), Error = BoxError>,
154 {
155 Runtime {
156 client: self.client,
157 config: self.config,
158 service: layer.layer(self.service),
159 concurrency_limit: self.concurrency_limit,
160 }
161 }
162}
163
164#[cfg(feature = "concurrency-tokio")]
165impl<S> Runtime<S>
166where
167 S: Service<LambdaInvocation, Response = (), Error = BoxError> + Clone + Send + 'static,
168 S::Future: Send,
169{
170 #[cfg_attr(docsrs, doc(cfg(feature = "concurrency-tokio")))]
183 pub async fn run_concurrent(self) -> Result<(), BoxError> {
184 if tokio::runtime::Handle::try_current().is_err() {
185 panic!("`run_concurrent` must be called from within a Tokio runtime");
186 }
187
188 if self.concurrency_limit > 1 {
189 trace!("Concurrent mode: _X_AMZN_TRACE_ID is not set; use context.xray_trace_id");
190 Self::run_concurrent_inner(self.service, self.config, self.client, self.concurrency_limit).await
191 } else {
192 debug!(
193 "Concurrent polling disabled (AWS_LAMBDA_MAX_CONCURRENCY unset or <= 1); falling back to sequential polling"
194 );
195 let incoming = incoming(&self.client);
196 Self::run_with_incoming(self.service, self.config, incoming).await
197 }
198 }
199
200 async fn run_concurrent_inner(
202 service: S,
203 config: Arc<Config>,
204 client: Arc<ApiClient>,
205 concurrency_limit: u32,
206 ) -> Result<(), BoxError> {
207 let limit = concurrency_limit as usize;
208
209 let mut workers: FuturesUnordered<tokio::task::JoinHandle<(tokio::task::Id, Result<(), BoxError>)>> =
212 FuturesUnordered::new();
213 let spawn_worker = |service: S, config: Arc<Config>, client: Arc<ApiClient>| {
214 tokio::spawn(async move {
215 let task_id = tokio::task::id();
216 let result = concurrent_worker_loop(service, config, client).await;
217 (task_id, result)
218 })
219 };
220 for _ in 1..limit {
222 workers.push(spawn_worker(service.clone(), config.clone(), client.clone()));
223 }
224 workers.push(spawn_worker(service, config, client));
225
226 let mut errors: Vec<WorkerError> = Vec::new();
236 let mut remaining_workers = limit;
237 while let Some(result) = futures::StreamExt::next(&mut workers).await {
238 remaining_workers = remaining_workers.saturating_sub(1);
239 match result {
240 Ok((task_id, Ok(()))) => {
241 error!(
244 task_id = %task_id,
245 remaining_workers,
246 "Concurrent worker exited cleanly (unexpected - loop should run forever)"
247 );
248 errors.push(WorkerError::CleanExit(task_id));
249 }
250 Ok((task_id, Err(err))) => {
251 error!(
252 task_id = %task_id,
253 error = %err,
254 remaining_workers,
255 "Concurrent worker exited with error"
256 );
257 errors.push(WorkerError::Failure(task_id, err));
258 }
259 Err(join_err) => {
260 let task_id = join_err.id();
261 let err: BoxError = Box::new(join_err);
262 error!(
263 task_id = %task_id,
264 error = %err,
265 remaining_workers,
266 "Concurrent worker panicked"
267 );
268 errors.push(WorkerError::Failure(task_id, err));
269 }
270 }
271 }
272
273 match errors.len() {
274 0 => Ok(()),
275 _ => Err(Box::new(ConcurrentWorkerErrors { errors })),
276 }
277 }
278}
279
280#[cfg(feature = "concurrency-tokio")]
281#[derive(Debug)]
282enum WorkerError {
283 CleanExit(tokio::task::Id),
284 Failure(tokio::task::Id, BoxError),
285}
286
287#[cfg(feature = "concurrency-tokio")]
288#[derive(Debug)]
289struct ConcurrentWorkerErrors {
290 errors: Vec<WorkerError>,
291}
292
293#[cfg(feature = "concurrency-tokio")]
294#[derive(Serialize)]
295struct ConcurrentWorkerErrorsPayload<'a> {
296 message: &'a str,
297 #[serde(skip_serializing_if = "Vec::is_empty")]
298 clean: Vec<String>,
299 #[serde(skip_serializing_if = "Vec::is_empty")]
300 failures: Vec<WorkerFailurePayload>,
301}
302
303#[cfg(feature = "concurrency-tokio")]
304#[derive(Serialize)]
305struct WorkerFailurePayload {
306 id: String,
307 err: String,
308}
309
310#[cfg(feature = "concurrency-tokio")]
311impl fmt::Display for ConcurrentWorkerErrors {
312 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
313 let mut clean = Vec::new();
314 let mut failures = Vec::new();
315 for error in &self.errors {
316 match error {
317 WorkerError::CleanExit(task_id) => clean.push(task_id),
318 WorkerError::Failure(task_id, err) => failures.push((task_id, err)),
319 }
320 }
321
322 let clean_ids: Vec<String> = clean.iter().map(|task_id| task_id.to_string()).collect();
323 let failure_entries: Vec<WorkerFailurePayload> = failures
324 .iter()
325 .map(|(task_id, err)| WorkerFailurePayload {
326 id: task_id.to_string(),
327 err: err.to_string(),
328 })
329 .collect();
330
331 let message = if failures.is_empty() && !clean.is_empty() {
332 "all concurrent workers exited cleanly (unexpected - loop should run forever)"
333 } else {
334 "concurrent workers exited unexpectedly"
335 };
336
337 let payload = ConcurrentWorkerErrorsPayload {
338 message,
339 clean: clean_ids,
340 failures: failure_entries,
341 };
342 let json = serde_json::to_string(&payload).map_err(|_| fmt::Error)?;
343 write!(f, "{json}")
344 }
345}
346
347#[cfg(feature = "concurrency-tokio")]
348impl std::error::Error for ConcurrentWorkerErrors {}
349
350impl<S> Runtime<S>
351where
352 S: Service<LambdaInvocation, Response = (), Error = BoxError>,
353{
354 pub async fn run(self) -> Result<(), BoxError> {
365 if let Some(raw) = concurrency_env_value() {
366 log_or_print!(
367 tracing: tracing::warn!(
368 "AWS_LAMBDA_MAX_CONCURRENCY is set to '{raw}', but the concurrency-tokio feature is not enabled; running sequentially",
369 ),
370 fallback: eprintln!("AWS_LAMBDA_MAX_CONCURRENCY is set to '{raw}', but the concurrency-tokio feature is not enabled; running sequentially")
371 );
372 }
373 let incoming = incoming(&self.client);
374 Self::run_with_incoming(self.service, self.config, incoming).await
375 }
376
377 pub(crate) async fn run_with_incoming(
380 mut service: S,
381 config: Arc<Config>,
382 incoming: impl Stream<Item = Result<http::Response<hyper::body::Incoming>, BoxError>> + Send,
383 ) -> Result<(), BoxError> {
384 tokio::pin!(incoming);
385 while let Some(next_event_response) = incoming.next().await {
386 trace!("New event arrived (run loop)");
387 let event = next_event_response?;
388 process_invocation(&mut service, &config, event, true).await?;
389 }
390 Ok(())
391 }
392}
393
394#[allow(clippy::type_complexity)]
397fn wrap_handler<'a, F, EventPayload, Response, BufferedResponse, StreamingResponse, StreamItem, StreamError>(
398 handler: F,
399 client: Arc<ApiClient>,
400) -> RuntimeApiClientService<
401 RuntimeApiResponseService<
402 CatchPanicService<'a, F>,
403 EventPayload,
404 Response,
405 BufferedResponse,
406 StreamingResponse,
407 StreamItem,
408 StreamError,
409 >,
410>
411where
412 F: Service<LambdaEvent<EventPayload>, Response = Response>,
413 F::Future: Future<Output = Result<Response, F::Error>>,
414 F::Error: Into<Diagnostic> + Debug,
415 EventPayload: for<'de> Deserialize<'de>,
416 Response: IntoFunctionResponse<BufferedResponse, StreamingResponse>,
417 BufferedResponse: Serialize,
418 StreamingResponse: Stream<Item = Result<StreamItem, StreamError>> + Unpin + Send + 'static,
419 StreamItem: Into<bytes::Bytes> + Send,
420 StreamError: Into<BoxError> + Send + Debug,
421{
422 let safe_service = CatchPanicService::new(handler);
423 let response_service = RuntimeApiResponseService::new(safe_service);
424 RuntimeApiClientService::new(response_service, client)
425}
426
427fn incoming(
428 client: &ApiClient,
429) -> impl Stream<Item = Result<http::Response<hyper::body::Incoming>, BoxError>> + Send + '_ {
430 async_stream::stream! {
431 loop {
432 trace!("Waiting for next event (incoming loop)");
433 let req = NextEventRequest.into_req().expect("Unable to construct request");
434 let res = client.call(req).await;
435 yield res;
436 }
437 }
438}
439
440#[cfg(feature = "concurrency-tokio")]
442async fn next_event_future(client: &ApiClient) -> Result<http::Response<hyper::body::Incoming>, BoxError> {
443 let req = NextEventRequest.into_req()?;
444 client.call(req).await
445}
446
447fn max_concurrency_from_env() -> Option<u32> {
448 env::var("AWS_LAMBDA_MAX_CONCURRENCY")
449 .ok()
450 .and_then(|v| v.parse::<u32>().ok())
451 .filter(|&c| c > 0)
452}
453
454fn concurrency_env_value() -> Option<String> {
455 env::var("AWS_LAMBDA_MAX_CONCURRENCY").ok()
456}
457
458#[cfg(feature = "concurrency-tokio")]
459async fn concurrent_worker_loop<S>(mut service: S, config: Arc<Config>, client: Arc<ApiClient>) -> Result<(), BoxError>
460where
461 S: Service<LambdaInvocation, Response = (), Error = BoxError>,
462 S::Future: Send,
463{
464 let task_id = tokio::task::id();
465 let span = info_span!("worker", task_id = %task_id);
466 loop {
467 let event = match next_event_future(client.as_ref()).instrument(span.clone()).await {
468 Ok(event) => event,
469 Err(e) => {
470 warn!(task_id = %task_id, error = %e, "Error polling /next, retrying");
471 continue;
472 }
473 };
474
475 process_invocation(&mut service, &config, event, false)
476 .instrument(span.clone())
477 .await?;
478 }
479}
480
481async fn process_invocation<S>(
482 service: &mut S,
483 config: &Arc<Config>,
484 event: http::Response<hyper::body::Incoming>,
485 set_amzn_trace_env: bool,
486) -> Result<(), BoxError>
487where
488 S: Service<LambdaInvocation, Response = (), Error = BoxError>,
489{
490 let (parts, incoming) = event.into_parts();
491
492 #[cfg(debug_assertions)]
493 if parts.status == http::StatusCode::NO_CONTENT {
494 return Ok(());
498 }
499
500 let body = incoming.collect().await?.to_bytes();
503 let context = Context::new(invoke_request_id(&parts.headers)?, config.clone(), &parts.headers)?;
504 let invocation = LambdaInvocation { parts, body, context };
505
506 if set_amzn_trace_env {
507 amzn_trace_env(&invocation.context);
509 }
510
511 let ready = service.ready().await?;
513
514 ready.call(invocation).await?;
516 Ok(())
517}
518
519fn amzn_trace_env(ctx: &Context) {
520 match &ctx.xray_trace_id {
521 Some(trace_id) => env::set_var("_X_AMZN_TRACE_ID", trace_id),
522 None => env::remove_var("_X_AMZN_TRACE_ID"),
523 }
524}
525
526#[cfg(test)]
531mod endpoint_tests {
532 use super::{incoming, wrap_handler};
533 use crate::{
534 requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest},
535 Config, Diagnostic, Error, Runtime,
536 };
537 use bytes::Bytes;
538 use futures::future::BoxFuture;
539 use http::{HeaderValue, Method, Request, Response, StatusCode};
540 use http_body_util::{BodyExt, Full};
541 use httpmock::prelude::*;
542
543 use hyper::{body::Incoming, service::service_fn};
544 use hyper_util::{
545 rt::{tokio::TokioIo, TokioExecutor},
546 server::conn::auto::Builder as ServerBuilder,
547 };
548 use lambda_runtime_api_client::Client;
549 use std::{
550 convert::Infallible,
551 env,
552 sync::{
553 atomic::{AtomicUsize, Ordering},
554 Arc,
555 },
556 time::Duration,
557 };
558 use tokio::{net::TcpListener, sync::Notify};
559 use tokio_stream::StreamExt;
560
561 #[tokio::test]
562 async fn test_next_event() -> Result<(), Error> {
563 let server = MockServer::start();
564 let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9";
565 let deadline = "1542409706888";
566
567 let mock = server.mock(|when, then| {
568 when.method(GET).path("/2018-06-01/runtime/invocation/next");
569 then.status(200)
570 .header("content-type", "application/json")
571 .header("lambda-runtime-aws-request-id", request_id)
572 .header("lambda-runtime-deadline-ms", deadline)
573 .body("{}");
574 });
575
576 let base = server.base_url().parse().expect("Invalid mock server Uri");
577 let client = Client::builder().with_endpoint(base).build()?;
578
579 let req = NextEventRequest.into_req()?;
580 let rsp = client.call(req).await.expect("Unable to send request");
581
582 mock.assert_async().await;
583 assert_eq!(rsp.status(), StatusCode::OK);
584 assert_eq!(
585 rsp.headers()["lambda-runtime-aws-request-id"],
586 &HeaderValue::from_static(request_id)
587 );
588 assert_eq!(
589 rsp.headers()["lambda-runtime-deadline-ms"],
590 &HeaderValue::from_static(deadline)
591 );
592
593 let body = rsp.into_body().collect().await?.to_bytes();
594 assert_eq!("{}", std::str::from_utf8(&body)?);
595 Ok(())
596 }
597
598 #[tokio::test]
599 async fn test_ok_response() -> Result<(), Error> {
600 let server = MockServer::start();
601
602 let mock = server.mock(|when, then| {
603 when.method(POST)
604 .path("/2018-06-01/runtime/invocation/156cb537-e2d4-11e8-9b34-d36013741fb9/response")
605 .body("\"{}\"");
606 then.status(200).body("");
607 });
608
609 let base = server.base_url().parse().expect("Invalid mock server Uri");
610 let client = Client::builder().with_endpoint(base).build()?;
611
612 let req = EventCompletionRequest::new("156cb537-e2d4-11e8-9b34-d36013741fb9", "{}");
613 let req = req.into_req()?;
614
615 let rsp = client.call(req).await?;
616
617 mock.assert_async().await;
618 assert_eq!(rsp.status(), StatusCode::OK);
619 Ok(())
620 }
621
622 #[tokio::test]
623 async fn test_error_response() -> Result<(), Error> {
624 let diagnostic = Diagnostic {
625 error_type: "InvalidEventDataError".into(),
626 error_message: "Error parsing event data".into(),
627 };
628 let body = serde_json::to_string(&diagnostic)?;
629
630 let server = MockServer::start();
631 let mock = server.mock(|when, then| {
632 when.method(POST)
633 .path("/2018-06-01/runtime/invocation/156cb537-e2d4-11e8-9b34-d36013741fb9/error")
634 .header("lambda-runtime-function-error-type", "unhandled")
635 .body(body);
636 then.status(200).body("");
637 });
638
639 let base = server.base_url().parse().expect("Invalid mock server Uri");
640 let client = Client::builder().with_endpoint(base).build()?;
641
642 let req = EventErrorRequest {
643 request_id: "156cb537-e2d4-11e8-9b34-d36013741fb9",
644 diagnostic,
645 };
646 let req = req.into_req()?;
647 let rsp = client.call(req).await?;
648
649 mock.assert_async().await;
650 assert_eq!(rsp.status(), StatusCode::OK);
651 Ok(())
652 }
653
654 #[tokio::test]
655 async fn successful_end_to_end_run() -> Result<(), Error> {
656 let server = MockServer::start();
657 let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9";
658 let deadline = "1542409706888";
659
660 let next_request = server.mock(|when, then| {
661 when.method(GET).path("/2018-06-01/runtime/invocation/next");
662 then.status(200)
663 .header("content-type", "application/json")
664 .header("lambda-runtime-aws-request-id", request_id)
665 .header("lambda-runtime-deadline-ms", deadline)
666 .body("{}");
667 });
668 let next_response = server.mock(|when, then| {
669 when.method(POST)
670 .path(format!("/2018-06-01/runtime/invocation/{request_id}/response"))
671 .body("{}");
672 then.status(200).body("");
673 });
674
675 let base = server.base_url().parse().expect("Invalid mock server Uri");
676 let client = Client::builder().with_endpoint(base).build()?;
677
678 async fn func(event: crate::LambdaEvent<serde_json::Value>) -> Result<serde_json::Value, Error> {
679 let (event, _) = event.into_parts();
680 Ok(event)
681 }
682 let f = crate::service_fn(func);
683
684 if env::var("AWS_LAMBDA_RUNTIME_API").is_err() {
686 env::set_var("AWS_LAMBDA_RUNTIME_API", server.base_url());
687 }
688 if env::var("AWS_LAMBDA_FUNCTION_NAME").is_err() {
689 env::set_var("AWS_LAMBDA_FUNCTION_NAME", "test_fn");
690 }
691 if env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE").is_err() {
692 env::set_var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "128");
693 }
694 if env::var("AWS_LAMBDA_FUNCTION_VERSION").is_err() {
695 env::set_var("AWS_LAMBDA_FUNCTION_VERSION", "1");
696 }
697 if env::var("AWS_LAMBDA_LOG_STREAM_NAME").is_err() {
698 env::set_var("AWS_LAMBDA_LOG_STREAM_NAME", "test_stream");
699 }
700 if env::var("AWS_LAMBDA_LOG_GROUP_NAME").is_err() {
701 env::set_var("AWS_LAMBDA_LOG_GROUP_NAME", "test_log");
702 }
703 let config = Config::from_env();
704
705 let client = Arc::new(client);
706 let runtime = Runtime {
707 client: client.clone(),
708 config: Arc::new(config),
709 service: wrap_handler(f, client),
710 concurrency_limit: 1,
711 };
712 let client = &runtime.client;
713 let incoming = incoming(client).take(1);
714 Runtime::run_with_incoming(runtime.service, runtime.config, incoming).await?;
715
716 next_request.assert_async().await;
717 next_response.assert_async().await;
718 Ok(())
719 }
720
721 async fn run_panicking_handler<F>(func: F) -> Result<(), Error>
722 where
723 F: FnMut(crate::LambdaEvent<serde_json::Value>) -> BoxFuture<'static, Result<serde_json::Value, Error>>
724 + Send
725 + 'static,
726 {
727 let server = MockServer::start();
728 let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9";
729 let deadline = "1542409706888";
730
731 let next_request = server.mock(|when, then| {
732 when.method(GET).path("/2018-06-01/runtime/invocation/next");
733 then.status(200)
734 .header("content-type", "application/json")
735 .header("lambda-runtime-aws-request-id", request_id)
736 .header("lambda-runtime-deadline-ms", deadline)
737 .body("{}");
738 });
739
740 let next_response = server.mock(|when, then| {
741 when.method(POST)
742 .path(format!("/2018-06-01/runtime/invocation/{request_id}/error"))
743 .header("lambda-runtime-function-error-type", "unhandled");
744 then.status(200).body("");
745 });
746
747 let base = server.base_url().parse().expect("Invalid mock server Uri");
748 let client = Client::builder().with_endpoint(base).build()?;
749
750 let f = crate::service_fn(func);
751
752 let config = Arc::new(Config {
753 function_name: "test_fn".to_string(),
754 memory: 128,
755 version: "1".to_string(),
756 log_stream: "test_stream".to_string(),
757 log_group: "test_log".to_string(),
758 });
759
760 let client = Arc::new(client);
761 let runtime = Runtime {
762 client: client.clone(),
763 config,
764 service: wrap_handler(f, client),
765 concurrency_limit: 1,
766 };
767 let client = &runtime.client;
768 let incoming = incoming(client).take(1);
769 Runtime::run_with_incoming(runtime.service, runtime.config, incoming).await?;
770
771 next_request.assert_async().await;
772 next_response.assert_async().await;
773 Ok(())
774 }
775
776 #[tokio::test]
777 async fn panic_in_async_run() -> Result<(), Error> {
778 run_panicking_handler(|_| Box::pin(async { panic!("This is intentionally here") })).await
779 }
780
781 #[tokio::test]
782 async fn panic_outside_async_run() -> Result<(), Error> {
783 run_panicking_handler(|_| {
784 panic!("This is intentionally here");
785 })
786 .await
787 }
788
789 #[cfg(feature = "concurrency-tokio")]
790 #[tokio::test]
791 async fn concurrent_worker_crash_does_not_stop_other_workers() -> Result<(), Error> {
792 let next_calls = Arc::new(AtomicUsize::new(0));
793 let response_calls = Arc::new(AtomicUsize::new(0));
794 let first_error_served = Arc::new(Notify::new());
795
796 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
797 let addr = listener.local_addr().unwrap();
798 let base: http::Uri = format!("http://{addr}").parse().unwrap();
799
800 let server_handle = {
801 let next_calls = next_calls.clone();
802 let response_calls = response_calls.clone();
803 let first_error_served = first_error_served.clone();
804 tokio::spawn(async move {
805 loop {
806 let (tcp, _) = match listener.accept().await {
807 Ok(v) => v,
808 Err(_) => return,
809 };
810
811 let next_calls = next_calls.clone();
812 let response_calls = response_calls.clone();
813 let first_error_served = first_error_served.clone();
814 let service = service_fn(move |req: Request<Incoming>| {
815 let next_calls = next_calls.clone();
816 let response_calls = response_calls.clone();
817 let first_error_served = first_error_served.clone();
818 async move {
819 let (parts, body) = req.into_parts();
820 let method = parts.method;
821 let path = parts.uri.path().to_string();
822
823 if method == Method::POST {
824 let _ = body.collect().await;
826 }
827
828 if method == Method::GET && path == "/2018-06-01/runtime/invocation/next" {
829 let call_index = next_calls.fetch_add(1, Ordering::SeqCst);
830 match call_index {
831 0 => {
833 first_error_served.notify_one();
834 let res = Response::builder()
835 .status(StatusCode::OK)
836 .header("lambda-runtime-deadline-ms", "1542409706888")
837 .body(Full::new(Bytes::from_static(b"{}")))
838 .unwrap();
839 return Ok::<_, Infallible>(res);
840 }
841 1 => {
843 first_error_served.notified().await;
844 let res = Response::builder()
845 .status(StatusCode::OK)
846 .header("content-type", "application/json")
847 .header("lambda-runtime-aws-request-id", "good-request")
848 .header("lambda-runtime-deadline-ms", "1542409706888")
849 .body(Full::new(Bytes::from_static(b"{}")))
850 .unwrap();
851 return Ok::<_, Infallible>(res);
852 }
853 2 => {
855 let res = Response::builder()
856 .status(StatusCode::OK)
857 .header("lambda-runtime-deadline-ms", "1542409706888")
858 .body(Full::new(Bytes::from_static(b"{}")))
859 .unwrap();
860 return Ok::<_, Infallible>(res);
861 }
862 _ => {
863 let res = Response::builder()
864 .status(StatusCode::NO_CONTENT)
865 .body(Full::new(Bytes::new()))
866 .unwrap();
867 return Ok::<_, Infallible>(res);
868 }
869 }
870 }
871
872 if method == Method::POST && path.ends_with("/response") {
873 response_calls.fetch_add(1, Ordering::SeqCst);
874 let res = Response::builder()
875 .status(StatusCode::OK)
876 .body(Full::new(Bytes::new()))
877 .unwrap();
878 return Ok::<_, Infallible>(res);
879 }
880
881 let res = Response::builder()
882 .status(StatusCode::NOT_FOUND)
883 .body(Full::new(Bytes::new()))
884 .unwrap();
885 Ok::<_, Infallible>(res)
886 }
887 });
888
889 let io = TokioIo::new(tcp);
890 tokio::spawn(async move {
891 if let Err(err) = ServerBuilder::new(TokioExecutor::new())
892 .serve_connection(io, service)
893 .await
894 {
895 eprintln!("Error serving connection: {err:?}");
896 }
897 });
898 }
899 })
900 };
901
902 async fn func(event: crate::LambdaEvent<serde_json::Value>) -> Result<serde_json::Value, Error> {
903 Ok(event.payload)
904 }
905
906 let handler = crate::service_fn(func);
907 let client = Arc::new(Client::builder().with_endpoint(base).build()?);
908 let runtime = Runtime {
909 client: client.clone(),
910 config: Arc::new(Config {
911 function_name: "test_fn".to_string(),
912 memory: 128,
913 version: "1".to_string(),
914 log_stream: "test_stream".to_string(),
915 log_group: "test_log".to_string(),
916 }),
917 service: wrap_handler(handler, client),
918 concurrency_limit: 2,
919 };
920
921 let res = tokio::time::timeout(Duration::from_secs(2), runtime.run_concurrent()).await;
922 assert!(res.is_ok(), "run_concurrent timed out");
923 assert!(
924 res.unwrap().is_err(),
925 "expected runtime to terminate once all workers crashed"
926 );
927
928 assert_eq!(
929 response_calls.load(Ordering::SeqCst),
930 1,
931 "expected remaining worker to keep running after a worker crash"
932 );
933
934 server_handle.abort();
935 Ok(())
936 }
937
938 #[cfg(feature = "concurrency-tokio")]
939 #[tokio::test]
942 async fn test_concurrent_structured_logging_isolation() -> Result<(), Error> {
943 use std::collections::HashSet;
944 use tracing::info;
945 use tracing_capture::{CaptureLayer, SharedStorage};
946 use tracing_subscriber::layer::SubscriberExt;
947
948 let storage = SharedStorage::default();
949 let subscriber = tracing_subscriber::registry().with(CaptureLayer::new(&storage));
950 let _guard = tracing::subscriber::set_default(subscriber);
951
952 let request_count = Arc::new(AtomicUsize::new(0));
953 let done = Arc::new(tokio::sync::Notify::new());
954 let listener = TcpListener::bind("127.0.0.1:0").await?;
955 let addr = listener.local_addr()?;
956 let base: http::Uri = format!("http://{addr}").parse()?;
957
958 let server_handle = {
959 let request_count = request_count.clone();
960 let done = done.clone();
961 tokio::spawn(async move {
962 loop {
963 let (tcp, _) = match listener.accept().await {
964 Ok(v) => v,
965 Err(_) => return,
966 };
967
968 let request_count = request_count.clone();
969 let done = done.clone();
970 let service = service_fn(move |req: Request<Incoming>| {
971 let request_count = request_count.clone();
972 let done = done.clone();
973 async move {
974 let (parts, body) = req.into_parts();
975 if parts.method == Method::POST {
976 let _ = body.collect().await;
977 }
978
979 if parts.method == Method::GET && parts.uri.path() == "/2018-06-01/runtime/invocation/next"
980 {
981 let count = request_count.fetch_add(1, Ordering::SeqCst);
982 if count < 300 {
983 let request_id = format!("test-request-{}", count + 1);
984 let res = Response::builder()
985 .status(StatusCode::OK)
986 .header("lambda-runtime-aws-request-id", &request_id)
987 .header("lambda-runtime-deadline-ms", "9999999999999")
988 .body(Full::new(Bytes::from_static(b"{}")))
989 .unwrap();
990 return Ok::<_, Infallible>(res);
991 } else {
992 done.notify_one();
993 let res = Response::builder()
994 .status(StatusCode::NO_CONTENT)
995 .body(Full::new(Bytes::new()))
996 .unwrap();
997 return Ok::<_, Infallible>(res);
998 }
999 }
1000
1001 if parts.method == Method::POST && parts.uri.path().contains("/response") {
1002 let res = Response::builder()
1003 .status(StatusCode::OK)
1004 .body(Full::new(Bytes::new()))
1005 .unwrap();
1006 return Ok::<_, Infallible>(res);
1007 }
1008
1009 let res = Response::builder()
1010 .status(StatusCode::NOT_FOUND)
1011 .body(Full::new(Bytes::new()))
1012 .unwrap();
1013 Ok::<_, Infallible>(res)
1014 }
1015 });
1016
1017 let io = TokioIo::new(tcp);
1018 tokio::spawn(async move {
1019 let _ = ServerBuilder::new(TokioExecutor::new())
1020 .serve_connection(io, service)
1021 .await;
1022 });
1023 }
1024 })
1025 };
1026
1027 async fn test_handler(event: crate::LambdaEvent<serde_json::Value>) -> Result<(), Error> {
1028 let request_id = &event.context.request_id;
1029 info!(observed_request_id = request_id);
1030 tokio::time::sleep(Duration::from_millis(100)).await;
1031 Ok(())
1032 }
1033
1034 let handler = crate::service_fn(test_handler);
1035 let client = Arc::new(Client::builder().with_endpoint(base).build()?);
1036
1037 use crate::layers::trace::TracingLayer;
1039 use tower::ServiceBuilder;
1040 let service = ServiceBuilder::new()
1041 .layer(TracingLayer::new())
1042 .service(wrap_handler(handler, client.clone()));
1043
1044 let runtime = Runtime {
1045 client: client.clone(),
1046 config: Arc::new(Config {
1047 function_name: "test_fn".to_string(),
1048 memory: 128,
1049 version: "1".to_string(),
1050 log_stream: "test_stream".to_string(),
1051 log_group: "test_log".to_string(),
1052 }),
1053 service,
1054 concurrency_limit: 3,
1055 };
1056
1057 let runtime_handle = tokio::spawn(async move { runtime.run_concurrent().await });
1058
1059 done.notified().await;
1060 tokio::time::sleep(Duration::from_millis(500)).await;
1062
1063 runtime_handle.abort();
1064 server_handle.abort();
1065
1066 let storage = storage.lock();
1067 let events: Vec<_> = storage
1068 .all_events()
1069 .filter(|e| e.value("observed_request_id").is_some())
1070 .collect();
1071
1072 assert!(
1073 events.len() >= 300,
1074 "Should have at least 300 log entries, got {}",
1075 events.len()
1076 );
1077
1078 let mut seen_ids = HashSet::new();
1079 for event in &events {
1080 let observed_id = event["observed_request_id"].as_str().unwrap();
1081
1082 let span_request_id = event
1084 .ancestors()
1085 .find(|s| s.metadata().name() == "Lambda runtime invoke")
1086 .and_then(|s| s.value("requestId"))
1087 .and_then(|v| v.as_str())
1088 .expect("Event should have a Lambda runtime invoke ancestor with requestId");
1089
1090 assert!(
1091 observed_id.starts_with("test-request-"),
1092 "Request ID should match pattern: {}",
1093 observed_id
1094 );
1095 assert!(
1096 seen_ids.insert(observed_id.to_string()),
1097 "Request ID should be unique: {}",
1098 observed_id
1099 );
1100
1101 assert_eq!(
1103 observed_id, span_request_id,
1104 "Span request ID should match logged request ID: span={}, logged={}",
1105 span_request_id, observed_id
1106 );
1107 }
1108
1109 Ok(())
1110 }
1111}