lambda_http/
response.rs

1//! Response types
2
3use crate::request::RequestOrigin;
4#[cfg(feature = "alb")]
5use aws_lambda_events::alb::AlbTargetGroupResponse;
6#[cfg(any(feature = "apigw_rest", feature = "apigw_websockets"))]
7use aws_lambda_events::apigw::ApiGatewayProxyResponse;
8#[cfg(feature = "apigw_http")]
9use aws_lambda_events::apigw::ApiGatewayV2httpResponse;
10use aws_lambda_events::encodings::Body;
11use encoding_rs::Encoding;
12use http::{
13    header::{CONTENT_ENCODING, CONTENT_TYPE},
14    HeaderMap, Response, StatusCode,
15};
16use http_body::Body as HttpBody;
17use http_body_util::BodyExt;
18use mime::{Mime, CHARSET};
19use serde::Serialize;
20use std::{
21    borrow::Cow,
22    fmt,
23    future::{ready, Future},
24    pin::Pin,
25};
26
27const X_LAMBDA_HTTP_CONTENT_ENCODING: &str = "x-lambda-http-content-encoding";
28
29// See list of common MIME types:
30// - https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types/Common_types
31// - https://github.com/ietf-wg-httpapi/mediatypes/blob/main/draft-ietf-httpapi-yaml-mediatypes.md
32const TEXT_ENCODING_PREFIXES: [&str; 5] = [
33    "text",
34    "application/json",
35    "application/javascript",
36    "application/xml",
37    "application/yaml",
38];
39
40const TEXT_ENCODING_SUFFIXES: [&str; 3] = ["+xml", "+yaml", "+json"];
41
42/// Representation of Lambda response
43#[doc(hidden)]
44#[derive(Serialize, Debug)]
45#[serde(untagged)]
46pub enum LambdaResponse {
47    #[cfg(any(feature = "apigw_rest", feature = "apigw_websockets"))]
48    ApiGatewayV1(ApiGatewayProxyResponse),
49    #[cfg(feature = "apigw_http")]
50    ApiGatewayV2(ApiGatewayV2httpResponse),
51    #[cfg(feature = "alb")]
52    Alb(AlbTargetGroupResponse),
53    #[cfg(feature = "pass_through")]
54    PassThrough(serde_json::Value),
55}
56
57/// Transformation from http type to internal type
58impl LambdaResponse {
59    pub(crate) fn from_response(request_origin: &RequestOrigin, value: Response<Body>) -> Self {
60        let (parts, bod) = value.into_parts();
61        let (is_base64_encoded, body) = match bod {
62            Body::Empty => (false, None),
63            b @ Body::Text(_) => (false, Some(b)),
64            b @ Body::Binary(_) => (true, Some(b)),
65        };
66
67        let headers = parts.headers;
68        let status_code = parts.status.as_u16();
69
70        match request_origin {
71            #[cfg(feature = "apigw_rest")]
72            RequestOrigin::ApiGatewayV1 => LambdaResponse::ApiGatewayV1(ApiGatewayProxyResponse {
73                body,
74                is_base64_encoded,
75                status_code: status_code as i64,
76                // Explicitly empty, as API gateway v1 will merge "headers" and
77                // "multi_value_headers" fields together resulting in duplicate response headers.
78                headers: HeaderMap::new(),
79                multi_value_headers: headers,
80            }),
81            #[cfg(feature = "apigw_http")]
82            RequestOrigin::ApiGatewayV2 => {
83                use http::header::SET_COOKIE;
84                let mut headers = headers;
85                // ApiGatewayV2 expects the set-cookies headers to be in the "cookies" attribute,
86                // so remove them from the headers.
87                let cookies = headers
88                    .get_all(SET_COOKIE)
89                    .iter()
90                    .cloned()
91                    .map(|v| v.to_str().ok().unwrap_or_default().to_string())
92                    .collect();
93                headers.remove(SET_COOKIE);
94
95                LambdaResponse::ApiGatewayV2(ApiGatewayV2httpResponse {
96                    body,
97                    is_base64_encoded,
98                    status_code: status_code as i64,
99                    cookies,
100                    // API gateway v2 doesn't have "multi_value_headers" field. Duplicate headers
101                    // are combined with commas and included in the headers field.
102                    headers,
103                    multi_value_headers: HeaderMap::new(),
104                })
105            }
106            #[cfg(feature = "alb")]
107            RequestOrigin::Alb => LambdaResponse::Alb(AlbTargetGroupResponse {
108                body,
109                status_code: status_code as i64,
110                is_base64_encoded,
111                // ALB responses are used for ALB integration, which can be configured to use
112                // either "headers" or "multi_value_headers" field. We need to return both fields
113                // to ensure both configuration work correctly.
114                headers: headers.clone(),
115                multi_value_headers: headers,
116                status_description: Some(format!(
117                    "{} {}",
118                    status_code,
119                    parts.status.canonical_reason().unwrap_or_default()
120                )),
121            }),
122            #[cfg(feature = "apigw_websockets")]
123            RequestOrigin::WebSocket => LambdaResponse::ApiGatewayV1(ApiGatewayProxyResponse {
124                body,
125                is_base64_encoded,
126                status_code: status_code as i64,
127                // Explicitly empty, as API gateway v1 will merge "headers" and
128                // "multi_value_headers" fields together resulting in duplicate response headers.
129                headers: HeaderMap::new(),
130                multi_value_headers: headers,
131            }),
132            #[cfg(feature = "pass_through")]
133            RequestOrigin::PassThrough => {
134                match body {
135                    // text body must be a valid json string
136                    Some(Body::Text(body)) => {LambdaResponse::PassThrough(serde_json::from_str(&body).unwrap_or_default())},
137                    // binary body and other cases return Value::Null
138                    _ => LambdaResponse::PassThrough(serde_json::Value::Null),
139                }
140            }
141            #[cfg(not(any(
142                feature = "apigw_rest",
143                feature = "apigw_http",
144                feature = "alb",
145                feature = "apigw_websockets"
146            )))]
147            _ => compile_error!("Either feature `apigw_rest`, `apigw_http`, `alb`, or `apigw_websockets` must be enabled for the `lambda-http` crate."),
148        }
149    }
150}
151
152/// Trait for generating responses
153///
154/// Types that implement this trait can be used as return types for handler functions.
155pub trait IntoResponse {
156    /// Transform into a `Response<Body>` Future
157    fn into_response(self) -> ResponseFuture;
158}
159
160impl<B> IntoResponse for Response<B>
161where
162    B: ConvertBody + Send + 'static,
163{
164    fn into_response(self) -> ResponseFuture {
165        let (parts, body) = self.into_parts();
166        let headers = parts.headers.clone();
167
168        let fut = async { Response::from_parts(parts, body.convert(headers).await) };
169
170        Box::pin(fut)
171    }
172}
173
174impl IntoResponse for String {
175    fn into_response(self) -> ResponseFuture {
176        Box::pin(ready(Response::new(Body::from(self))))
177    }
178}
179
180impl IntoResponse for &str {
181    fn into_response(self) -> ResponseFuture {
182        Box::pin(ready(Response::new(Body::from(self))))
183    }
184}
185
186impl IntoResponse for &[u8] {
187    fn into_response(self) -> ResponseFuture {
188        Box::pin(ready(Response::new(Body::from(self))))
189    }
190}
191
192impl IntoResponse for Vec<u8> {
193    fn into_response(self) -> ResponseFuture {
194        Box::pin(ready(Response::new(Body::from(self))))
195    }
196}
197
198impl IntoResponse for serde_json::Value {
199    fn into_response(self) -> ResponseFuture {
200        Box::pin(async move {
201            Response::builder()
202                .header(CONTENT_TYPE, "application/json")
203                .body(
204                    serde_json::to_string(&self)
205                        .expect("unable to serialize serde_json::Value")
206                        .into(),
207                )
208                .expect("unable to build http::Response")
209        })
210    }
211}
212
213impl IntoResponse for (StatusCode, String) {
214    fn into_response(self) -> ResponseFuture {
215        let (status, body) = self;
216        Box::pin(ready(
217            Response::builder()
218                .status(status)
219                .body(Body::from(body))
220                .expect("unable to build http::Response"),
221        ))
222    }
223}
224
225impl IntoResponse for (StatusCode, &str) {
226    fn into_response(self) -> ResponseFuture {
227        let (status, body) = self;
228        Box::pin(ready(
229            Response::builder()
230                .status(status)
231                .body(Body::from(body))
232                .expect("unable to build http::Response"),
233        ))
234    }
235}
236
237impl IntoResponse for (StatusCode, &[u8]) {
238    fn into_response(self) -> ResponseFuture {
239        let (status, body) = self;
240        Box::pin(ready(
241            Response::builder()
242                .status(status)
243                .body(Body::from(body))
244                .expect("unable to build http::Response"),
245        ))
246    }
247}
248
249impl IntoResponse for (StatusCode, Vec<u8>) {
250    fn into_response(self) -> ResponseFuture {
251        let (status, body) = self;
252        Box::pin(ready(
253            Response::builder()
254                .status(status)
255                .body(Body::from(body))
256                .expect("unable to build http::Response"),
257        ))
258    }
259}
260
261impl IntoResponse for (StatusCode, serde_json::Value) {
262    fn into_response(self) -> ResponseFuture {
263        let (status, body) = self;
264        Box::pin(async move {
265            Response::builder()
266                .status(status)
267                .header(CONTENT_TYPE, "application/json")
268                .body(
269                    serde_json::to_string(&body)
270                        .expect("unable to serialize serde_json::Value")
271                        .into(),
272                )
273                .expect("unable to build http::Response")
274        })
275    }
276}
277
278pub type ResponseFuture = Pin<Box<dyn Future<Output = Response<Body>> + Send>>;
279
280pub trait ConvertBody {
281    fn convert(self, parts: HeaderMap) -> BodyFuture;
282}
283
284impl<B> ConvertBody for B
285where
286    B: HttpBody + Unpin + Send + 'static,
287    B::Data: Send,
288    B::Error: fmt::Debug,
289{
290    fn convert(self, headers: HeaderMap) -> BodyFuture {
291        if headers.get(CONTENT_ENCODING).is_some() {
292            return convert_to_binary(self);
293        }
294
295        let content_type = if let Some(value) = headers.get(CONTENT_TYPE) {
296            value.to_str().unwrap_or_default()
297        } else {
298            // Content-Type and Content-Encoding not set, passthrough as utf8 text
299            return convert_to_text(self, "utf-8");
300        };
301
302        for prefix in TEXT_ENCODING_PREFIXES {
303            if content_type.starts_with(prefix) {
304                return convert_to_text(self, content_type);
305            }
306        }
307
308        for suffix in TEXT_ENCODING_SUFFIXES {
309            let mut parts = content_type.trim().split(';');
310            let mime_type = parts.next().unwrap_or_default();
311            if mime_type.ends_with(suffix) {
312                return convert_to_text(self, content_type);
313            }
314        }
315
316        if let Some(value) = headers.get(X_LAMBDA_HTTP_CONTENT_ENCODING) {
317            if value == "text" {
318                return convert_to_text(self, content_type);
319            }
320        }
321
322        convert_to_binary(self)
323    }
324}
325
326fn convert_to_binary<B>(body: B) -> BodyFuture
327where
328    B: HttpBody + Unpin + Send + 'static,
329    B::Data: Send,
330    B::Error: fmt::Debug,
331{
332    Box::pin(async move {
333        Body::from(
334            body.collect()
335                .await
336                .expect("unable to read bytes from body")
337                .to_bytes()
338                .to_vec(),
339        )
340    })
341}
342
343fn convert_to_text<B>(body: B, content_type: &str) -> BodyFuture
344where
345    B: HttpBody + Unpin + Send + 'static,
346    B::Data: Send,
347    B::Error: fmt::Debug,
348{
349    let mime_type = content_type.parse::<Mime>();
350
351    let encoding = match mime_type.as_ref() {
352        Ok(mime) => mime.get_param(CHARSET).unwrap_or(mime::UTF_8),
353        Err(_) => mime::UTF_8,
354    };
355
356    let label = encoding.as_ref().as_bytes();
357    let encoding = Encoding::for_label(label).unwrap_or(encoding_rs::UTF_8);
358
359    // assumes utf-8
360    Box::pin(async move {
361        let bytes = body.collect().await.expect("unable to read bytes from body").to_bytes();
362        let (content, _, _) = encoding.decode(&bytes);
363
364        match content {
365            Cow::Borrowed(content) => Body::from(content),
366            Cow::Owned(content) => Body::from(content),
367        }
368    })
369}
370
371pub type BodyFuture = Pin<Box<dyn Future<Output = Body> + Send>>;
372
373#[cfg(test)]
374mod tests {
375    use super::{Body, IntoResponse, LambdaResponse, RequestOrigin, X_LAMBDA_HTTP_CONTENT_ENCODING};
376    use http::{
377        header::{CONTENT_ENCODING, CONTENT_TYPE},
378        Response, StatusCode,
379    };
380    use lambda_runtime_api_client::body::Body as HyperBody;
381    use serde_json::{self, json};
382
383    const SVG_LOGO: &str = include_str!("../tests/data/svg_logo.svg");
384
385    #[tokio::test]
386    async fn json_into_response() {
387        let response = json!({ "hello": "lambda"}).into_response().await;
388        match response.body() {
389            Body::Text(json) => assert_eq!(json, r#"{"hello":"lambda"}"#),
390            _ => panic!("invalid body"),
391        }
392        assert_eq!(
393            response
394                .headers()
395                .get(CONTENT_TYPE)
396                .map(|h| h.to_str().expect("invalid header")),
397            Some("application/json")
398        )
399    }
400
401    #[tokio::test]
402    async fn text_into_response() {
403        let response = "text".into_response().await;
404        match response.body() {
405            Body::Text(text) => assert_eq!(text, "text"),
406            _ => panic!("invalid body"),
407        }
408    }
409
410    #[tokio::test]
411    async fn bytes_into_response() {
412        let response = "text".as_bytes().into_response().await;
413        match response.body() {
414            Body::Binary(data) => assert_eq!(data, "text".as_bytes()),
415            _ => panic!("invalid body"),
416        }
417    }
418
419    #[tokio::test]
420    async fn json_with_status_code_into_response() {
421        let response = (StatusCode::CREATED, json!({ "hello": "lambda"})).into_response().await;
422        match response.body() {
423            Body::Text(json) => assert_eq!(json, r#"{"hello":"lambda"}"#),
424            _ => panic!("invalid body"),
425        }
426        match response.status() {
427            StatusCode::CREATED => (),
428            _ => panic!("invalid status code"),
429        }
430
431        assert_eq!(
432            response
433                .headers()
434                .get(CONTENT_TYPE)
435                .map(|h| h.to_str().expect("invalid header")),
436            Some("application/json")
437        )
438    }
439
440    #[tokio::test]
441    async fn text_with_status_code_into_response() {
442        let response = (StatusCode::CREATED, "text").into_response().await;
443
444        match response.status() {
445            StatusCode::CREATED => (),
446            _ => panic!("invalid status code"),
447        }
448        match response.body() {
449            Body::Text(text) => assert_eq!(text, "text"),
450            _ => panic!("invalid body"),
451        }
452    }
453
454    #[tokio::test]
455    async fn bytes_with_status_code_into_response() {
456        let response = (StatusCode::CREATED, "text".as_bytes()).into_response().await;
457        match response.status() {
458            StatusCode::CREATED => (),
459            _ => panic!("invalid status code"),
460        }
461        match response.body() {
462            Body::Binary(data) => assert_eq!(data, "text".as_bytes()),
463            _ => panic!("invalid body"),
464        }
465    }
466
467    #[tokio::test]
468    async fn content_encoding_header() {
469        // Drive the implementation by using `hyper::Body` instead of
470        // of `aws_lambda_events::encodings::Body`
471        let response = Response::builder()
472            .header(CONTENT_ENCODING, "gzip")
473            .body(HyperBody::from("000000".as_bytes()))
474            .expect("unable to build http::Response");
475        let response = response.into_response().await;
476        let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
477
478        let json = serde_json::to_string(&response).expect("failed to serialize to json");
479        assert_eq!(
480            json,
481            r#"{"statusCode":200,"headers":{"content-encoding":"gzip"},"multiValueHeaders":{},"body":"MDAwMDAw","isBase64Encoded":true,"cookies":[]}"#
482        )
483    }
484
485    #[tokio::test]
486    async fn content_type_header() {
487        // Drive the implementation by using `hyper::Body` instead of
488        // of `aws_lambda_events::encodings::Body`
489        let response = Response::builder()
490            .header(CONTENT_TYPE, "application/json")
491            .body(HyperBody::from("000000".as_bytes()))
492            .expect("unable to build http::Response");
493        let response = response.into_response().await;
494        let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
495
496        let json = serde_json::to_string(&response).expect("failed to serialize to json");
497        assert_eq!(
498            json,
499            r#"{"statusCode":200,"headers":{"content-type":"application/json"},"multiValueHeaders":{},"body":"000000","isBase64Encoded":false,"cookies":[]}"#
500        )
501    }
502
503    #[tokio::test]
504    async fn charset_content_type_header() {
505        // Drive the implementation by using `hyper::Body` instead of
506        // of `aws_lambda_events::encodings::Body`
507        let response = Response::builder()
508            .header(CONTENT_TYPE, "application/json; charset=utf-16")
509            .body(HyperBody::from("000000".as_bytes()))
510            .expect("unable to build http::Response");
511        let response = response.into_response().await;
512        let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
513
514        let json = serde_json::to_string(&response).expect("failed to serialize to json");
515        assert_eq!(
516            json,
517            r#"{"statusCode":200,"headers":{"content-type":"application/json; charset=utf-16"},"multiValueHeaders":{},"body":"〰〰〰","isBase64Encoded":false,"cookies":[]}"#
518        )
519    }
520
521    #[tokio::test]
522    async fn charset_content_type_header_suffix() {
523        // Drive the implementation by using `hyper::Body` instead of
524        // of `aws_lambda_events::encodings::Body`
525        let response = Response::builder()
526            .header(CONTENT_TYPE, "application/graphql-response+json; charset=utf-16")
527            .body(HyperBody::from("000000".as_bytes()))
528            .expect("unable to build http::Response");
529        let response = response.into_response().await;
530        let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
531
532        let json = serde_json::to_string(&response).expect("failed to serialize to json");
533        assert_eq!(
534            json,
535            r#"{"statusCode":200,"headers":{"content-type":"application/graphql-response+json; charset=utf-16"},"multiValueHeaders":{},"body":"〰〰〰","isBase64Encoded":false,"cookies":[]}"#
536        )
537    }
538
539    #[tokio::test]
540    async fn content_headers_unset() {
541        // Drive the implementation by using `hyper::Body` instead of
542        // of `aws_lambda_events::encodings::Body`
543        let response = Response::builder()
544            .body(HyperBody::from("000000".as_bytes()))
545            .expect("unable to build http::Response");
546        let response = response.into_response().await;
547        let response = LambdaResponse::from_response(&RequestOrigin::ApiGatewayV2, response);
548
549        let json = serde_json::to_string(&response).expect("failed to serialize to json");
550        assert_eq!(
551            json,
552            r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"000000","isBase64Encoded":false,"cookies":[]}"#
553        )
554    }
555
556    #[test]
557    fn serialize_multi_value_headers() {
558        let res = LambdaResponse::from_response(
559            &RequestOrigin::ApiGatewayV1,
560            Response::builder()
561                .header("multi", "a")
562                .header("multi", "b")
563                .body(Body::from(()))
564                .expect("failed to create response"),
565        );
566        let json = serde_json::to_string(&res).expect("failed to serialize to json");
567        assert_eq!(
568            json,
569            r#"{"statusCode":200,"headers":{},"multiValueHeaders":{"multi":["a","b"]},"isBase64Encoded":false}"#
570        )
571    }
572
573    #[test]
574    fn serialize_cookies() {
575        let res = LambdaResponse::from_response(
576            &RequestOrigin::ApiGatewayV2,
577            Response::builder()
578                .header("set-cookie", "cookie1=a")
579                .header("set-cookie", "cookie2=b")
580                .body(Body::from(()))
581                .expect("failed to create response"),
582        );
583        let json = serde_json::to_string(&res).expect("failed to serialize to json");
584        assert_eq!(
585            "{\"statusCode\":200,\"headers\":{},\"multiValueHeaders\":{},\"isBase64Encoded\":false,\"cookies\":[\"cookie1=a\",\"cookie2=b\"]}",
586            json
587        )
588    }
589
590    #[tokio::test]
591    async fn content_type_xml_as_text() {
592        // Drive the implementation by using `hyper::Body` instead of
593        // of `aws_lambda_events::encodings::Body`
594        let response = Response::builder()
595            .header(CONTENT_TYPE, "image/svg+xml")
596            .body(HyperBody::from(SVG_LOGO.as_bytes()))
597            .expect("unable to build http::Response");
598        let response = response.into_response().await;
599
600        match response.body() {
601            Body::Text(body) => assert_eq!(SVG_LOGO, body),
602            _ => panic!("invalid body"),
603        }
604        assert_eq!(
605            response
606                .headers()
607                .get(CONTENT_TYPE)
608                .map(|h| h.to_str().expect("invalid header")),
609            Some("image/svg+xml")
610        )
611    }
612
613    #[tokio::test]
614    async fn content_type_custom_encoding_as_text() {
615        // Drive the implementation by using `hyper::Body` instead of
616        // of `aws_lambda_events::encodings::Body`
617        let response = Response::builder()
618            // this CONTENT-TYPE is not standard, and would yield a binary response
619            .header(CONTENT_TYPE, "image/svg")
620            .header(X_LAMBDA_HTTP_CONTENT_ENCODING, "text")
621            .body(HyperBody::from(SVG_LOGO.as_bytes()))
622            .expect("unable to build http::Response");
623        let response = response.into_response().await;
624
625        match response.body() {
626            Body::Text(body) => assert_eq!(SVG_LOGO, body),
627            _ => panic!("invalid body"),
628        }
629        assert_eq!(
630            response
631                .headers()
632                .get(CONTENT_TYPE)
633                .map(|h| h.to_str().expect("invalid header")),
634            Some("image/svg")
635        )
636    }
637
638    #[tokio::test]
639    async fn content_type_yaml_as_text() {
640        // Drive the implementation by using `hyper::Body` instead of
641        // of `aws_lambda_events::encodings::Body`
642        let yaml = r#"---
643foo: bar
644        "#;
645
646        let formats = ["application/yaml", "custom/vdn+yaml"];
647
648        for format in formats {
649            let response = Response::builder()
650                .header(CONTENT_TYPE, format)
651                .body(HyperBody::from(yaml.as_bytes()))
652                .expect("unable to build http::Response");
653            let response = response.into_response().await;
654
655            match response.body() {
656                Body::Text(body) => assert_eq!(yaml, body),
657                _ => panic!("invalid body"),
658            }
659            assert_eq!(
660                response
661                    .headers()
662                    .get(CONTENT_TYPE)
663                    .map(|h| h.to_str().expect("invalid header")),
664                Some(format)
665            )
666        }
667    }
668}