1use crate::{
2 layers::{CatchPanicService, RuntimeApiClientService, RuntimeApiResponseService},
3 requests::{IntoRequest, NextEventRequest},
4 types::{invoke_request_id, IntoFunctionResponse, LambdaEvent},
5 Config, Context, Diagnostic,
6};
7use http_body_util::BodyExt;
8use lambda_runtime_api_client::{BoxError, Client as ApiClient};
9use serde::{Deserialize, Serialize};
10use std::{env, fmt::Debug, future::Future, sync::Arc};
11use tokio_stream::{Stream, StreamExt};
12use tower::{Layer, Service, ServiceExt};
13use tracing::trace;
14
15pub struct LambdaInvocation {
19 pub parts: http::response::Parts,
21 pub body: bytes::Bytes,
23 pub context: Context,
25}
26
27pub struct Runtime<S> {
53 service: S,
54 config: Arc<Config>,
55 client: Arc<ApiClient>,
56}
57
58impl<F, EventPayload, Response, BufferedResponse, StreamingResponse, StreamItem, StreamError>
59 Runtime<
60 RuntimeApiClientService<
61 RuntimeApiResponseService<
62 CatchPanicService<'_, F>,
63 EventPayload,
64 Response,
65 BufferedResponse,
66 StreamingResponse,
67 StreamItem,
68 StreamError,
69 >,
70 >,
71 >
72where
73 F: Service<LambdaEvent<EventPayload>, Response = Response>,
74 F::Future: Future<Output = Result<Response, F::Error>>,
75 F::Error: Into<Diagnostic> + Debug,
76 EventPayload: for<'de> Deserialize<'de>,
77 Response: IntoFunctionResponse<BufferedResponse, StreamingResponse>,
78 BufferedResponse: Serialize,
79 StreamingResponse: Stream<Item = Result<StreamItem, StreamError>> + Unpin + Send + 'static,
80 StreamItem: Into<bytes::Bytes> + Send,
81 StreamError: Into<BoxError> + Send + Debug,
82{
83 pub fn new(handler: F) -> Self {
93 trace!("Loading config from env");
94 let config = Arc::new(Config::from_env());
95 let client = Arc::new(ApiClient::builder().build().expect("Unable to create a runtime client"));
96 Self {
97 service: wrap_handler(handler, client.clone()),
98 config,
99 client,
100 }
101 }
102}
103
104impl<S> Runtime<S> {
105 pub fn layer<L>(self, layer: L) -> Runtime<L::Service>
128 where
129 L: Layer<S>,
130 L::Service: Service<LambdaInvocation, Response = (), Error = BoxError>,
131 {
132 Runtime {
133 client: self.client,
134 config: self.config,
135 service: layer.layer(self.service),
136 }
137 }
138}
139
140impl<S> Runtime<S>
141where
142 S: Service<LambdaInvocation, Response = (), Error = BoxError>,
143{
144 pub async fn run(self) -> Result<(), BoxError> {
146 let incoming = incoming(&self.client);
147 Self::run_with_incoming(self.service, self.config, incoming).await
148 }
149
150 pub(crate) async fn run_with_incoming(
153 mut service: S,
154 config: Arc<Config>,
155 incoming: impl Stream<Item = Result<http::Response<hyper::body::Incoming>, BoxError>> + Send,
156 ) -> Result<(), BoxError> {
157 tokio::pin!(incoming);
158 while let Some(next_event_response) = incoming.next().await {
159 trace!("New event arrived (run loop)");
160 let event = next_event_response?;
161 let (parts, incoming) = event.into_parts();
162
163 #[cfg(debug_assertions)]
164 if parts.status == http::StatusCode::NO_CONTENT {
165 continue;
169 }
170
171 let body = incoming.collect().await?.to_bytes();
174 let context = Context::new(invoke_request_id(&parts.headers)?, config.clone(), &parts.headers)?;
175 let invocation = LambdaInvocation { parts, body, context };
176
177 amzn_trace_env(&invocation.context);
179
180 let ready = service.ready().await?;
182
183 ready.call(invocation).await?;
185 }
186 Ok(())
187 }
188}
189
190#[allow(clippy::type_complexity)]
193fn wrap_handler<'a, F, EventPayload, Response, BufferedResponse, StreamingResponse, StreamItem, StreamError>(
194 handler: F,
195 client: Arc<ApiClient>,
196) -> RuntimeApiClientService<
197 RuntimeApiResponseService<
198 CatchPanicService<'a, F>,
199 EventPayload,
200 Response,
201 BufferedResponse,
202 StreamingResponse,
203 StreamItem,
204 StreamError,
205 >,
206>
207where
208 F: Service<LambdaEvent<EventPayload>, Response = Response>,
209 F::Future: Future<Output = Result<Response, F::Error>>,
210 F::Error: Into<Diagnostic> + Debug,
211 EventPayload: for<'de> Deserialize<'de>,
212 Response: IntoFunctionResponse<BufferedResponse, StreamingResponse>,
213 BufferedResponse: Serialize,
214 StreamingResponse: Stream<Item = Result<StreamItem, StreamError>> + Unpin + Send + 'static,
215 StreamItem: Into<bytes::Bytes> + Send,
216 StreamError: Into<BoxError> + Send + Debug,
217{
218 let safe_service = CatchPanicService::new(handler);
219 let response_service = RuntimeApiResponseService::new(safe_service);
220 RuntimeApiClientService::new(response_service, client)
221}
222
223fn incoming(
224 client: &ApiClient,
225) -> impl Stream<Item = Result<http::Response<hyper::body::Incoming>, BoxError>> + Send + '_ {
226 async_stream::stream! {
227 loop {
228 trace!("Waiting for next event (incoming loop)");
229 let req = NextEventRequest.into_req().expect("Unable to construct request");
230 let res = client.call(req).await;
231 yield res;
232 }
233 }
234}
235
236fn amzn_trace_env(ctx: &Context) {
237 match &ctx.xray_trace_id {
238 Some(trace_id) => env::set_var("_X_AMZN_TRACE_ID", trace_id),
239 None => env::remove_var("_X_AMZN_TRACE_ID"),
240 }
241}
242
243#[cfg(test)]
248mod endpoint_tests {
249 use super::{incoming, wrap_handler};
250 use crate::{
251 requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest},
252 Config, Diagnostic, Error, Runtime,
253 };
254 use futures::future::BoxFuture;
255 use http::{HeaderValue, StatusCode};
256 use http_body_util::BodyExt;
257 use httpmock::prelude::*;
258
259 use lambda_runtime_api_client::Client;
260 use std::{env, sync::Arc};
261 use tokio_stream::StreamExt;
262
263 #[tokio::test]
264 async fn test_next_event() -> Result<(), Error> {
265 let server = MockServer::start();
266 let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9";
267 let deadline = "1542409706888";
268
269 let mock = server.mock(|when, then| {
270 when.method(GET).path("/2018-06-01/runtime/invocation/next");
271 then.status(200)
272 .header("content-type", "application/json")
273 .header("lambda-runtime-aws-request-id", request_id)
274 .header("lambda-runtime-deadline-ms", deadline)
275 .body("{}");
276 });
277
278 let base = server.base_url().parse().expect("Invalid mock server Uri");
279 let client = Client::builder().with_endpoint(base).build()?;
280
281 let req = NextEventRequest.into_req()?;
282 let rsp = client.call(req).await.expect("Unable to send request");
283
284 mock.assert_async().await;
285 assert_eq!(rsp.status(), StatusCode::OK);
286 assert_eq!(
287 rsp.headers()["lambda-runtime-aws-request-id"],
288 &HeaderValue::from_static(request_id)
289 );
290 assert_eq!(
291 rsp.headers()["lambda-runtime-deadline-ms"],
292 &HeaderValue::from_static(deadline)
293 );
294
295 let body = rsp.into_body().collect().await?.to_bytes();
296 assert_eq!("{}", std::str::from_utf8(&body)?);
297 Ok(())
298 }
299
300 #[tokio::test]
301 async fn test_ok_response() -> Result<(), Error> {
302 let server = MockServer::start();
303
304 let mock = server.mock(|when, then| {
305 when.method(POST)
306 .path("/2018-06-01/runtime/invocation/156cb537-e2d4-11e8-9b34-d36013741fb9/response")
307 .body("\"{}\"");
308 then.status(200).body("");
309 });
310
311 let base = server.base_url().parse().expect("Invalid mock server Uri");
312 let client = Client::builder().with_endpoint(base).build()?;
313
314 let req = EventCompletionRequest::new("156cb537-e2d4-11e8-9b34-d36013741fb9", "{}");
315 let req = req.into_req()?;
316
317 let rsp = client.call(req).await?;
318
319 mock.assert_async().await;
320 assert_eq!(rsp.status(), StatusCode::OK);
321 Ok(())
322 }
323
324 #[tokio::test]
325 async fn test_error_response() -> Result<(), Error> {
326 let diagnostic = Diagnostic {
327 error_type: "InvalidEventDataError".into(),
328 error_message: "Error parsing event data".into(),
329 };
330 let body = serde_json::to_string(&diagnostic)?;
331
332 let server = MockServer::start();
333 let mock = server.mock(|when, then| {
334 when.method(POST)
335 .path("/2018-06-01/runtime/invocation/156cb537-e2d4-11e8-9b34-d36013741fb9/error")
336 .header("lambda-runtime-function-error-type", "unhandled")
337 .body(body);
338 then.status(200).body("");
339 });
340
341 let base = server.base_url().parse().expect("Invalid mock server Uri");
342 let client = Client::builder().with_endpoint(base).build()?;
343
344 let req = EventErrorRequest {
345 request_id: "156cb537-e2d4-11e8-9b34-d36013741fb9",
346 diagnostic,
347 };
348 let req = req.into_req()?;
349 let rsp = client.call(req).await?;
350
351 mock.assert_async().await;
352 assert_eq!(rsp.status(), StatusCode::OK);
353 Ok(())
354 }
355
356 #[tokio::test]
357 async fn successful_end_to_end_run() -> Result<(), Error> {
358 let server = MockServer::start();
359 let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9";
360 let deadline = "1542409706888";
361
362 let next_request = server.mock(|when, then| {
363 when.method(GET).path("/2018-06-01/runtime/invocation/next");
364 then.status(200)
365 .header("content-type", "application/json")
366 .header("lambda-runtime-aws-request-id", request_id)
367 .header("lambda-runtime-deadline-ms", deadline)
368 .body("{}");
369 });
370 let next_response = server.mock(|when, then| {
371 when.method(POST)
372 .path(format!("/2018-06-01/runtime/invocation/{}/response", request_id))
373 .body("{}");
374 then.status(200).body("");
375 });
376
377 let base = server.base_url().parse().expect("Invalid mock server Uri");
378 let client = Client::builder().with_endpoint(base).build()?;
379
380 async fn func(event: crate::LambdaEvent<serde_json::Value>) -> Result<serde_json::Value, Error> {
381 let (event, _) = event.into_parts();
382 Ok(event)
383 }
384 let f = crate::service_fn(func);
385
386 if env::var("AWS_LAMBDA_RUNTIME_API").is_err() {
388 env::set_var("AWS_LAMBDA_RUNTIME_API", server.base_url());
389 }
390 if env::var("AWS_LAMBDA_FUNCTION_NAME").is_err() {
391 env::set_var("AWS_LAMBDA_FUNCTION_NAME", "test_fn");
392 }
393 if env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE").is_err() {
394 env::set_var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "128");
395 }
396 if env::var("AWS_LAMBDA_FUNCTION_VERSION").is_err() {
397 env::set_var("AWS_LAMBDA_FUNCTION_VERSION", "1");
398 }
399 if env::var("AWS_LAMBDA_LOG_STREAM_NAME").is_err() {
400 env::set_var("AWS_LAMBDA_LOG_STREAM_NAME", "test_stream");
401 }
402 if env::var("AWS_LAMBDA_LOG_GROUP_NAME").is_err() {
403 env::set_var("AWS_LAMBDA_LOG_GROUP_NAME", "test_log");
404 }
405 let config = Config::from_env();
406
407 let client = Arc::new(client);
408 let runtime = Runtime {
409 client: client.clone(),
410 config: Arc::new(config),
411 service: wrap_handler(f, client),
412 };
413 let client = &runtime.client;
414 let incoming = incoming(client).take(1);
415 Runtime::run_with_incoming(runtime.service, runtime.config, incoming).await?;
416
417 next_request.assert_async().await;
418 next_response.assert_async().await;
419 Ok(())
420 }
421
422 async fn run_panicking_handler<F>(func: F) -> Result<(), Error>
423 where
424 F: FnMut(crate::LambdaEvent<serde_json::Value>) -> BoxFuture<'static, Result<serde_json::Value, Error>>
425 + Send
426 + 'static,
427 {
428 let server = MockServer::start();
429 let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9";
430 let deadline = "1542409706888";
431
432 let next_request = server.mock(|when, then| {
433 when.method(GET).path("/2018-06-01/runtime/invocation/next");
434 then.status(200)
435 .header("content-type", "application/json")
436 .header("lambda-runtime-aws-request-id", request_id)
437 .header("lambda-runtime-deadline-ms", deadline)
438 .body("{}");
439 });
440
441 let next_response = server.mock(|when, then| {
442 when.method(POST)
443 .path(format!("/2018-06-01/runtime/invocation/{}/error", request_id))
444 .header("lambda-runtime-function-error-type", "unhandled");
445 then.status(200).body("");
446 });
447
448 let base = server.base_url().parse().expect("Invalid mock server Uri");
449 let client = Client::builder().with_endpoint(base).build()?;
450
451 let f = crate::service_fn(func);
452
453 let config = Arc::new(Config {
454 function_name: "test_fn".to_string(),
455 memory: 128,
456 version: "1".to_string(),
457 log_stream: "test_stream".to_string(),
458 log_group: "test_log".to_string(),
459 });
460
461 let client = Arc::new(client);
462 let runtime = Runtime {
463 client: client.clone(),
464 config,
465 service: wrap_handler(f, client),
466 };
467 let client = &runtime.client;
468 let incoming = incoming(client).take(1);
469 Runtime::run_with_incoming(runtime.service, runtime.config, incoming).await?;
470
471 next_request.assert_async().await;
472 next_response.assert_async().await;
473 Ok(())
474 }
475
476 #[tokio::test]
477 async fn panic_in_async_run() -> Result<(), Error> {
478 run_panicking_handler(|_| Box::pin(async { panic!("This is intentionally here") })).await
479 }
480
481 #[tokio::test]
482 async fn panic_outside_async_run() -> Result<(), Error> {
483 run_panicking_handler(|_| {
484 panic!("This is intentionally here");
485 })
486 .await
487 }
488}