1use crate::request::RequestOrigin;
4#[cfg(feature = "alb")]
5use aws_lambda_events::alb::AlbTargetGroupResponse;
6#[cfg(any(feature = "apigw_rest", feature = "apigw_websockets"))]
7use aws_lambda_events::apigw::ApiGatewayProxyResponse;
8#[cfg(feature = "apigw_http")]
9use aws_lambda_events::apigw::ApiGatewayV2httpResponse;
10use aws_lambda_events::encodings::Body;
11use encoding_rs::Encoding;
12use http::{
13 header::{CONTENT_ENCODING, CONTENT_TYPE},
14 HeaderMap, Response, StatusCode,
15};
16use http_body::Body as HttpBody;
17use http_body_util::BodyExt;
18use mime::{Mime, CHARSET};
19use serde::Serialize;
20use std::{
21 borrow::Cow,
22 fmt,
23 future::{ready, Future},
24 pin::Pin,
25};
26
27const X_LAMBDA_HTTP_CONTENT_ENCODING: &str = "x-lambda-http-content-encoding";
28
29const TEXT_ENCODING_PREFIXES: [&str; 5] = [
33 "text",
34 "application/json",
35 "application/javascript",
36 "application/xml",
37 "application/yaml",
38];
39
40const TEXT_ENCODING_SUFFIXES: [&str; 3] = ["+xml", "+yaml", "+json"];
41
42#[non_exhaustive]
44#[doc(hidden)]
45#[derive(Serialize, Debug)]
46#[serde(untagged)]
47pub enum LambdaResponse {
48 #[cfg(any(feature = "apigw_rest", feature = "apigw_websockets"))]
49 ApiGatewayV1(ApiGatewayProxyResponse),
50 #[cfg(feature = "apigw_http")]
51 ApiGatewayV2(ApiGatewayV2httpResponse),
52 #[cfg(feature = "alb")]
53 Alb(AlbTargetGroupResponse),
54 #[cfg(feature = "pass_through")]
55 PassThrough(serde_json::Value),
56}
57
58impl LambdaResponse {
60 pub(crate) fn from_response(request_origin: &RequestOrigin, value: Response<Body>) -> Self {
61 let (parts, bod) = value.into_parts();
62 let (is_base64_encoded, body) = match bod {
63 Body::Empty => (false, None),
64 b @ Body::Text(_) => (false, Some(b)),
65 b @ Body::Binary(_) => (true, Some(b)),
66 _ => (false, None),
67 };
68
69 let headers = parts.headers;
70 let status_code = parts.status.as_u16();
71
72 match request_origin {
73 #[cfg(feature = "apigw_rest")]
74 RequestOrigin::ApiGatewayV1 => LambdaResponse::ApiGatewayV1({
75 let mut response = ApiGatewayProxyResponse::default();
76
77 response.body = body;
78 response.is_base64_encoded = is_base64_encoded;
79 response.status_code = status_code as i64;
80 response.headers = HeaderMap::new();
83 response.multi_value_headers = headers;
84 #[cfg(feature = "catch-all-fields")]
86 {
87 response.other = Default::default();
88 }
89 response
90 }),
91 #[cfg(feature = "apigw_http")]
92 RequestOrigin::ApiGatewayV2 => {
93 use http::header::SET_COOKIE;
94 let mut headers = headers;
95 let cookies = headers
98 .get_all(SET_COOKIE)
99 .iter()
100 .cloned()
101 .map(|v| v.to_str().ok().unwrap_or_default().to_string())
102 .collect();
103 headers.remove(SET_COOKIE);
104
105 LambdaResponse::ApiGatewayV2({
106 let mut response = ApiGatewayV2httpResponse::default();
107 response.body = body;
108 response.is_base64_encoded = is_base64_encoded;
109 response.status_code = status_code as i64;
110 response.cookies = cookies;
111 response.headers = headers;
114 response.multi_value_headers = HeaderMap::new();
115 #[cfg(feature = "catch-all-fields")]
117 {
118 response.other = Default::default();
119 }
120 response
121 })
122 }
123 #[cfg(feature = "alb")]
124 RequestOrigin::Alb => LambdaResponse::Alb({
125 let mut response = AlbTargetGroupResponse::default();
126
127 response.body = body;
128 response.is_base64_encoded = is_base64_encoded;
129 response.status_code = status_code as i64;
130 response.headers = headers.clone();
134 response.multi_value_headers = headers;
135 response.status_description = Some(format!(
136 "{} {}",
137 status_code,
138 parts.status.canonical_reason().unwrap_or_default()
139 ));
140 #[cfg(feature = "catch-all-fields")]
142 {
143 response.other = Default::default();
144 }
145 response
146 }),
147 #[cfg(feature = "apigw_websockets")]
148 RequestOrigin::WebSocket => LambdaResponse::ApiGatewayV1({
149 let mut response = ApiGatewayProxyResponse::default();
150 response.body = body;
151 response.is_base64_encoded = is_base64_encoded;
152 response.status_code = status_code as i64;
153 response.headers = HeaderMap::new();
156 response.multi_value_headers = headers;
157 #[cfg(feature = "catch-all-fields")]
159 {
160 response.other = Default::default();
161 }
162 response
163 }),
164 #[cfg(feature = "pass_through")]
165 RequestOrigin::PassThrough => {
166 match body {
167 Some(Body::Text(body)) => {LambdaResponse::PassThrough(serde_json::from_str(&body).unwrap_or_default())},
169 _ => LambdaResponse::PassThrough(serde_json::Value::Null),
171 }
172 }
173 #[cfg(not(any(
174 feature = "apigw_rest",
175 feature = "apigw_http",
176 feature = "alb",
177 feature = "apigw_websockets"
178 )))]
179 _ => compile_error!("Either feature `apigw_rest`, `apigw_http`, `alb`, or `apigw_websockets` must be enabled for the `lambda-http` crate."),
180 }
181 }
182}
183
184pub trait IntoResponse {
188 fn into_response(self) -> ResponseFuture;
190}
191
192impl<B> IntoResponse for Response<B>
193where
194 B: ConvertBody + Send + 'static,
195{
196 fn into_response(self) -> ResponseFuture {
197 let (parts, body) = self.into_parts();
198 let headers = parts.headers.clone();
199
200 let fut = async { Response::from_parts(parts, body.convert(headers).await) };
201
202 Box::pin(fut)
203 }
204}
205
206impl IntoResponse for String {
207 fn into_response(self) -> ResponseFuture {
208 Box::pin(ready(Response::new(Body::from(self))))
209 }
210}
211
212impl IntoResponse for &str {
213 fn into_response(self) -> ResponseFuture {
214 Box::pin(ready(Response::new(Body::from(self))))
215 }
216}
217
218impl IntoResponse for &[u8] {
219 fn into_response(self) -> ResponseFuture {
220 Box::pin(ready(Response::new(Body::from(self))))
221 }
222}
223
224impl IntoResponse for Vec<u8> {
225 fn into_response(self) -> ResponseFuture {
226 Box::pin(ready(Response::new(Body::from(self))))
227 }
228}
229
230impl IntoResponse for serde_json::Value {
231 fn into_response(self) -> ResponseFuture {
232 Box::pin(async move {
233 Response::builder()
234 .header(CONTENT_TYPE, "application/json")
235 .body(
236 serde_json::to_string(&self)
237 .expect("unable to serialize serde_json::Value")
238 .into(),
239 )
240 .expect("unable to build http::Response")
241 })
242 }
243}
244
245impl IntoResponse for (StatusCode, String) {
246 fn into_response(self) -> ResponseFuture {
247 let (status, body) = self;
248 Box::pin(ready(
249 Response::builder()
250 .status(status)
251 .body(Body::from(body))
252 .expect("unable to build http::Response"),
253 ))
254 }
255}
256
257impl IntoResponse for (StatusCode, &str) {
258 fn into_response(self) -> ResponseFuture {
259 let (status, body) = self;
260 Box::pin(ready(
261 Response::builder()
262 .status(status)
263 .body(Body::from(body))
264 .expect("unable to build http::Response"),
265 ))
266 }
267}
268
269impl IntoResponse for (StatusCode, &[u8]) {
270 fn into_response(self) -> ResponseFuture {
271 let (status, body) = self;
272 Box::pin(ready(
273 Response::builder()
274 .status(status)
275 .body(Body::from(body))
276 .expect("unable to build http::Response"),
277 ))
278 }
279}
280
281impl IntoResponse for (StatusCode, Vec<u8>) {
282 fn into_response(self) -> ResponseFuture {
283 let (status, body) = self;
284 Box::pin(ready(
285 Response::builder()
286 .status(status)
287 .body(Body::from(body))
288 .expect("unable to build http::Response"),
289 ))
290 }
291}
292
293impl IntoResponse for (StatusCode, serde_json::Value) {
294 fn into_response(self) -> ResponseFuture {
295 let (status, body) = self;
296 Box::pin(async move {
297 Response::builder()
298 .status(status)
299 .header(CONTENT_TYPE, "application/json")
300 .body(
301 serde_json::to_string(&body)
302 .expect("unable to serialize serde_json::Value")
303 .into(),
304 )
305 .expect("unable to build http::Response")
306 })
307 }
308}
309
310pub type ResponseFuture = Pin<Box<dyn Future<Output = Response<Body>> + Send>>;
311
312pub trait ConvertBody {
313 fn convert(self, parts: HeaderMap) -> BodyFuture;
314}
315
316impl<B> ConvertBody for B
317where
318 B: HttpBody + Unpin + Send + 'static,
319 B::Data: Send,
320 B::Error: fmt::Debug,
321{
322 fn convert(self, headers: HeaderMap) -> BodyFuture {
323 if headers.get(CONTENT_ENCODING).is_some() {
324 return convert_to_binary(self);
325 }
326
327 let content_type = if let Some(value) = headers.get(CONTENT_TYPE) {
328 value.to_str().unwrap_or_default()
329 } else {
330 return convert_to_text(self, "utf-8");
332 };
333
334 for prefix in TEXT_ENCODING_PREFIXES {
335 if content_type.starts_with(prefix) {
336 return convert_to_text(self, content_type);
337 }
338 }
339
340 for suffix in TEXT_ENCODING_SUFFIXES {
341 let mut parts = content_type.trim().split(';');
342 let mime_type = parts.next().unwrap_or_default();
343 if mime_type.ends_with(suffix) {
344 return convert_to_text(self, content_type);
345 }
346 }
347
348 if let Some(value) = headers.get(X_LAMBDA_HTTP_CONTENT_ENCODING) {
349 if value == "text" {
350 return convert_to_text(self, content_type);
351 }
352 }
353
354 convert_to_binary(self)
355 }
356}
357
358fn convert_to_binary<B>(body: B) -> BodyFuture
359where
360 B: HttpBody + Unpin + Send + 'static,
361 B::Data: Send,
362 B::Error: fmt::Debug,
363{
364 Box::pin(async move {
365 Body::from(
366 body.collect()
367 .await
368 .expect("unable to read bytes from body")
369 .to_bytes()
370 .to_vec(),
371 )
372 })
373}
374
375fn convert_to_text<B>(body: B, content_type: &str) -> BodyFuture
376where
377 B: HttpBody + Unpin + Send + 'static,
378 B::Data: Send,
379 B::Error: fmt::Debug,
380{
381 let mime_type = content_type.parse::<Mime>();
382
383 let encoding = match mime_type.as_ref() {
384 Ok(mime) => mime.get_param(CHARSET).unwrap_or(mime::UTF_8),
385 Err(_) => mime::UTF_8,
386 };
387
388 let label = encoding.as_ref().as_bytes();
389 let encoding = Encoding::for_label(label).unwrap_or(encoding_rs::UTF_8);
390
391 Box::pin(async move {
393 let bytes = body.collect().await.expect("unable to read bytes from body").to_bytes();
394 let (content, _, _) = encoding.decode(&bytes);
395
396 match content {
397 Cow::Borrowed(content) => Body::from(content),
398 Cow::Owned(content) => Body::from(content),
399 }
400 })
401}
402
403pub type BodyFuture = Pin<Box<dyn Future<Output = Body> + Send>>;
404
405#[cfg(test)]
406mod tests {
407 use super::{Body, IntoResponse, LambdaResponse, RequestOrigin, X_LAMBDA_HTTP_CONTENT_ENCODING};
408 use http::{
409 header::{CONTENT_ENCODING, CONTENT_TYPE},
410 Response, StatusCode,
411 };
412 use lambda_runtime_api_client::body::Body as HyperBody;
413 use serde_json::{self, json};
414
415 const SVG_LOGO: &str = include_str!("../tests/data/svg_logo.svg");
416
417 #[tokio::test]
418 async fn json_into_response() {
419 let response = json!({ "hello": "lambda"}).into_response().await;
420 match response.body() {
421 Body::Text(json) => assert_eq!(json, r#"{"hello":"lambda"}"#),
422 _ => panic!("invalid body"),
423 }
424 assert_eq!(
425 response
426 .headers()
427 .get(CONTENT_TYPE)
428 .map(|h| h.to_str().expect("invalid header")),
429 Some("application/json")
430 )
431 }
432
433 #[tokio::test]
434 async fn text_into_response() {
435 let response = "text".into_response().await;
436 match response.body() {
437 Body::Text(text) => assert_eq!(text, "text"),
438 _ => panic!("invalid body"),
439 }
440 }
441
442 #[tokio::test]
443 async fn bytes_into_response() {
444 let response = "text".as_bytes().into_response().await;
445 match response.body() {
446 Body::Binary(data) => assert_eq!(data, "text".as_bytes()),
447 _ => panic!("invalid body"),
448 }
449 }
450
451 #[tokio::test]
452 async fn json_with_status_code_into_response() {
453 let response = (StatusCode::CREATED, json!({ "hello": "lambda"})).into_response().await;
454 match response.body() {
455 Body::Text(json) => assert_eq!(json, r#"{"hello":"lambda"}"#),
456 _ => panic!("invalid body"),
457 }
458 match response.status() {
459 StatusCode::CREATED => (),
460 _ => panic!("invalid status code"),
461 }
462
463 assert_eq!(
464 response
465 .headers()
466 .get(CONTENT_TYPE)
467 .map(|h| h.to_str().expect("invalid header")),
468 Some("application/json")
469 )
470 }
471
472 #[tokio::test]
473 async fn text_with_status_code_into_response() {
474 let response = (StatusCode::CREATED, "text").into_response().await;
475
476 match response.status() {
477 StatusCode::CREATED => (),
478 _ => panic!("invalid status code"),
479 }
480 match response.body() {
481 Body::Text(text) => assert_eq!(text, "text"),
482 _ => panic!("invalid body"),
483 }
484 }
485
486 #[tokio::test]
487 async fn bytes_with_status_code_into_response() {
488 let response = (StatusCode::CREATED, "text".as_bytes()).into_response().await;
489 match response.status() {
490 StatusCode::CREATED => (),
491 _ => panic!("invalid status code"),
492 }
493 match response.body() {
494 Body::Binary(data) => assert_eq!(data, "text".as_bytes()),
495 _ => panic!("invalid body"),
496 }
497 }
498
499 #[tokio::test]
500 async fn content_encoding_header() {
501 let response = Response::builder()
504 .header(CONTENT_ENCODING, "gzip")
505 .body(HyperBody::from("000000".as_bytes()))
506 .expect("unable to build http::Response");
507 let response = response.into_response().await;
508 let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
509
510 let json = serde_json::to_string(&response).expect("failed to serialize to json");
511 assert_eq!(
512 json,
513 r#"{"statusCode":200,"headers":{"content-encoding":"gzip"},"multiValueHeaders":{},"body":"MDAwMDAw","isBase64Encoded":true,"cookies":[]}"#
514 )
515 }
516
517 #[tokio::test]
518 async fn content_type_header() {
519 let response = Response::builder()
522 .header(CONTENT_TYPE, "application/json")
523 .body(HyperBody::from("000000".as_bytes()))
524 .expect("unable to build http::Response");
525 let response = response.into_response().await;
526 let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
527
528 let json = serde_json::to_string(&response).expect("failed to serialize to json");
529 assert_eq!(
530 json,
531 r#"{"statusCode":200,"headers":{"content-type":"application/json"},"multiValueHeaders":{},"body":"000000","isBase64Encoded":false,"cookies":[]}"#
532 )
533 }
534
535 #[tokio::test]
536 async fn charset_content_type_header() {
537 let response = Response::builder()
540 .header(CONTENT_TYPE, "application/json; charset=utf-16")
541 .body(HyperBody::from("000000".as_bytes()))
542 .expect("unable to build http::Response");
543 let response = response.into_response().await;
544 let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
545
546 let json = serde_json::to_string(&response).expect("failed to serialize to json");
547 assert_eq!(
548 json,
549 r#"{"statusCode":200,"headers":{"content-type":"application/json; charset=utf-16"},"multiValueHeaders":{},"body":"〰〰〰","isBase64Encoded":false,"cookies":[]}"#
550 )
551 }
552
553 #[tokio::test]
554 async fn charset_content_type_header_suffix() {
555 let response = Response::builder()
558 .header(CONTENT_TYPE, "application/graphql-response+json; charset=utf-16")
559 .body(HyperBody::from("000000".as_bytes()))
560 .expect("unable to build http::Response");
561 let response = response.into_response().await;
562 let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
563
564 let json = serde_json::to_string(&response).expect("failed to serialize to json");
565 assert_eq!(
566 json,
567 r#"{"statusCode":200,"headers":{"content-type":"application/graphql-response+json; charset=utf-16"},"multiValueHeaders":{},"body":"〰〰〰","isBase64Encoded":false,"cookies":[]}"#
568 )
569 }
570
571 #[tokio::test]
572 async fn content_headers_unset() {
573 let response = Response::builder()
576 .body(HyperBody::from("000000".as_bytes()))
577 .expect("unable to build http::Response");
578 let response = response.into_response().await;
579 let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
580
581 let json = serde_json::to_string(&response).expect("failed to serialize to json");
582 assert_eq!(
583 json,
584 r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"000000","isBase64Encoded":false,"cookies":[]}"#
585 )
586 }
587
588 #[test]
589 fn serialize_multi_value_headers() {
590 let res = LambdaResponse::from_response(
591 &RequestOrigin::ApiGatewayV1,
592 Response::builder()
593 .header("multi", "a")
594 .header("multi", "b")
595 .body(Body::from(()))
596 .expect("failed to create response"),
597 );
598 let json = serde_json::to_string(&res).expect("failed to serialize to json");
599 assert_eq!(
600 json,
601 r#"{"statusCode":200,"headers":{},"multiValueHeaders":{"multi":["a","b"]},"isBase64Encoded":false}"#
602 )
603 }
604
605 #[test]
606 fn serialize_cookies() {
607 let res = LambdaResponse::from_response(
608 &RequestOrigin::ApiGatewayV2,
609 Response::builder()
610 .header("set-cookie", "cookie1=a")
611 .header("set-cookie", "cookie2=b")
612 .body(Body::from(()))
613 .expect("failed to create response"),
614 );
615 let json = serde_json::to_string(&res).expect("failed to serialize to json");
616 assert_eq!(
617 "{\"statusCode\":200,\"headers\":{},\"multiValueHeaders\":{},\"isBase64Encoded\":false,\"cookies\":[\"cookie1=a\",\"cookie2=b\"]}",
618 json
619 )
620 }
621
622 #[tokio::test]
623 async fn content_type_xml_as_text() {
624 let response = Response::builder()
627 .header(CONTENT_TYPE, "image/svg+xml")
628 .body(HyperBody::from(SVG_LOGO.as_bytes()))
629 .expect("unable to build http::Response");
630 let response = response.into_response().await;
631
632 match response.body() {
633 Body::Text(body) => assert_eq!(SVG_LOGO, body),
634 _ => panic!("invalid body"),
635 }
636 assert_eq!(
637 response
638 .headers()
639 .get(CONTENT_TYPE)
640 .map(|h| h.to_str().expect("invalid header")),
641 Some("image/svg+xml")
642 )
643 }
644
645 #[tokio::test]
646 async fn content_type_custom_encoding_as_text() {
647 let response = Response::builder()
650 .header(CONTENT_TYPE, "image/svg")
652 .header(X_LAMBDA_HTTP_CONTENT_ENCODING, "text")
653 .body(HyperBody::from(SVG_LOGO.as_bytes()))
654 .expect("unable to build http::Response");
655 let response = response.into_response().await;
656
657 match response.body() {
658 Body::Text(body) => assert_eq!(SVG_LOGO, body),
659 _ => panic!("invalid body"),
660 }
661 assert_eq!(
662 response
663 .headers()
664 .get(CONTENT_TYPE)
665 .map(|h| h.to_str().expect("invalid header")),
666 Some("image/svg")
667 )
668 }
669
670 #[tokio::test]
671 async fn content_type_yaml_as_text() {
672 let yaml = r#"---
675foo: bar
676 "#;
677
678 let formats = ["application/yaml", "custom/vdn+yaml"];
679
680 for format in formats {
681 let response = Response::builder()
682 .header(CONTENT_TYPE, format)
683 .body(HyperBody::from(yaml.as_bytes()))
684 .expect("unable to build http::Response");
685 let response = response.into_response().await;
686
687 match response.body() {
688 Body::Text(body) => assert_eq!(yaml, body),
689 _ => panic!("invalid body"),
690 }
691 assert_eq!(
692 response
693 .headers()
694 .get(CONTENT_TYPE)
695 .map(|h| h.to_str().expect("invalid header")),
696 Some(format)
697 )
698 }
699 }
700}