lambda_runtime/
runtime.rs

1use crate::{
2    layers::{CatchPanicService, RuntimeApiClientService, RuntimeApiResponseService},
3    requests::{IntoRequest, NextEventRequest},
4    types::{invoke_request_id, IntoFunctionResponse, LambdaEvent},
5    Config, Context, Diagnostic,
6};
7use http_body_util::BodyExt;
8use lambda_runtime_api_client::{BoxError, Client as ApiClient};
9use serde::{Deserialize, Serialize};
10use std::{env, fmt::Debug, future::Future, sync::Arc};
11use tokio_stream::{Stream, StreamExt};
12use tower::{Layer, Service, ServiceExt};
13use tracing::trace;
14
15/* ----------------------------------------- INVOCATION ---------------------------------------- */
16
17/// A simple container that provides information about a single invocation of a Lambda function.
18pub struct LambdaInvocation {
19    /// The header of the request sent to invoke the Lambda function.
20    pub parts: http::response::Parts,
21    /// The body of the request sent to invoke the Lambda function.
22    pub body: bytes::Bytes,
23    /// The context of the Lambda invocation.
24    pub context: Context,
25}
26
27/* ------------------------------------------ RUNTIME ------------------------------------------ */
28
29/// Lambda runtime executing a handler function on incoming requests.
30///
31/// Middleware can be added to a runtime using the [Runtime::layer] method in order to execute
32/// logic prior to processing the incoming request and/or after the response has been sent back
33/// to the Lambda Runtime API.
34///
35/// # Example
36/// ```no_run
37/// use lambda_runtime::{Error, LambdaEvent, Runtime};
38/// use serde_json::Value;
39/// use tower::service_fn;
40///
41/// #[tokio::main]
42/// async fn main() -> Result<(), Error> {
43///     let func = service_fn(func);
44///     Runtime::new(func).run().await?;
45///     Ok(())
46/// }
47///
48/// async fn func(event: LambdaEvent<Value>) -> Result<Value, Error> {
49///     Ok(event.payload)
50/// }
51/// ````
52pub struct Runtime<S> {
53    service: S,
54    config: Arc<Config>,
55    client: Arc<ApiClient>,
56}
57
58impl<F, EventPayload, Response, BufferedResponse, StreamingResponse, StreamItem, StreamError>
59    Runtime<
60        RuntimeApiClientService<
61            RuntimeApiResponseService<
62                CatchPanicService<'_, F>,
63                EventPayload,
64                Response,
65                BufferedResponse,
66                StreamingResponse,
67                StreamItem,
68                StreamError,
69            >,
70        >,
71    >
72where
73    F: Service<LambdaEvent<EventPayload>, Response = Response>,
74    F::Future: Future<Output = Result<Response, F::Error>>,
75    F::Error: Into<Diagnostic> + Debug,
76    EventPayload: for<'de> Deserialize<'de>,
77    Response: IntoFunctionResponse<BufferedResponse, StreamingResponse>,
78    BufferedResponse: Serialize,
79    StreamingResponse: Stream<Item = Result<StreamItem, StreamError>> + Unpin + Send + 'static,
80    StreamItem: Into<bytes::Bytes> + Send,
81    StreamError: Into<BoxError> + Send + Debug,
82{
83    /// Create a new runtime that executes the provided handler for incoming requests.
84    ///
85    /// In order to start the runtime and poll for events on the [Lambda Runtime
86    /// APIs](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-api.html), you must call
87    /// [Runtime::run].
88    ///
89    /// Note that manually creating a [Runtime] does not add tracing to the executed handler
90    /// as is done by [super::run]. If you want to add the default tracing functionality, call
91    /// [Runtime::layer] with a [super::layers::TracingLayer].
92    pub fn new(handler: F) -> Self {
93        trace!("Loading config from env");
94        let config = Arc::new(Config::from_env());
95        let client = Arc::new(ApiClient::builder().build().expect("Unable to create a runtime client"));
96        Self {
97            service: wrap_handler(handler, client.clone()),
98            config,
99            client,
100        }
101    }
102}
103
104impl<S> Runtime<S> {
105    /// Add a new layer to this runtime. For an incoming request, this layer will be executed
106    /// before any layer that has been added prior.
107    ///
108    /// # Example
109    /// ```no_run
110    /// use lambda_runtime::{layers, Error, LambdaEvent, Runtime};
111    /// use serde_json::Value;
112    /// use tower::service_fn;
113    ///
114    /// #[tokio::main]
115    /// async fn main() -> Result<(), Error> {
116    ///     let runtime = Runtime::new(service_fn(echo)).layer(
117    ///         layers::TracingLayer::new()
118    ///     );
119    ///     runtime.run().await?;
120    ///     Ok(())
121    /// }
122    ///
123    /// async fn echo(event: LambdaEvent<Value>) -> Result<Value, Error> {
124    ///     Ok(event.payload)
125    /// }
126    /// ```
127    pub fn layer<L>(self, layer: L) -> Runtime<L::Service>
128    where
129        L: Layer<S>,
130        L::Service: Service<LambdaInvocation, Response = (), Error = BoxError>,
131    {
132        Runtime {
133            client: self.client,
134            config: self.config,
135            service: layer.layer(self.service),
136        }
137    }
138}
139
140impl<S> Runtime<S>
141where
142    S: Service<LambdaInvocation, Response = (), Error = BoxError>,
143{
144    /// Start the runtime and begin polling for events on the Lambda Runtime API.
145    pub async fn run(self) -> Result<(), BoxError> {
146        let incoming = incoming(&self.client);
147        Self::run_with_incoming(self.service, self.config, incoming).await
148    }
149
150    /// Internal utility function to start the runtime with a customized incoming stream.
151    /// This implements the core of the [Runtime::run] method.
152    pub(crate) async fn run_with_incoming(
153        mut service: S,
154        config: Arc<Config>,
155        incoming: impl Stream<Item = Result<http::Response<hyper::body::Incoming>, BoxError>> + Send,
156    ) -> Result<(), BoxError> {
157        tokio::pin!(incoming);
158        while let Some(next_event_response) = incoming.next().await {
159            trace!("New event arrived (run loop)");
160            let event = next_event_response?;
161            let (parts, incoming) = event.into_parts();
162
163            #[cfg(debug_assertions)]
164            if parts.status == http::StatusCode::NO_CONTENT {
165                // Ignore the event if the status code is 204.
166                // This is a way to keep the runtime alive when
167                // there are no events pending to be processed.
168                continue;
169            }
170
171            // Build the invocation such that it can be sent to the service right away
172            // when it is ready
173            let body = incoming.collect().await?.to_bytes();
174            let context = Context::new(invoke_request_id(&parts.headers)?, config.clone(), &parts.headers)?;
175            let invocation = LambdaInvocation { parts, body, context };
176
177            // Setup Amazon's default tracing data
178            amzn_trace_env(&invocation.context);
179
180            // Wait for service to be ready
181            let ready = service.ready().await?;
182
183            // Once ready, call the service which will respond to the Lambda runtime API
184            ready.call(invocation).await?;
185        }
186        Ok(())
187    }
188}
189
190/* ------------------------------------------- UTILS ------------------------------------------- */
191
192#[allow(clippy::type_complexity)]
193fn wrap_handler<'a, F, EventPayload, Response, BufferedResponse, StreamingResponse, StreamItem, StreamError>(
194    handler: F,
195    client: Arc<ApiClient>,
196) -> RuntimeApiClientService<
197    RuntimeApiResponseService<
198        CatchPanicService<'a, F>,
199        EventPayload,
200        Response,
201        BufferedResponse,
202        StreamingResponse,
203        StreamItem,
204        StreamError,
205    >,
206>
207where
208    F: Service<LambdaEvent<EventPayload>, Response = Response>,
209    F::Future: Future<Output = Result<Response, F::Error>>,
210    F::Error: Into<Diagnostic> + Debug,
211    EventPayload: for<'de> Deserialize<'de>,
212    Response: IntoFunctionResponse<BufferedResponse, StreamingResponse>,
213    BufferedResponse: Serialize,
214    StreamingResponse: Stream<Item = Result<StreamItem, StreamError>> + Unpin + Send + 'static,
215    StreamItem: Into<bytes::Bytes> + Send,
216    StreamError: Into<BoxError> + Send + Debug,
217{
218    let safe_service = CatchPanicService::new(handler);
219    let response_service = RuntimeApiResponseService::new(safe_service);
220    RuntimeApiClientService::new(response_service, client)
221}
222
223fn incoming(
224    client: &ApiClient,
225) -> impl Stream<Item = Result<http::Response<hyper::body::Incoming>, BoxError>> + Send + '_ {
226    async_stream::stream! {
227        loop {
228            trace!("Waiting for next event (incoming loop)");
229            let req = NextEventRequest.into_req().expect("Unable to construct request");
230            let res = client.call(req).await;
231            yield res;
232        }
233    }
234}
235
236fn amzn_trace_env(ctx: &Context) {
237    match &ctx.xray_trace_id {
238        Some(trace_id) => env::set_var("_X_AMZN_TRACE_ID", trace_id),
239        None => env::remove_var("_X_AMZN_TRACE_ID"),
240    }
241}
242
243/* --------------------------------------------------------------------------------------------- */
244/*                                             TESTS                                             */
245/* --------------------------------------------------------------------------------------------- */
246
247#[cfg(test)]
248mod endpoint_tests {
249    use super::{incoming, wrap_handler};
250    use crate::{
251        requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest},
252        Config, Diagnostic, Error, Runtime,
253    };
254    use futures::future::BoxFuture;
255    use http::{HeaderValue, StatusCode};
256    use http_body_util::BodyExt;
257    use httpmock::prelude::*;
258
259    use lambda_runtime_api_client::Client;
260    use std::{env, sync::Arc};
261    use tokio_stream::StreamExt;
262
263    #[tokio::test]
264    async fn test_next_event() -> Result<(), Error> {
265        let server = MockServer::start();
266        let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9";
267        let deadline = "1542409706888";
268
269        let mock = server.mock(|when, then| {
270            when.method(GET).path("/2018-06-01/runtime/invocation/next");
271            then.status(200)
272                .header("content-type", "application/json")
273                .header("lambda-runtime-aws-request-id", request_id)
274                .header("lambda-runtime-deadline-ms", deadline)
275                .body("{}");
276        });
277
278        let base = server.base_url().parse().expect("Invalid mock server Uri");
279        let client = Client::builder().with_endpoint(base).build()?;
280
281        let req = NextEventRequest.into_req()?;
282        let rsp = client.call(req).await.expect("Unable to send request");
283
284        mock.assert_async().await;
285        assert_eq!(rsp.status(), StatusCode::OK);
286        assert_eq!(
287            rsp.headers()["lambda-runtime-aws-request-id"],
288            &HeaderValue::from_static(request_id)
289        );
290        assert_eq!(
291            rsp.headers()["lambda-runtime-deadline-ms"],
292            &HeaderValue::from_static(deadline)
293        );
294
295        let body = rsp.into_body().collect().await?.to_bytes();
296        assert_eq!("{}", std::str::from_utf8(&body)?);
297        Ok(())
298    }
299
300    #[tokio::test]
301    async fn test_ok_response() -> Result<(), Error> {
302        let server = MockServer::start();
303
304        let mock = server.mock(|when, then| {
305            when.method(POST)
306                .path("/2018-06-01/runtime/invocation/156cb537-e2d4-11e8-9b34-d36013741fb9/response")
307                .body("\"{}\"");
308            then.status(200).body("");
309        });
310
311        let base = server.base_url().parse().expect("Invalid mock server Uri");
312        let client = Client::builder().with_endpoint(base).build()?;
313
314        let req = EventCompletionRequest::new("156cb537-e2d4-11e8-9b34-d36013741fb9", "{}");
315        let req = req.into_req()?;
316
317        let rsp = client.call(req).await?;
318
319        mock.assert_async().await;
320        assert_eq!(rsp.status(), StatusCode::OK);
321        Ok(())
322    }
323
324    #[tokio::test]
325    async fn test_error_response() -> Result<(), Error> {
326        let diagnostic = Diagnostic {
327            error_type: "InvalidEventDataError".into(),
328            error_message: "Error parsing event data".into(),
329        };
330        let body = serde_json::to_string(&diagnostic)?;
331
332        let server = MockServer::start();
333        let mock = server.mock(|when, then| {
334            when.method(POST)
335                .path("/2018-06-01/runtime/invocation/156cb537-e2d4-11e8-9b34-d36013741fb9/error")
336                .header("lambda-runtime-function-error-type", "unhandled")
337                .body(body);
338            then.status(200).body("");
339        });
340
341        let base = server.base_url().parse().expect("Invalid mock server Uri");
342        let client = Client::builder().with_endpoint(base).build()?;
343
344        let req = EventErrorRequest {
345            request_id: "156cb537-e2d4-11e8-9b34-d36013741fb9",
346            diagnostic,
347        };
348        let req = req.into_req()?;
349        let rsp = client.call(req).await?;
350
351        mock.assert_async().await;
352        assert_eq!(rsp.status(), StatusCode::OK);
353        Ok(())
354    }
355
356    #[tokio::test]
357    async fn successful_end_to_end_run() -> Result<(), Error> {
358        let server = MockServer::start();
359        let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9";
360        let deadline = "1542409706888";
361
362        let next_request = server.mock(|when, then| {
363            when.method(GET).path("/2018-06-01/runtime/invocation/next");
364            then.status(200)
365                .header("content-type", "application/json")
366                .header("lambda-runtime-aws-request-id", request_id)
367                .header("lambda-runtime-deadline-ms", deadline)
368                .body("{}");
369        });
370        let next_response = server.mock(|when, then| {
371            when.method(POST)
372                .path(format!("/2018-06-01/runtime/invocation/{}/response", request_id))
373                .body("{}");
374            then.status(200).body("");
375        });
376
377        let base = server.base_url().parse().expect("Invalid mock server Uri");
378        let client = Client::builder().with_endpoint(base).build()?;
379
380        async fn func(event: crate::LambdaEvent<serde_json::Value>) -> Result<serde_json::Value, Error> {
381            let (event, _) = event.into_parts();
382            Ok(event)
383        }
384        let f = crate::service_fn(func);
385
386        // set env vars needed to init Config if they are not already set in the environment
387        if env::var("AWS_LAMBDA_RUNTIME_API").is_err() {
388            env::set_var("AWS_LAMBDA_RUNTIME_API", server.base_url());
389        }
390        if env::var("AWS_LAMBDA_FUNCTION_NAME").is_err() {
391            env::set_var("AWS_LAMBDA_FUNCTION_NAME", "test_fn");
392        }
393        if env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE").is_err() {
394            env::set_var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "128");
395        }
396        if env::var("AWS_LAMBDA_FUNCTION_VERSION").is_err() {
397            env::set_var("AWS_LAMBDA_FUNCTION_VERSION", "1");
398        }
399        if env::var("AWS_LAMBDA_LOG_STREAM_NAME").is_err() {
400            env::set_var("AWS_LAMBDA_LOG_STREAM_NAME", "test_stream");
401        }
402        if env::var("AWS_LAMBDA_LOG_GROUP_NAME").is_err() {
403            env::set_var("AWS_LAMBDA_LOG_GROUP_NAME", "test_log");
404        }
405        let config = Config::from_env();
406
407        let client = Arc::new(client);
408        let runtime = Runtime {
409            client: client.clone(),
410            config: Arc::new(config),
411            service: wrap_handler(f, client),
412        };
413        let client = &runtime.client;
414        let incoming = incoming(client).take(1);
415        Runtime::run_with_incoming(runtime.service, runtime.config, incoming).await?;
416
417        next_request.assert_async().await;
418        next_response.assert_async().await;
419        Ok(())
420    }
421
422    async fn run_panicking_handler<F>(func: F) -> Result<(), Error>
423    where
424        F: FnMut(crate::LambdaEvent<serde_json::Value>) -> BoxFuture<'static, Result<serde_json::Value, Error>>
425            + Send
426            + 'static,
427    {
428        let server = MockServer::start();
429        let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9";
430        let deadline = "1542409706888";
431
432        let next_request = server.mock(|when, then| {
433            when.method(GET).path("/2018-06-01/runtime/invocation/next");
434            then.status(200)
435                .header("content-type", "application/json")
436                .header("lambda-runtime-aws-request-id", request_id)
437                .header("lambda-runtime-deadline-ms", deadline)
438                .body("{}");
439        });
440
441        let next_response = server.mock(|when, then| {
442            when.method(POST)
443                .path(format!("/2018-06-01/runtime/invocation/{}/error", request_id))
444                .header("lambda-runtime-function-error-type", "unhandled");
445            then.status(200).body("");
446        });
447
448        let base = server.base_url().parse().expect("Invalid mock server Uri");
449        let client = Client::builder().with_endpoint(base).build()?;
450
451        let f = crate::service_fn(func);
452
453        let config = Arc::new(Config {
454            function_name: "test_fn".to_string(),
455            memory: 128,
456            version: "1".to_string(),
457            log_stream: "test_stream".to_string(),
458            log_group: "test_log".to_string(),
459        });
460
461        let client = Arc::new(client);
462        let runtime = Runtime {
463            client: client.clone(),
464            config,
465            service: wrap_handler(f, client),
466        };
467        let client = &runtime.client;
468        let incoming = incoming(client).take(1);
469        Runtime::run_with_incoming(runtime.service, runtime.config, incoming).await?;
470
471        next_request.assert_async().await;
472        next_response.assert_async().await;
473        Ok(())
474    }
475
476    #[tokio::test]
477    async fn panic_in_async_run() -> Result<(), Error> {
478        run_panicking_handler(|_| Box::pin(async { panic!("This is intentionally here") })).await
479    }
480
481    #[tokio::test]
482    async fn panic_outside_async_run() -> Result<(), Error> {
483        run_panicking_handler(|_| {
484            panic!("This is intentionally here");
485        })
486        .await
487    }
488}