1use crate::request::RequestOrigin;
4#[cfg(feature = "alb")]
5use aws_lambda_events::alb::AlbTargetGroupResponse;
6#[cfg(any(feature = "apigw_rest", feature = "apigw_websockets"))]
7use aws_lambda_events::apigw::ApiGatewayProxyResponse;
8#[cfg(feature = "apigw_http")]
9use aws_lambda_events::apigw::ApiGatewayV2httpResponse;
10use aws_lambda_events::encodings::Body;
11use encoding_rs::Encoding;
12use http::{
13 header::{CONTENT_ENCODING, CONTENT_TYPE},
14 HeaderMap, Response, StatusCode,
15};
16use http_body::Body as HttpBody;
17use http_body_util::BodyExt;
18use mime::{Mime, CHARSET};
19use serde::Serialize;
20use std::{
21 borrow::Cow,
22 fmt,
23 future::{ready, Future},
24 pin::Pin,
25};
26
27const X_LAMBDA_HTTP_CONTENT_ENCODING: &str = "x-lambda-http-content-encoding";
28
29const TEXT_ENCODING_PREFIXES: [&str; 5] = [
33 "text",
34 "application/json",
35 "application/javascript",
36 "application/xml",
37 "application/yaml",
38];
39
40const TEXT_ENCODING_SUFFIXES: [&str; 3] = ["+xml", "+yaml", "+json"];
41
42#[doc(hidden)]
44#[derive(Serialize, Debug)]
45#[serde(untagged)]
46pub enum LambdaResponse {
47 #[cfg(any(feature = "apigw_rest", feature = "apigw_websockets"))]
48 ApiGatewayV1(ApiGatewayProxyResponse),
49 #[cfg(feature = "apigw_http")]
50 ApiGatewayV2(ApiGatewayV2httpResponse),
51 #[cfg(feature = "alb")]
52 Alb(AlbTargetGroupResponse),
53 #[cfg(feature = "pass_through")]
54 PassThrough(serde_json::Value),
55}
56
57impl LambdaResponse {
59 pub(crate) fn from_response(request_origin: &RequestOrigin, value: Response<Body>) -> Self {
60 let (parts, bod) = value.into_parts();
61 let (is_base64_encoded, body) = match bod {
62 Body::Empty => (false, None),
63 b @ Body::Text(_) => (false, Some(b)),
64 b @ Body::Binary(_) => (true, Some(b)),
65 };
66
67 let headers = parts.headers;
68 let status_code = parts.status.as_u16();
69
70 match request_origin {
71 #[cfg(feature = "apigw_rest")]
72 RequestOrigin::ApiGatewayV1 => LambdaResponse::ApiGatewayV1(ApiGatewayProxyResponse {
73 body,
74 is_base64_encoded,
75 status_code: status_code as i64,
76 headers: HeaderMap::new(),
79 multi_value_headers: headers,
80 }),
81 #[cfg(feature = "apigw_http")]
82 RequestOrigin::ApiGatewayV2 => {
83 use http::header::SET_COOKIE;
84 let mut headers = headers;
85 let cookies = headers
88 .get_all(SET_COOKIE)
89 .iter()
90 .cloned()
91 .map(|v| v.to_str().ok().unwrap_or_default().to_string())
92 .collect();
93 headers.remove(SET_COOKIE);
94
95 LambdaResponse::ApiGatewayV2(ApiGatewayV2httpResponse {
96 body,
97 is_base64_encoded,
98 status_code: status_code as i64,
99 cookies,
100 headers,
103 multi_value_headers: HeaderMap::new(),
104 })
105 }
106 #[cfg(feature = "alb")]
107 RequestOrigin::Alb => LambdaResponse::Alb(AlbTargetGroupResponse {
108 body,
109 status_code: status_code as i64,
110 is_base64_encoded,
111 headers: headers.clone(),
115 multi_value_headers: headers,
116 status_description: Some(format!(
117 "{} {}",
118 status_code,
119 parts.status.canonical_reason().unwrap_or_default()
120 )),
121 }),
122 #[cfg(feature = "apigw_websockets")]
123 RequestOrigin::WebSocket => LambdaResponse::ApiGatewayV1(ApiGatewayProxyResponse {
124 body,
125 is_base64_encoded,
126 status_code: status_code as i64,
127 headers: HeaderMap::new(),
130 multi_value_headers: headers,
131 }),
132 #[cfg(feature = "pass_through")]
133 RequestOrigin::PassThrough => {
134 match body {
135 Some(Body::Text(body)) => {LambdaResponse::PassThrough(serde_json::from_str(&body).unwrap_or_default())},
137 _ => LambdaResponse::PassThrough(serde_json::Value::Null),
139 }
140 }
141 #[cfg(not(any(
142 feature = "apigw_rest",
143 feature = "apigw_http",
144 feature = "alb",
145 feature = "apigw_websockets"
146 )))]
147 _ => compile_error!("Either feature `apigw_rest`, `apigw_http`, `alb`, or `apigw_websockets` must be enabled for the `lambda-http` crate."),
148 }
149 }
150}
151
152pub trait IntoResponse {
156 fn into_response(self) -> ResponseFuture;
158}
159
160impl<B> IntoResponse for Response<B>
161where
162 B: ConvertBody + Send + 'static,
163{
164 fn into_response(self) -> ResponseFuture {
165 let (parts, body) = self.into_parts();
166 let headers = parts.headers.clone();
167
168 let fut = async { Response::from_parts(parts, body.convert(headers).await) };
169
170 Box::pin(fut)
171 }
172}
173
174impl IntoResponse for String {
175 fn into_response(self) -> ResponseFuture {
176 Box::pin(ready(Response::new(Body::from(self))))
177 }
178}
179
180impl IntoResponse for &str {
181 fn into_response(self) -> ResponseFuture {
182 Box::pin(ready(Response::new(Body::from(self))))
183 }
184}
185
186impl IntoResponse for &[u8] {
187 fn into_response(self) -> ResponseFuture {
188 Box::pin(ready(Response::new(Body::from(self))))
189 }
190}
191
192impl IntoResponse for Vec<u8> {
193 fn into_response(self) -> ResponseFuture {
194 Box::pin(ready(Response::new(Body::from(self))))
195 }
196}
197
198impl IntoResponse for serde_json::Value {
199 fn into_response(self) -> ResponseFuture {
200 Box::pin(async move {
201 Response::builder()
202 .header(CONTENT_TYPE, "application/json")
203 .body(
204 serde_json::to_string(&self)
205 .expect("unable to serialize serde_json::Value")
206 .into(),
207 )
208 .expect("unable to build http::Response")
209 })
210 }
211}
212
213impl IntoResponse for (StatusCode, String) {
214 fn into_response(self) -> ResponseFuture {
215 let (status, body) = self;
216 Box::pin(ready(
217 Response::builder()
218 .status(status)
219 .body(Body::from(body))
220 .expect("unable to build http::Response"),
221 ))
222 }
223}
224
225impl IntoResponse for (StatusCode, &str) {
226 fn into_response(self) -> ResponseFuture {
227 let (status, body) = self;
228 Box::pin(ready(
229 Response::builder()
230 .status(status)
231 .body(Body::from(body))
232 .expect("unable to build http::Response"),
233 ))
234 }
235}
236
237impl IntoResponse for (StatusCode, &[u8]) {
238 fn into_response(self) -> ResponseFuture {
239 let (status, body) = self;
240 Box::pin(ready(
241 Response::builder()
242 .status(status)
243 .body(Body::from(body))
244 .expect("unable to build http::Response"),
245 ))
246 }
247}
248
249impl IntoResponse for (StatusCode, Vec<u8>) {
250 fn into_response(self) -> ResponseFuture {
251 let (status, body) = self;
252 Box::pin(ready(
253 Response::builder()
254 .status(status)
255 .body(Body::from(body))
256 .expect("unable to build http::Response"),
257 ))
258 }
259}
260
261impl IntoResponse for (StatusCode, serde_json::Value) {
262 fn into_response(self) -> ResponseFuture {
263 let (status, body) = self;
264 Box::pin(async move {
265 Response::builder()
266 .status(status)
267 .header(CONTENT_TYPE, "application/json")
268 .body(
269 serde_json::to_string(&body)
270 .expect("unable to serialize serde_json::Value")
271 .into(),
272 )
273 .expect("unable to build http::Response")
274 })
275 }
276}
277
278pub type ResponseFuture = Pin<Box<dyn Future<Output = Response<Body>> + Send>>;
279
280pub trait ConvertBody {
281 fn convert(self, parts: HeaderMap) -> BodyFuture;
282}
283
284impl<B> ConvertBody for B
285where
286 B: HttpBody + Unpin + Send + 'static,
287 B::Data: Send,
288 B::Error: fmt::Debug,
289{
290 fn convert(self, headers: HeaderMap) -> BodyFuture {
291 if headers.get(CONTENT_ENCODING).is_some() {
292 return convert_to_binary(self);
293 }
294
295 let content_type = if let Some(value) = headers.get(CONTENT_TYPE) {
296 value.to_str().unwrap_or_default()
297 } else {
298 return convert_to_text(self, "utf-8");
300 };
301
302 for prefix in TEXT_ENCODING_PREFIXES {
303 if content_type.starts_with(prefix) {
304 return convert_to_text(self, content_type);
305 }
306 }
307
308 for suffix in TEXT_ENCODING_SUFFIXES {
309 let mut parts = content_type.trim().split(';');
310 let mime_type = parts.next().unwrap_or_default();
311 if mime_type.ends_with(suffix) {
312 return convert_to_text(self, content_type);
313 }
314 }
315
316 if let Some(value) = headers.get(X_LAMBDA_HTTP_CONTENT_ENCODING) {
317 if value == "text" {
318 return convert_to_text(self, content_type);
319 }
320 }
321
322 convert_to_binary(self)
323 }
324}
325
326fn convert_to_binary<B>(body: B) -> BodyFuture
327where
328 B: HttpBody + Unpin + Send + 'static,
329 B::Data: Send,
330 B::Error: fmt::Debug,
331{
332 Box::pin(async move {
333 Body::from(
334 body.collect()
335 .await
336 .expect("unable to read bytes from body")
337 .to_bytes()
338 .to_vec(),
339 )
340 })
341}
342
343fn convert_to_text<B>(body: B, content_type: &str) -> BodyFuture
344where
345 B: HttpBody + Unpin + Send + 'static,
346 B::Data: Send,
347 B::Error: fmt::Debug,
348{
349 let mime_type = content_type.parse::<Mime>();
350
351 let encoding = match mime_type.as_ref() {
352 Ok(mime) => mime.get_param(CHARSET).unwrap_or(mime::UTF_8),
353 Err(_) => mime::UTF_8,
354 };
355
356 let label = encoding.as_ref().as_bytes();
357 let encoding = Encoding::for_label(label).unwrap_or(encoding_rs::UTF_8);
358
359 Box::pin(async move {
361 let bytes = body.collect().await.expect("unable to read bytes from body").to_bytes();
362 let (content, _, _) = encoding.decode(&bytes);
363
364 match content {
365 Cow::Borrowed(content) => Body::from(content),
366 Cow::Owned(content) => Body::from(content),
367 }
368 })
369}
370
371pub type BodyFuture = Pin<Box<dyn Future<Output = Body> + Send>>;
372
373#[cfg(test)]
374mod tests {
375 use super::{Body, IntoResponse, LambdaResponse, RequestOrigin, X_LAMBDA_HTTP_CONTENT_ENCODING};
376 use http::{
377 header::{CONTENT_ENCODING, CONTENT_TYPE},
378 Response, StatusCode,
379 };
380 use lambda_runtime_api_client::body::Body as HyperBody;
381 use serde_json::{self, json};
382
383 const SVG_LOGO: &str = include_str!("../tests/data/svg_logo.svg");
384
385 #[tokio::test]
386 async fn json_into_response() {
387 let response = json!({ "hello": "lambda"}).into_response().await;
388 match response.body() {
389 Body::Text(json) => assert_eq!(json, r#"{"hello":"lambda"}"#),
390 _ => panic!("invalid body"),
391 }
392 assert_eq!(
393 response
394 .headers()
395 .get(CONTENT_TYPE)
396 .map(|h| h.to_str().expect("invalid header")),
397 Some("application/json")
398 )
399 }
400
401 #[tokio::test]
402 async fn text_into_response() {
403 let response = "text".into_response().await;
404 match response.body() {
405 Body::Text(text) => assert_eq!(text, "text"),
406 _ => panic!("invalid body"),
407 }
408 }
409
410 #[tokio::test]
411 async fn bytes_into_response() {
412 let response = "text".as_bytes().into_response().await;
413 match response.body() {
414 Body::Binary(data) => assert_eq!(data, "text".as_bytes()),
415 _ => panic!("invalid body"),
416 }
417 }
418
419 #[tokio::test]
420 async fn json_with_status_code_into_response() {
421 let response = (StatusCode::CREATED, json!({ "hello": "lambda"})).into_response().await;
422 match response.body() {
423 Body::Text(json) => assert_eq!(json, r#"{"hello":"lambda"}"#),
424 _ => panic!("invalid body"),
425 }
426 match response.status() {
427 StatusCode::CREATED => (),
428 _ => panic!("invalid status code"),
429 }
430
431 assert_eq!(
432 response
433 .headers()
434 .get(CONTENT_TYPE)
435 .map(|h| h.to_str().expect("invalid header")),
436 Some("application/json")
437 )
438 }
439
440 #[tokio::test]
441 async fn text_with_status_code_into_response() {
442 let response = (StatusCode::CREATED, "text").into_response().await;
443
444 match response.status() {
445 StatusCode::CREATED => (),
446 _ => panic!("invalid status code"),
447 }
448 match response.body() {
449 Body::Text(text) => assert_eq!(text, "text"),
450 _ => panic!("invalid body"),
451 }
452 }
453
454 #[tokio::test]
455 async fn bytes_with_status_code_into_response() {
456 let response = (StatusCode::CREATED, "text".as_bytes()).into_response().await;
457 match response.status() {
458 StatusCode::CREATED => (),
459 _ => panic!("invalid status code"),
460 }
461 match response.body() {
462 Body::Binary(data) => assert_eq!(data, "text".as_bytes()),
463 _ => panic!("invalid body"),
464 }
465 }
466
467 #[tokio::test]
468 async fn content_encoding_header() {
469 let response = Response::builder()
472 .header(CONTENT_ENCODING, "gzip")
473 .body(HyperBody::from("000000".as_bytes()))
474 .expect("unable to build http::Response");
475 let response = response.into_response().await;
476 let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
477
478 let json = serde_json::to_string(&response).expect("failed to serialize to json");
479 assert_eq!(
480 json,
481 r#"{"statusCode":200,"headers":{"content-encoding":"gzip"},"multiValueHeaders":{},"body":"MDAwMDAw","isBase64Encoded":true,"cookies":[]}"#
482 )
483 }
484
485 #[tokio::test]
486 async fn content_type_header() {
487 let response = Response::builder()
490 .header(CONTENT_TYPE, "application/json")
491 .body(HyperBody::from("000000".as_bytes()))
492 .expect("unable to build http::Response");
493 let response = response.into_response().await;
494 let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
495
496 let json = serde_json::to_string(&response).expect("failed to serialize to json");
497 assert_eq!(
498 json,
499 r#"{"statusCode":200,"headers":{"content-type":"application/json"},"multiValueHeaders":{},"body":"000000","isBase64Encoded":false,"cookies":[]}"#
500 )
501 }
502
503 #[tokio::test]
504 async fn charset_content_type_header() {
505 let response = Response::builder()
508 .header(CONTENT_TYPE, "application/json; charset=utf-16")
509 .body(HyperBody::from("000000".as_bytes()))
510 .expect("unable to build http::Response");
511 let response = response.into_response().await;
512 let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
513
514 let json = serde_json::to_string(&response).expect("failed to serialize to json");
515 assert_eq!(
516 json,
517 r#"{"statusCode":200,"headers":{"content-type":"application/json; charset=utf-16"},"multiValueHeaders":{},"body":"〰〰〰","isBase64Encoded":false,"cookies":[]}"#
518 )
519 }
520
521 #[tokio::test]
522 async fn charset_content_type_header_suffix() {
523 let response = Response::builder()
526 .header(CONTENT_TYPE, "application/graphql-response+json; charset=utf-16")
527 .body(HyperBody::from("000000".as_bytes()))
528 .expect("unable to build http::Response");
529 let response = response.into_response().await;
530 let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
531
532 let json = serde_json::to_string(&response).expect("failed to serialize to json");
533 assert_eq!(
534 json,
535 r#"{"statusCode":200,"headers":{"content-type":"application/graphql-response+json; charset=utf-16"},"multiValueHeaders":{},"body":"〰〰〰","isBase64Encoded":false,"cookies":[]}"#
536 )
537 }
538
539 #[tokio::test]
540 async fn content_headers_unset() {
541 let response = Response::builder()
544 .body(HyperBody::from("000000".as_bytes()))
545 .expect("unable to build http::Response");
546 let response = response.into_response().await;
547 let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
548
549 let json = serde_json::to_string(&response).expect("failed to serialize to json");
550 assert_eq!(
551 json,
552 r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"000000","isBase64Encoded":false,"cookies":[]}"#
553 )
554 }
555
556 #[test]
557 fn serialize_multi_value_headers() {
558 let res = LambdaResponse::from_response(
559 &RequestOrigin::ApiGatewayV1,
560 Response::builder()
561 .header("multi", "a")
562 .header("multi", "b")
563 .body(Body::from(()))
564 .expect("failed to create response"),
565 );
566 let json = serde_json::to_string(&res).expect("failed to serialize to json");
567 assert_eq!(
568 json,
569 r#"{"statusCode":200,"headers":{},"multiValueHeaders":{"multi":["a","b"]},"isBase64Encoded":false}"#
570 )
571 }
572
573 #[test]
574 fn serialize_cookies() {
575 let res = LambdaResponse::from_response(
576 &RequestOrigin::ApiGatewayV2,
577 Response::builder()
578 .header("set-cookie", "cookie1=a")
579 .header("set-cookie", "cookie2=b")
580 .body(Body::from(()))
581 .expect("failed to create response"),
582 );
583 let json = serde_json::to_string(&res).expect("failed to serialize to json");
584 assert_eq!(
585 "{\"statusCode\":200,\"headers\":{},\"multiValueHeaders\":{},\"isBase64Encoded\":false,\"cookies\":[\"cookie1=a\",\"cookie2=b\"]}",
586 json
587 )
588 }
589
590 #[tokio::test]
591 async fn content_type_xml_as_text() {
592 let response = Response::builder()
595 .header(CONTENT_TYPE, "image/svg+xml")
596 .body(HyperBody::from(SVG_LOGO.as_bytes()))
597 .expect("unable to build http::Response");
598 let response = response.into_response().await;
599
600 match response.body() {
601 Body::Text(body) => assert_eq!(SVG_LOGO, body),
602 _ => panic!("invalid body"),
603 }
604 assert_eq!(
605 response
606 .headers()
607 .get(CONTENT_TYPE)
608 .map(|h| h.to_str().expect("invalid header")),
609 Some("image/svg+xml")
610 )
611 }
612
613 #[tokio::test]
614 async fn content_type_custom_encoding_as_text() {
615 let response = Response::builder()
618 .header(CONTENT_TYPE, "image/svg")
620 .header(X_LAMBDA_HTTP_CONTENT_ENCODING, "text")
621 .body(HyperBody::from(SVG_LOGO.as_bytes()))
622 .expect("unable to build http::Response");
623 let response = response.into_response().await;
624
625 match response.body() {
626 Body::Text(body) => assert_eq!(SVG_LOGO, body),
627 _ => panic!("invalid body"),
628 }
629 assert_eq!(
630 response
631 .headers()
632 .get(CONTENT_TYPE)
633 .map(|h| h.to_str().expect("invalid header")),
634 Some("image/svg")
635 )
636 }
637
638 #[tokio::test]
639 async fn content_type_yaml_as_text() {
640 let yaml = r#"---
643foo: bar
644 "#;
645
646 let formats = ["application/yaml", "custom/vdn+yaml"];
647
648 for format in formats {
649 let response = Response::builder()
650 .header(CONTENT_TYPE, format)
651 .body(HyperBody::from(yaml.as_bytes()))
652 .expect("unable to build http::Response");
653 let response = response.into_response().await;
654
655 match response.body() {
656 Body::Text(body) => assert_eq!(yaml, body),
657 _ => panic!("invalid body"),
658 }
659 assert_eq!(
660 response
661 .headers()
662 .get(CONTENT_TYPE)
663 .map(|h| h.to_str().expect("invalid header")),
664 Some(format)
665 )
666 }
667 }
668}