1use crate::request::RequestOrigin;
4#[cfg(feature = "alb")]
5use aws_lambda_events::alb::AlbTargetGroupResponse;
6#[cfg(any(feature = "apigw_rest", feature = "apigw_websockets"))]
7use aws_lambda_events::apigw::ApiGatewayProxyResponse;
8#[cfg(feature = "apigw_http")]
9use aws_lambda_events::apigw::ApiGatewayV2httpResponse;
10use aws_lambda_events::encodings::Body;
11use encoding_rs::Encoding;
12use http::{
13 header::{CONTENT_ENCODING, CONTENT_TYPE},
14 HeaderMap, Response, StatusCode,
15};
16use http_body::Body as HttpBody;
17use http_body_util::BodyExt;
18use mime::{Mime, CHARSET};
19use serde::Serialize;
20use std::{
21 borrow::Cow,
22 fmt,
23 future::{ready, Future},
24 pin::Pin,
25};
26
27const X_LAMBDA_HTTP_CONTENT_ENCODING: &str = "x-lambda-http-content-encoding";
28
29const TEXT_ENCODING_PREFIXES: [&str; 5] = [
33 "text",
34 "application/json",
35 "application/javascript",
36 "application/xml",
37 "application/yaml",
38];
39
40const TEXT_ENCODING_SUFFIXES: [&str; 3] = ["+xml", "+yaml", "+json"];
41
42#[non_exhaustive]
44#[doc(hidden)]
45#[derive(Serialize, Debug)]
46#[serde(untagged)]
47pub enum LambdaResponse {
48 #[cfg(any(feature = "apigw_rest", feature = "apigw_websockets"))]
49 ApiGatewayV1(ApiGatewayProxyResponse),
50 #[cfg(feature = "apigw_http")]
51 ApiGatewayV2(ApiGatewayV2httpResponse),
52 #[cfg(feature = "alb")]
53 Alb(AlbTargetGroupResponse),
54 #[cfg(feature = "pass_through")]
55 PassThrough(serde_json::Value),
56}
57
58impl LambdaResponse {
60 pub(crate) fn from_response(request_origin: &RequestOrigin, value: Response<Body>) -> Self {
61 let (parts, bod) = value.into_parts();
62 let (is_base64_encoded, body) = match bod {
63 Body::Empty => (false, None),
64 b @ Body::Text(_) => (false, Some(b)),
65 b @ Body::Binary(_) => (true, Some(b)),
66 _ => (false, None),
67 };
68
69 let headers = parts.headers;
70 let status_code = parts.status.as_u16();
71
72 match request_origin {
73 #[cfg(feature = "apigw_rest")]
74 RequestOrigin::ApiGatewayV1 => LambdaResponse::ApiGatewayV1({
75 let mut response = ApiGatewayProxyResponse::default();
76
77 response.body = body;
78 response.is_base64_encoded = is_base64_encoded;
79 response.status_code = status_code as i64;
80 response.headers = HeaderMap::new();
83 response.multi_value_headers = headers;
84 #[cfg(feature = "catch-all-fields")]
86 {
87 response.other = Default::default();
88 }
89 response
90 }),
91 #[cfg(feature = "apigw_http")]
92 RequestOrigin::ApiGatewayV2 => {
93 use http::header::SET_COOKIE;
94 let mut headers = headers;
95 let cookies = headers
98 .get_all(SET_COOKIE)
99 .iter()
100 .map(|v| v.to_str().ok().unwrap_or_default().to_string())
101 .collect();
102 headers.remove(SET_COOKIE);
103
104 LambdaResponse::ApiGatewayV2({
105 let mut response = ApiGatewayV2httpResponse::default();
106 response.body = body;
107 response.is_base64_encoded = is_base64_encoded;
108 response.status_code = status_code as i64;
109 response.cookies = cookies;
110 response.headers = headers;
113 response.multi_value_headers = HeaderMap::new();
114 #[cfg(feature = "catch-all-fields")]
116 {
117 response.other = Default::default();
118 }
119 response
120 })
121 }
122 #[cfg(feature = "alb")]
123 RequestOrigin::Alb => LambdaResponse::Alb({
124 let mut response = AlbTargetGroupResponse::default();
125
126 response.body = body;
127 response.is_base64_encoded = is_base64_encoded;
128 response.status_code = status_code as i64;
129 response.headers = headers.clone();
133 response.multi_value_headers = headers;
134 response.status_description = Some(format!(
135 "{} {}",
136 status_code,
137 parts.status.canonical_reason().unwrap_or_default()
138 ));
139 #[cfg(feature = "catch-all-fields")]
141 {
142 response.other = Default::default();
143 }
144 response
145 }),
146 #[cfg(feature = "apigw_websockets")]
147 RequestOrigin::WebSocket => LambdaResponse::ApiGatewayV1({
148 let mut response = ApiGatewayProxyResponse::default();
149 response.body = body;
150 response.is_base64_encoded = is_base64_encoded;
151 response.status_code = status_code as i64;
152 response.headers = HeaderMap::new();
155 response.multi_value_headers = headers;
156 #[cfg(feature = "catch-all-fields")]
158 {
159 response.other = Default::default();
160 }
161 response
162 }),
163 #[cfg(feature = "pass_through")]
164 RequestOrigin::PassThrough => {
165 match body {
166 Some(Body::Text(body)) => {LambdaResponse::PassThrough(serde_json::from_str(&body).unwrap_or_default())},
168 _ => LambdaResponse::PassThrough(serde_json::Value::Null),
170 }
171 }
172 #[cfg(not(any(
173 feature = "apigw_rest",
174 feature = "apigw_http",
175 feature = "alb",
176 feature = "apigw_websockets"
177 )))]
178 _ => compile_error!("Either feature `apigw_rest`, `apigw_http`, `alb`, or `apigw_websockets` must be enabled for the `lambda-http` crate."),
179 }
180 }
181}
182
183pub trait IntoResponse {
187 fn into_response(self) -> ResponseFuture;
189}
190
191impl<B> IntoResponse for Response<B>
192where
193 B: ConvertBody + Send + 'static,
194{
195 fn into_response(self) -> ResponseFuture {
196 let (parts, body) = self.into_parts();
197 let headers = parts.headers.clone();
198
199 let fut = async { Response::from_parts(parts, body.convert(headers).await) };
200
201 Box::pin(fut)
202 }
203}
204
205impl IntoResponse for String {
206 fn into_response(self) -> ResponseFuture {
207 Box::pin(ready(Response::new(Body::from(self))))
208 }
209}
210
211impl IntoResponse for &str {
212 fn into_response(self) -> ResponseFuture {
213 Box::pin(ready(Response::new(Body::from(self))))
214 }
215}
216
217impl IntoResponse for &[u8] {
218 fn into_response(self) -> ResponseFuture {
219 Box::pin(ready(Response::new(Body::from(self))))
220 }
221}
222
223impl IntoResponse for Vec<u8> {
224 fn into_response(self) -> ResponseFuture {
225 Box::pin(ready(Response::new(Body::from(self))))
226 }
227}
228
229impl IntoResponse for serde_json::Value {
230 fn into_response(self) -> ResponseFuture {
231 Box::pin(async move {
232 Response::builder()
233 .header(CONTENT_TYPE, "application/json")
234 .body(
235 serde_json::to_string(&self)
236 .expect("unable to serialize serde_json::Value")
237 .into(),
238 )
239 .expect("unable to build http::Response")
240 })
241 }
242}
243
244impl IntoResponse for (StatusCode, String) {
245 fn into_response(self) -> ResponseFuture {
246 let (status, body) = self;
247 Box::pin(ready(
248 Response::builder()
249 .status(status)
250 .body(Body::from(body))
251 .expect("unable to build http::Response"),
252 ))
253 }
254}
255
256impl IntoResponse for (StatusCode, &str) {
257 fn into_response(self) -> ResponseFuture {
258 let (status, body) = self;
259 Box::pin(ready(
260 Response::builder()
261 .status(status)
262 .body(Body::from(body))
263 .expect("unable to build http::Response"),
264 ))
265 }
266}
267
268impl IntoResponse for (StatusCode, &[u8]) {
269 fn into_response(self) -> ResponseFuture {
270 let (status, body) = self;
271 Box::pin(ready(
272 Response::builder()
273 .status(status)
274 .body(Body::from(body))
275 .expect("unable to build http::Response"),
276 ))
277 }
278}
279
280impl IntoResponse for (StatusCode, Vec<u8>) {
281 fn into_response(self) -> ResponseFuture {
282 let (status, body) = self;
283 Box::pin(ready(
284 Response::builder()
285 .status(status)
286 .body(Body::from(body))
287 .expect("unable to build http::Response"),
288 ))
289 }
290}
291
292impl IntoResponse for (StatusCode, serde_json::Value) {
293 fn into_response(self) -> ResponseFuture {
294 let (status, body) = self;
295 Box::pin(async move {
296 Response::builder()
297 .status(status)
298 .header(CONTENT_TYPE, "application/json")
299 .body(
300 serde_json::to_string(&body)
301 .expect("unable to serialize serde_json::Value")
302 .into(),
303 )
304 .expect("unable to build http::Response")
305 })
306 }
307}
308
309pub type ResponseFuture = Pin<Box<dyn Future<Output = Response<Body>> + Send>>;
310
311pub trait ConvertBody {
312 fn convert(self, parts: HeaderMap) -> BodyFuture;
313}
314
315impl<B> ConvertBody for B
316where
317 B: HttpBody + Unpin + Send + 'static,
318 B::Data: Send,
319 B::Error: fmt::Debug,
320{
321 fn convert(self, headers: HeaderMap) -> BodyFuture {
322 if headers.get(CONTENT_ENCODING).is_some() {
323 return convert_to_binary(self);
324 }
325
326 let content_type = if let Some(value) = headers.get(CONTENT_TYPE) {
327 value.to_str().unwrap_or_default()
328 } else {
329 return convert_to_text(self, "utf-8");
331 };
332
333 for prefix in TEXT_ENCODING_PREFIXES {
334 if content_type.starts_with(prefix) {
335 return convert_to_text(self, content_type);
336 }
337 }
338
339 for suffix in TEXT_ENCODING_SUFFIXES {
340 let mut parts = content_type.trim().split(';');
341 let mime_type = parts.next().unwrap_or_default();
342 if mime_type.ends_with(suffix) {
343 return convert_to_text(self, content_type);
344 }
345 }
346
347 if let Some(value) = headers.get(X_LAMBDA_HTTP_CONTENT_ENCODING) {
348 if value == "text" {
349 return convert_to_text(self, content_type);
350 }
351 }
352
353 convert_to_binary(self)
354 }
355}
356
357fn convert_to_binary<B>(body: B) -> BodyFuture
358where
359 B: HttpBody + Unpin + Send + 'static,
360 B::Data: Send,
361 B::Error: fmt::Debug,
362{
363 Box::pin(async move {
364 Body::from(
365 body.collect()
366 .await
367 .expect("unable to read bytes from body")
368 .to_bytes()
369 .to_vec(),
370 )
371 })
372}
373
374fn convert_to_text<B>(body: B, content_type: &str) -> BodyFuture
375where
376 B: HttpBody + Unpin + Send + 'static,
377 B::Data: Send,
378 B::Error: fmt::Debug,
379{
380 let mime_type = content_type.parse::<Mime>();
381
382 let encoding = match mime_type.as_ref() {
383 Ok(mime) => mime.get_param(CHARSET).unwrap_or(mime::UTF_8),
384 Err(_) => mime::UTF_8,
385 };
386
387 let label = encoding.as_ref().as_bytes();
388 let encoding = Encoding::for_label(label).unwrap_or(encoding_rs::UTF_8);
389
390 Box::pin(async move {
392 let bytes = body.collect().await.expect("unable to read bytes from body").to_bytes();
393 let (content, _, _) = encoding.decode(&bytes);
394
395 match content {
396 Cow::Borrowed(content) => Body::from(content),
397 Cow::Owned(content) => Body::from(content),
398 }
399 })
400}
401
402pub type BodyFuture = Pin<Box<dyn Future<Output = Body> + Send>>;
403
404#[cfg(test)]
405mod tests {
406 use super::{Body, IntoResponse, LambdaResponse, RequestOrigin, X_LAMBDA_HTTP_CONTENT_ENCODING};
407 use http::{
408 header::{CONTENT_ENCODING, CONTENT_TYPE},
409 Response, StatusCode,
410 };
411 use lambda_runtime_api_client::body::Body as HyperBody;
412 use serde_json::{self, json};
413
414 const SVG_LOGO: &str = include_str!("../tests/data/svg_logo.svg");
415
416 #[tokio::test]
417 async fn json_into_response() {
418 let response = json!({ "hello": "lambda"}).into_response().await;
419 match response.body() {
420 Body::Text(json) => assert_eq!(json, r#"{"hello":"lambda"}"#),
421 _ => panic!("invalid body"),
422 }
423 assert_eq!(
424 response
425 .headers()
426 .get(CONTENT_TYPE)
427 .map(|h| h.to_str().expect("invalid header")),
428 Some("application/json")
429 )
430 }
431
432 #[tokio::test]
433 async fn text_into_response() {
434 let response = "text".into_response().await;
435 match response.body() {
436 Body::Text(text) => assert_eq!(text, "text"),
437 _ => panic!("invalid body"),
438 }
439 }
440
441 #[tokio::test]
442 async fn bytes_into_response() {
443 let response = "text".as_bytes().into_response().await;
444 match response.body() {
445 Body::Binary(data) => assert_eq!(data, "text".as_bytes()),
446 _ => panic!("invalid body"),
447 }
448 }
449
450 #[tokio::test]
451 async fn json_with_status_code_into_response() {
452 let response = (StatusCode::CREATED, json!({ "hello": "lambda"})).into_response().await;
453 match response.body() {
454 Body::Text(json) => assert_eq!(json, r#"{"hello":"lambda"}"#),
455 _ => panic!("invalid body"),
456 }
457 match response.status() {
458 StatusCode::CREATED => (),
459 _ => panic!("invalid status code"),
460 }
461
462 assert_eq!(
463 response
464 .headers()
465 .get(CONTENT_TYPE)
466 .map(|h| h.to_str().expect("invalid header")),
467 Some("application/json")
468 )
469 }
470
471 #[tokio::test]
472 async fn text_with_status_code_into_response() {
473 let response = (StatusCode::CREATED, "text").into_response().await;
474
475 match response.status() {
476 StatusCode::CREATED => (),
477 _ => panic!("invalid status code"),
478 }
479 match response.body() {
480 Body::Text(text) => assert_eq!(text, "text"),
481 _ => panic!("invalid body"),
482 }
483 }
484
485 #[tokio::test]
486 async fn bytes_with_status_code_into_response() {
487 let response = (StatusCode::CREATED, "text".as_bytes()).into_response().await;
488 match response.status() {
489 StatusCode::CREATED => (),
490 _ => panic!("invalid status code"),
491 }
492 match response.body() {
493 Body::Binary(data) => assert_eq!(data, "text".as_bytes()),
494 _ => panic!("invalid body"),
495 }
496 }
497
498 #[tokio::test]
499 async fn content_encoding_header() {
500 let response = Response::builder()
503 .header(CONTENT_ENCODING, "gzip")
504 .body(HyperBody::from("000000".as_bytes()))
505 .expect("unable to build http::Response");
506 let response = response.into_response().await;
507 let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
508
509 let json = serde_json::to_string(&response).expect("failed to serialize to json");
510 assert_eq!(
511 json,
512 r#"{"statusCode":200,"headers":{"content-encoding":"gzip"},"multiValueHeaders":{},"body":"MDAwMDAw","isBase64Encoded":true,"cookies":[]}"#
513 )
514 }
515
516 #[tokio::test]
517 async fn content_type_header() {
518 let response = Response::builder()
521 .header(CONTENT_TYPE, "application/json")
522 .body(HyperBody::from("000000".as_bytes()))
523 .expect("unable to build http::Response");
524 let response = response.into_response().await;
525 let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
526
527 let json = serde_json::to_string(&response).expect("failed to serialize to json");
528 assert_eq!(
529 json,
530 r#"{"statusCode":200,"headers":{"content-type":"application/json"},"multiValueHeaders":{},"body":"000000","isBase64Encoded":false,"cookies":[]}"#
531 )
532 }
533
534 #[tokio::test]
535 async fn charset_content_type_header() {
536 let response = Response::builder()
539 .header(CONTENT_TYPE, "application/json; charset=utf-16")
540 .body(HyperBody::from("000000".as_bytes()))
541 .expect("unable to build http::Response");
542 let response = response.into_response().await;
543 let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
544
545 let json = serde_json::to_string(&response).expect("failed to serialize to json");
546 assert_eq!(
547 json,
548 r#"{"statusCode":200,"headers":{"content-type":"application/json; charset=utf-16"},"multiValueHeaders":{},"body":"〰〰〰","isBase64Encoded":false,"cookies":[]}"#
549 )
550 }
551
552 #[tokio::test]
553 async fn charset_content_type_header_suffix() {
554 let response = Response::builder()
557 .header(CONTENT_TYPE, "application/graphql-response+json; charset=utf-16")
558 .body(HyperBody::from("000000".as_bytes()))
559 .expect("unable to build http::Response");
560 let response = response.into_response().await;
561 let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
562
563 let json = serde_json::to_string(&response).expect("failed to serialize to json");
564 assert_eq!(
565 json,
566 r#"{"statusCode":200,"headers":{"content-type":"application/graphql-response+json; charset=utf-16"},"multiValueHeaders":{},"body":"〰〰〰","isBase64Encoded":false,"cookies":[]}"#
567 )
568 }
569
570 #[tokio::test]
571 async fn content_headers_unset() {
572 let response = Response::builder()
575 .body(HyperBody::from("000000".as_bytes()))
576 .expect("unable to build http::Response");
577 let response = response.into_response().await;
578 let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
579
580 let json = serde_json::to_string(&response).expect("failed to serialize to json");
581 assert_eq!(
582 json,
583 r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"000000","isBase64Encoded":false,"cookies":[]}"#
584 )
585 }
586
587 #[test]
588 fn serialize_multi_value_headers() {
589 let res = LambdaResponse::from_response(
590 &RequestOrigin::ApiGatewayV1,
591 Response::builder()
592 .header("multi", "a")
593 .header("multi", "b")
594 .body(Body::from(()))
595 .expect("failed to create response"),
596 );
597 let json = serde_json::to_string(&res).expect("failed to serialize to json");
598 assert_eq!(
599 json,
600 r#"{"statusCode":200,"headers":{},"multiValueHeaders":{"multi":["a","b"]},"isBase64Encoded":false}"#
601 )
602 }
603
604 #[test]
605 fn serialize_cookies() {
606 let res = LambdaResponse::from_response(
607 &RequestOrigin::ApiGatewayV2,
608 Response::builder()
609 .header("set-cookie", "cookie1=a")
610 .header("set-cookie", "cookie2=b")
611 .body(Body::from(()))
612 .expect("failed to create response"),
613 );
614 let json = serde_json::to_string(&res).expect("failed to serialize to json");
615 assert_eq!(
616 "{\"statusCode\":200,\"headers\":{},\"multiValueHeaders\":{},\"isBase64Encoded\":false,\"cookies\":[\"cookie1=a\",\"cookie2=b\"]}",
617 json
618 )
619 }
620
621 #[tokio::test]
622 async fn content_type_xml_as_text() {
623 let response = Response::builder()
626 .header(CONTENT_TYPE, "image/svg+xml")
627 .body(HyperBody::from(SVG_LOGO.as_bytes()))
628 .expect("unable to build http::Response");
629 let response = response.into_response().await;
630
631 match response.body() {
632 Body::Text(body) => assert_eq!(SVG_LOGO, body),
633 _ => panic!("invalid body"),
634 }
635 assert_eq!(
636 response
637 .headers()
638 .get(CONTENT_TYPE)
639 .map(|h| h.to_str().expect("invalid header")),
640 Some("image/svg+xml")
641 )
642 }
643
644 #[tokio::test]
645 async fn content_type_custom_encoding_as_text() {
646 let response = Response::builder()
649 .header(CONTENT_TYPE, "image/svg")
651 .header(X_LAMBDA_HTTP_CONTENT_ENCODING, "text")
652 .body(HyperBody::from(SVG_LOGO.as_bytes()))
653 .expect("unable to build http::Response");
654 let response = response.into_response().await;
655
656 match response.body() {
657 Body::Text(body) => assert_eq!(SVG_LOGO, body),
658 _ => panic!("invalid body"),
659 }
660 assert_eq!(
661 response
662 .headers()
663 .get(CONTENT_TYPE)
664 .map(|h| h.to_str().expect("invalid header")),
665 Some("image/svg")
666 )
667 }
668
669 #[tokio::test]
670 async fn content_type_yaml_as_text() {
671 let yaml = r#"---
674foo: bar
675 "#;
676
677 let formats = ["application/yaml", "custom/vdn+yaml"];
678
679 for format in formats {
680 let response = Response::builder()
681 .header(CONTENT_TYPE, format)
682 .body(HyperBody::from(yaml.as_bytes()))
683 .expect("unable to build http::Response");
684 let response = response.into_response().await;
685
686 match response.body() {
687 Body::Text(body) => assert_eq!(yaml, body),
688 _ => panic!("invalid body"),
689 }
690 assert_eq!(
691 response
692 .headers()
693 .get(CONTENT_TYPE)
694 .map(|h| h.to_str().expect("invalid header")),
695 Some(format)
696 )
697 }
698 }
699}