lambda_http/
response.rs

1//! Response types
2
3use crate::request::RequestOrigin;
4#[cfg(feature = "alb")]
5use aws_lambda_events::alb::AlbTargetGroupResponse;
6#[cfg(any(feature = "apigw_rest", feature = "apigw_websockets"))]
7use aws_lambda_events::apigw::ApiGatewayProxyResponse;
8#[cfg(feature = "apigw_http")]
9use aws_lambda_events::apigw::ApiGatewayV2httpResponse;
10use aws_lambda_events::encodings::Body;
11use encoding_rs::Encoding;
12use http::{
13    header::{CONTENT_ENCODING, CONTENT_TYPE},
14    HeaderMap, Response, StatusCode,
15};
16use http_body::Body as HttpBody;
17use http_body_util::BodyExt;
18use mime::{Mime, CHARSET};
19use serde::Serialize;
20use std::{
21    borrow::Cow,
22    fmt,
23    future::{ready, Future},
24    pin::Pin,
25};
26
27const X_LAMBDA_HTTP_CONTENT_ENCODING: &str = "x-lambda-http-content-encoding";
28
29// See list of common MIME types:
30// - https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types/Common_types
31// - https://github.com/ietf-wg-httpapi/mediatypes/blob/main/draft-ietf-httpapi-yaml-mediatypes.md
32const TEXT_ENCODING_PREFIXES: [&str; 5] = [
33    "text",
34    "application/json",
35    "application/javascript",
36    "application/xml",
37    "application/yaml",
38];
39
40const TEXT_ENCODING_SUFFIXES: [&str; 3] = ["+xml", "+yaml", "+json"];
41
42/// Representation of Lambda response
43#[non_exhaustive]
44#[doc(hidden)]
45#[derive(Serialize, Debug)]
46#[serde(untagged)]
47pub enum LambdaResponse {
48    #[cfg(any(feature = "apigw_rest", feature = "apigw_websockets"))]
49    ApiGatewayV1(ApiGatewayProxyResponse),
50    #[cfg(feature = "apigw_http")]
51    ApiGatewayV2(ApiGatewayV2httpResponse),
52    #[cfg(feature = "alb")]
53    Alb(AlbTargetGroupResponse),
54    #[cfg(feature = "pass_through")]
55    PassThrough(serde_json::Value),
56}
57
58/// Transformation from http type to internal type
59impl LambdaResponse {
60    pub(crate) fn from_response(request_origin: &RequestOrigin, value: Response<Body>) -> Self {
61        let (parts, bod) = value.into_parts();
62        let (is_base64_encoded, body) = match bod {
63            Body::Empty => (false, None),
64            b @ Body::Text(_) => (false, Some(b)),
65            b @ Body::Binary(_) => (true, Some(b)),
66            _ => (false, None),
67        };
68
69        let headers = parts.headers;
70        let status_code = parts.status.as_u16();
71
72        match request_origin {
73            #[cfg(feature = "apigw_rest")]
74            RequestOrigin::ApiGatewayV1 => LambdaResponse::ApiGatewayV1({
75                let mut response = ApiGatewayProxyResponse::default();
76
77                response.body = body;
78                response.is_base64_encoded = is_base64_encoded;
79                response.status_code = status_code as i64;
80                // Explicitly empty, as API gateway v1 will merge "headers" and
81                // "multi_value_headers" fields together resulting in duplicate response headers.
82                response.headers = HeaderMap::new();
83                response.multi_value_headers = headers;
84                // Today, this implementation doesn't provide any additional fields
85                #[cfg(feature = "catch-all-fields")]
86                {
87                    response.other = Default::default();
88                }
89                response
90            }),
91            #[cfg(feature = "apigw_http")]
92            RequestOrigin::ApiGatewayV2 => {
93                use http::header::SET_COOKIE;
94                let mut headers = headers;
95                // ApiGatewayV2 expects the set-cookies headers to be in the "cookies" attribute,
96                // so remove them from the headers.
97                let cookies = headers
98                    .get_all(SET_COOKIE)
99                    .iter()
100                    .cloned()
101                    .map(|v| v.to_str().ok().unwrap_or_default().to_string())
102                    .collect();
103                headers.remove(SET_COOKIE);
104
105                LambdaResponse::ApiGatewayV2({
106                    let mut response = ApiGatewayV2httpResponse::default();
107                    response.body = body;
108                    response.is_base64_encoded = is_base64_encoded;
109                    response.status_code = status_code as i64;
110                    response.cookies = cookies;
111                    // API gateway v2 doesn't have "multi_value_headers" field. Duplicate headers
112                    // are combined with commas and included in the headers field.
113                    response.headers = headers;
114                    response.multi_value_headers = HeaderMap::new();
115                    // Today, this implementation doesn't provide any additional fields
116                    #[cfg(feature = "catch-all-fields")]
117                    {
118                        response.other = Default::default();
119                    }
120                    response
121                })
122            }
123            #[cfg(feature = "alb")]
124            RequestOrigin::Alb => LambdaResponse::Alb({
125                let mut response = AlbTargetGroupResponse::default();
126
127                response.body = body;
128                response.is_base64_encoded = is_base64_encoded;
129                response.status_code = status_code as i64;
130                // ALB responses are used for ALB integration, which can be configured to use
131                // either "headers" or "multi_value_headers" field. We need to return both fields
132                // to ensure both configuration work correctly.
133                response.headers = headers.clone();
134                response.multi_value_headers = headers;
135                response.status_description = Some(format!(
136                    "{} {}",
137                    status_code,
138                    parts.status.canonical_reason().unwrap_or_default()
139                ));
140                // Today, this implementation doesn't provide any additional fields
141                #[cfg(feature = "catch-all-fields")]
142                {
143                    response.other = Default::default();
144                }
145                response
146            }),
147            #[cfg(feature = "apigw_websockets")]
148            RequestOrigin::WebSocket => LambdaResponse::ApiGatewayV1({
149                let mut response = ApiGatewayProxyResponse::default();
150                response.body = body;
151                response.is_base64_encoded = is_base64_encoded;
152                response.status_code = status_code as i64;
153                // Explicitly empty, as API gateway v1 will merge "headers" and
154                // "multi_value_headers" fields together resulting in duplicate response headers.
155                response.headers = HeaderMap::new();
156                response.multi_value_headers = headers;
157                // Today, this implementation doesn't provide any additional fields
158                #[cfg(feature = "catch-all-fields")]
159                {
160                    response.other = Default::default();
161                }
162                response
163            }),
164            #[cfg(feature = "pass_through")]
165            RequestOrigin::PassThrough => {
166                match body {
167                    // text body must be a valid json string
168                    Some(Body::Text(body)) => {LambdaResponse::PassThrough(serde_json::from_str(&body).unwrap_or_default())},
169                    // binary body and other cases return Value::Null
170                    _ => LambdaResponse::PassThrough(serde_json::Value::Null),
171                }
172            }
173            #[cfg(not(any(
174                feature = "apigw_rest",
175                feature = "apigw_http",
176                feature = "alb",
177                feature = "apigw_websockets"
178            )))]
179            _ => compile_error!("Either feature `apigw_rest`, `apigw_http`, `alb`, or `apigw_websockets` must be enabled for the `lambda-http` crate."),
180        }
181    }
182}
183
184/// Trait for generating responses
185///
186/// Types that implement this trait can be used as return types for handler functions.
187pub trait IntoResponse {
188    /// Transform into a `Response<Body>` Future
189    fn into_response(self) -> ResponseFuture;
190}
191
192impl<B> IntoResponse for Response<B>
193where
194    B: ConvertBody + Send + 'static,
195{
196    fn into_response(self) -> ResponseFuture {
197        let (parts, body) = self.into_parts();
198        let headers = parts.headers.clone();
199
200        let fut = async { Response::from_parts(parts, body.convert(headers).await) };
201
202        Box::pin(fut)
203    }
204}
205
206impl IntoResponse for String {
207    fn into_response(self) -> ResponseFuture {
208        Box::pin(ready(Response::new(Body::from(self))))
209    }
210}
211
212impl IntoResponse for &str {
213    fn into_response(self) -> ResponseFuture {
214        Box::pin(ready(Response::new(Body::from(self))))
215    }
216}
217
218impl IntoResponse for &[u8] {
219    fn into_response(self) -> ResponseFuture {
220        Box::pin(ready(Response::new(Body::from(self))))
221    }
222}
223
224impl IntoResponse for Vec<u8> {
225    fn into_response(self) -> ResponseFuture {
226        Box::pin(ready(Response::new(Body::from(self))))
227    }
228}
229
230impl IntoResponse for serde_json::Value {
231    fn into_response(self) -> ResponseFuture {
232        Box::pin(async move {
233            Response::builder()
234                .header(CONTENT_TYPE, "application/json")
235                .body(
236                    serde_json::to_string(&self)
237                        .expect("unable to serialize serde_json::Value")
238                        .into(),
239                )
240                .expect("unable to build http::Response")
241        })
242    }
243}
244
245impl IntoResponse for (StatusCode, String) {
246    fn into_response(self) -> ResponseFuture {
247        let (status, body) = self;
248        Box::pin(ready(
249            Response::builder()
250                .status(status)
251                .body(Body::from(body))
252                .expect("unable to build http::Response"),
253        ))
254    }
255}
256
257impl IntoResponse for (StatusCode, &str) {
258    fn into_response(self) -> ResponseFuture {
259        let (status, body) = self;
260        Box::pin(ready(
261            Response::builder()
262                .status(status)
263                .body(Body::from(body))
264                .expect("unable to build http::Response"),
265        ))
266    }
267}
268
269impl IntoResponse for (StatusCode, &[u8]) {
270    fn into_response(self) -> ResponseFuture {
271        let (status, body) = self;
272        Box::pin(ready(
273            Response::builder()
274                .status(status)
275                .body(Body::from(body))
276                .expect("unable to build http::Response"),
277        ))
278    }
279}
280
281impl IntoResponse for (StatusCode, Vec<u8>) {
282    fn into_response(self) -> ResponseFuture {
283        let (status, body) = self;
284        Box::pin(ready(
285            Response::builder()
286                .status(status)
287                .body(Body::from(body))
288                .expect("unable to build http::Response"),
289        ))
290    }
291}
292
293impl IntoResponse for (StatusCode, serde_json::Value) {
294    fn into_response(self) -> ResponseFuture {
295        let (status, body) = self;
296        Box::pin(async move {
297            Response::builder()
298                .status(status)
299                .header(CONTENT_TYPE, "application/json")
300                .body(
301                    serde_json::to_string(&body)
302                        .expect("unable to serialize serde_json::Value")
303                        .into(),
304                )
305                .expect("unable to build http::Response")
306        })
307    }
308}
309
310pub type ResponseFuture = Pin<Box<dyn Future<Output = Response<Body>> + Send>>;
311
312pub trait ConvertBody {
313    fn convert(self, parts: HeaderMap) -> BodyFuture;
314}
315
316impl<B> ConvertBody for B
317where
318    B: HttpBody + Unpin + Send + 'static,
319    B::Data: Send,
320    B::Error: fmt::Debug,
321{
322    fn convert(self, headers: HeaderMap) -> BodyFuture {
323        if headers.get(CONTENT_ENCODING).is_some() {
324            return convert_to_binary(self);
325        }
326
327        let content_type = if let Some(value) = headers.get(CONTENT_TYPE) {
328            value.to_str().unwrap_or_default()
329        } else {
330            // Content-Type and Content-Encoding not set, passthrough as utf8 text
331            return convert_to_text(self, "utf-8");
332        };
333
334        for prefix in TEXT_ENCODING_PREFIXES {
335            if content_type.starts_with(prefix) {
336                return convert_to_text(self, content_type);
337            }
338        }
339
340        for suffix in TEXT_ENCODING_SUFFIXES {
341            let mut parts = content_type.trim().split(';');
342            let mime_type = parts.next().unwrap_or_default();
343            if mime_type.ends_with(suffix) {
344                return convert_to_text(self, content_type);
345            }
346        }
347
348        if let Some(value) = headers.get(X_LAMBDA_HTTP_CONTENT_ENCODING) {
349            if value == "text" {
350                return convert_to_text(self, content_type);
351            }
352        }
353
354        convert_to_binary(self)
355    }
356}
357
358fn convert_to_binary<B>(body: B) -> BodyFuture
359where
360    B: HttpBody + Unpin + Send + 'static,
361    B::Data: Send,
362    B::Error: fmt::Debug,
363{
364    Box::pin(async move {
365        Body::from(
366            body.collect()
367                .await
368                .expect("unable to read bytes from body")
369                .to_bytes()
370                .to_vec(),
371        )
372    })
373}
374
375fn convert_to_text<B>(body: B, content_type: &str) -> BodyFuture
376where
377    B: HttpBody + Unpin + Send + 'static,
378    B::Data: Send,
379    B::Error: fmt::Debug,
380{
381    let mime_type = content_type.parse::<Mime>();
382
383    let encoding = match mime_type.as_ref() {
384        Ok(mime) => mime.get_param(CHARSET).unwrap_or(mime::UTF_8),
385        Err(_) => mime::UTF_8,
386    };
387
388    let label = encoding.as_ref().as_bytes();
389    let encoding = Encoding::for_label(label).unwrap_or(encoding_rs::UTF_8);
390
391    // assumes utf-8
392    Box::pin(async move {
393        let bytes = body.collect().await.expect("unable to read bytes from body").to_bytes();
394        let (content, _, _) = encoding.decode(&bytes);
395
396        match content {
397            Cow::Borrowed(content) => Body::from(content),
398            Cow::Owned(content) => Body::from(content),
399        }
400    })
401}
402
403pub type BodyFuture = Pin<Box<dyn Future<Output = Body> + Send>>;
404
405#[cfg(test)]
406mod tests {
407    use super::{Body, IntoResponse, LambdaResponse, RequestOrigin, X_LAMBDA_HTTP_CONTENT_ENCODING};
408    use http::{
409        header::{CONTENT_ENCODING, CONTENT_TYPE},
410        Response, StatusCode,
411    };
412    use lambda_runtime_api_client::body::Body as HyperBody;
413    use serde_json::{self, json};
414
415    const SVG_LOGO: &str = include_str!("../tests/data/svg_logo.svg");
416
417    #[tokio::test]
418    async fn json_into_response() {
419        let response = json!({ "hello": "lambda"}).into_response().await;
420        match response.body() {
421            Body::Text(json) => assert_eq!(json, r#"{"hello":"lambda"}"#),
422            _ => panic!("invalid body"),
423        }
424        assert_eq!(
425            response
426                .headers()
427                .get(CONTENT_TYPE)
428                .map(|h| h.to_str().expect("invalid header")),
429            Some("application/json")
430        )
431    }
432
433    #[tokio::test]
434    async fn text_into_response() {
435        let response = "text".into_response().await;
436        match response.body() {
437            Body::Text(text) => assert_eq!(text, "text"),
438            _ => panic!("invalid body"),
439        }
440    }
441
442    #[tokio::test]
443    async fn bytes_into_response() {
444        let response = "text".as_bytes().into_response().await;
445        match response.body() {
446            Body::Binary(data) => assert_eq!(data, "text".as_bytes()),
447            _ => panic!("invalid body"),
448        }
449    }
450
451    #[tokio::test]
452    async fn json_with_status_code_into_response() {
453        let response = (StatusCode::CREATED, json!({ "hello": "lambda"})).into_response().await;
454        match response.body() {
455            Body::Text(json) => assert_eq!(json, r#"{"hello":"lambda"}"#),
456            _ => panic!("invalid body"),
457        }
458        match response.status() {
459            StatusCode::CREATED => (),
460            _ => panic!("invalid status code"),
461        }
462
463        assert_eq!(
464            response
465                .headers()
466                .get(CONTENT_TYPE)
467                .map(|h| h.to_str().expect("invalid header")),
468            Some("application/json")
469        )
470    }
471
472    #[tokio::test]
473    async fn text_with_status_code_into_response() {
474        let response = (StatusCode::CREATED, "text").into_response().await;
475
476        match response.status() {
477            StatusCode::CREATED => (),
478            _ => panic!("invalid status code"),
479        }
480        match response.body() {
481            Body::Text(text) => assert_eq!(text, "text"),
482            _ => panic!("invalid body"),
483        }
484    }
485
486    #[tokio::test]
487    async fn bytes_with_status_code_into_response() {
488        let response = (StatusCode::CREATED, "text".as_bytes()).into_response().await;
489        match response.status() {
490            StatusCode::CREATED => (),
491            _ => panic!("invalid status code"),
492        }
493        match response.body() {
494            Body::Binary(data) => assert_eq!(data, "text".as_bytes()),
495            _ => panic!("invalid body"),
496        }
497    }
498
499    #[tokio::test]
500    async fn content_encoding_header() {
501        // Drive the implementation by using `hyper::Body` instead of
502        // of `aws_lambda_events::encodings::Body`
503        let response = Response::builder()
504            .header(CONTENT_ENCODING, "gzip")
505            .body(HyperBody::from("000000".as_bytes()))
506            .expect("unable to build http::Response");
507        let response = response.into_response().await;
508        let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
509
510        let json = serde_json::to_string(&response).expect("failed to serialize to json");
511        assert_eq!(
512            json,
513            r#"{"statusCode":200,"headers":{"content-encoding":"gzip"},"multiValueHeaders":{},"body":"MDAwMDAw","isBase64Encoded":true,"cookies":[]}"#
514        )
515    }
516
517    #[tokio::test]
518    async fn content_type_header() {
519        // Drive the implementation by using `hyper::Body` instead of
520        // of `aws_lambda_events::encodings::Body`
521        let response = Response::builder()
522            .header(CONTENT_TYPE, "application/json")
523            .body(HyperBody::from("000000".as_bytes()))
524            .expect("unable to build http::Response");
525        let response = response.into_response().await;
526        let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
527
528        let json = serde_json::to_string(&response).expect("failed to serialize to json");
529        assert_eq!(
530            json,
531            r#"{"statusCode":200,"headers":{"content-type":"application/json"},"multiValueHeaders":{},"body":"000000","isBase64Encoded":false,"cookies":[]}"#
532        )
533    }
534
535    #[tokio::test]
536    async fn charset_content_type_header() {
537        // Drive the implementation by using `hyper::Body` instead of
538        // of `aws_lambda_events::encodings::Body`
539        let response = Response::builder()
540            .header(CONTENT_TYPE, "application/json; charset=utf-16")
541            .body(HyperBody::from("000000".as_bytes()))
542            .expect("unable to build http::Response");
543        let response = response.into_response().await;
544        let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
545
546        let json = serde_json::to_string(&response).expect("failed to serialize to json");
547        assert_eq!(
548            json,
549            r#"{"statusCode":200,"headers":{"content-type":"application/json; charset=utf-16"},"multiValueHeaders":{},"body":"〰〰〰","isBase64Encoded":false,"cookies":[]}"#
550        )
551    }
552
553    #[tokio::test]
554    async fn charset_content_type_header_suffix() {
555        // Drive the implementation by using `hyper::Body` instead of
556        // of `aws_lambda_events::encodings::Body`
557        let response = Response::builder()
558            .header(CONTENT_TYPE, "application/graphql-response+json; charset=utf-16")
559            .body(HyperBody::from("000000".as_bytes()))
560            .expect("unable to build http::Response");
561        let response = response.into_response().await;
562        let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
563
564        let json = serde_json::to_string(&response).expect("failed to serialize to json");
565        assert_eq!(
566            json,
567            r#"{"statusCode":200,"headers":{"content-type":"application/graphql-response+json; charset=utf-16"},"multiValueHeaders":{},"body":"〰〰〰","isBase64Encoded":false,"cookies":[]}"#
568        )
569    }
570
571    #[tokio::test]
572    async fn content_headers_unset() {
573        // Drive the implementation by using `hyper::Body` instead of
574        // of `aws_lambda_events::encodings::Body`
575        let response = Response::builder()
576            .body(HyperBody::from("000000".as_bytes()))
577            .expect("unable to build http::Response");
578        let response = response.into_response().await;
579        let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
580
581        let json = serde_json::to_string(&response).expect("failed to serialize to json");
582        assert_eq!(
583            json,
584            r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"000000","isBase64Encoded":false,"cookies":[]}"#
585        )
586    }
587
588    #[test]
589    fn serialize_multi_value_headers() {
590        let res = LambdaResponse::from_response(
591            &RequestOrigin::ApiGatewayV1,
592            Response::builder()
593                .header("multi", "a")
594                .header("multi", "b")
595                .body(Body::from(()))
596                .expect("failed to create response"),
597        );
598        let json = serde_json::to_string(&res).expect("failed to serialize to json");
599        assert_eq!(
600            json,
601            r#"{"statusCode":200,"headers":{},"multiValueHeaders":{"multi":["a","b"]},"isBase64Encoded":false}"#
602        )
603    }
604
605    #[test]
606    fn serialize_cookies() {
607        let res = LambdaResponse::from_response(
608            &RequestOrigin::ApiGatewayV2,
609            Response::builder()
610                .header("set-cookie", "cookie1=a")
611                .header("set-cookie", "cookie2=b")
612                .body(Body::from(()))
613                .expect("failed to create response"),
614        );
615        let json = serde_json::to_string(&res).expect("failed to serialize to json");
616        assert_eq!(
617            "{\"statusCode\":200,\"headers\":{},\"multiValueHeaders\":{},\"isBase64Encoded\":false,\"cookies\":[\"cookie1=a\",\"cookie2=b\"]}",
618            json
619        )
620    }
621
622    #[tokio::test]
623    async fn content_type_xml_as_text() {
624        // Drive the implementation by using `hyper::Body` instead of
625        // of `aws_lambda_events::encodings::Body`
626        let response = Response::builder()
627            .header(CONTENT_TYPE, "image/svg+xml")
628            .body(HyperBody::from(SVG_LOGO.as_bytes()))
629            .expect("unable to build http::Response");
630        let response = response.into_response().await;
631
632        match response.body() {
633            Body::Text(body) => assert_eq!(SVG_LOGO, body),
634            _ => panic!("invalid body"),
635        }
636        assert_eq!(
637            response
638                .headers()
639                .get(CONTENT_TYPE)
640                .map(|h| h.to_str().expect("invalid header")),
641            Some("image/svg+xml")
642        )
643    }
644
645    #[tokio::test]
646    async fn content_type_custom_encoding_as_text() {
647        // Drive the implementation by using `hyper::Body` instead of
648        // of `aws_lambda_events::encodings::Body`
649        let response = Response::builder()
650            // this CONTENT-TYPE is not standard, and would yield a binary response
651            .header(CONTENT_TYPE, "image/svg")
652            .header(X_LAMBDA_HTTP_CONTENT_ENCODING, "text")
653            .body(HyperBody::from(SVG_LOGO.as_bytes()))
654            .expect("unable to build http::Response");
655        let response = response.into_response().await;
656
657        match response.body() {
658            Body::Text(body) => assert_eq!(SVG_LOGO, body),
659            _ => panic!("invalid body"),
660        }
661        assert_eq!(
662            response
663                .headers()
664                .get(CONTENT_TYPE)
665                .map(|h| h.to_str().expect("invalid header")),
666            Some("image/svg")
667        )
668    }
669
670    #[tokio::test]
671    async fn content_type_yaml_as_text() {
672        // Drive the implementation by using `hyper::Body` instead of
673        // of `aws_lambda_events::encodings::Body`
674        let yaml = r#"---
675foo: bar
676        "#;
677
678        let formats = ["application/yaml", "custom/vdn+yaml"];
679
680        for format in formats {
681            let response = Response::builder()
682                .header(CONTENT_TYPE, format)
683                .body(HyperBody::from(yaml.as_bytes()))
684                .expect("unable to build http::Response");
685            let response = response.into_response().await;
686
687            match response.body() {
688                Body::Text(body) => assert_eq!(yaml, body),
689                _ => panic!("invalid body"),
690            }
691            assert_eq!(
692                response
693                    .headers()
694                    .get(CONTENT_TYPE)
695                    .map(|h| h.to_str().expect("invalid header")),
696                Some(format)
697            )
698        }
699    }
700}