lambda_web/
hyper014.rs

1// SPDX-License-Identifier: MIT
2//!
3//! Run hyper based web framework on AWS Lambda
4//!
5use crate::request::LambdaHttpEvent;
6use core::convert::TryFrom;
7use core::future::Future;
8use lambda_runtime::{Error as LambdaError, LambdaEvent, Service as LambdaService};
9use std::convert::Infallible;
10use std::pin::Pin;
11
12type HyperRequest = hyper::Request<hyper::Body>;
13type HyperResponse<B> = hyper::Response<B>;
14
15/// Run hyper based web framework on AWS Lambda
16///
17/// axum 0.3 example:
18///
19/// ```no_run
20/// use axum::{routing::get, Router};
21/// use lambda_web::{is_running_on_lambda, run_hyper_on_lambda, LambdaError};
22/// use std::net::SocketAddr;
23///
24/// // basic handler that responds with a static string
25/// async fn root() -> &'static str {
26///     "Hello, World!"
27/// }
28///
29/// #[tokio::main]
30/// async fn main() -> Result<(), LambdaError> {
31///     // build our application with a route
32///     let app = Router::new()
33///         // `GET /` goes to `root`
34///         .route("/", get(root));
35///
36///     if is_running_on_lambda() {
37///         // Run app on AWS Lambda
38///         run_hyper_on_lambda(app).await?;
39///     } else {
40///         // Run app on local server
41///         let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
42///         axum::Server::bind(&addr).serve(app.into_make_service()).await?;
43///     }
44///     Ok(())
45/// }
46/// ```
47///
48/// warp 0.3 example:
49///
50/// ```no_run
51/// use warp::Filter;
52/// use lambda_web::{is_running_on_lambda, run_hyper_on_lambda, LambdaError};
53///
54/// #[tokio::main]
55/// async fn main() -> Result<(),LambdaError> {
56///     // GET /hello/warp => 200 OK with body "Hello, warp!"
57///     let hello = warp::path!("hello" / String)
58///         .map(|name| format!("Hello, {}!", name));
59///
60///     if is_running_on_lambda() {
61///         // Run app on AWS Lambda
62///         run_hyper_on_lambda(warp::service(hello)).await?;
63///     } else {
64///         // Run app on local server
65///         warp::serve(hello).run(([127, 0, 0, 1], 8080)).await;
66///     }
67///     Ok(())
68/// }
69/// ```
70pub async fn run_hyper_on_lambda<S, B>(svc: S) -> Result<(), LambdaError>
71where
72    S: hyper::service::Service<HyperRequest, Response = HyperResponse<B>, Error = Infallible>
73        + 'static,
74    B: hyper::body::HttpBody,
75    <B as hyper::body::HttpBody>::Error: std::error::Error + Send + Sync + 'static,
76{
77    lambda_runtime::run(HyperHandler(svc)).await?;
78    Ok(())
79}
80
81/// Lambda_runtime handler for hyper
82struct HyperHandler<S, B>(S)
83where
84    S: hyper::service::Service<HyperRequest, Response = HyperResponse<B>, Error = Infallible>
85        + 'static,
86    B: hyper::body::HttpBody,
87    <B as hyper::body::HttpBody>::Error: std::error::Error + Send + Sync + 'static;
88
89impl<S, B> LambdaService<LambdaEvent<LambdaHttpEvent<'_>>> for HyperHandler<S, B>
90where
91    S: hyper::service::Service<HyperRequest, Response = HyperResponse<B>, Error = Infallible>
92        + 'static,
93    B: hyper::body::HttpBody,
94    <B as hyper::body::HttpBody>::Error: std::error::Error + Send + Sync + 'static,
95{
96    type Response = serde_json::Value;
97    type Error = Infallible;
98    type Future = Pin<Box<dyn Future<Output = Result<serde_json::Value, Self::Error>>>>;
99
100    /// Returns Poll::Ready when servie can process more requrests.
101    fn poll_ready(
102        &mut self,
103        cx: &mut core::task::Context<'_>,
104    ) -> core::task::Poll<Result<(), Self::Error>> {
105        self.0.poll_ready(cx)
106    }
107
108    /// Lambda handler function
109    /// Parse Lambda event as hyper request,
110    /// serialize hyper response to Lambda JSON response
111    fn call(&mut self, req: LambdaEvent<LambdaHttpEvent<'_>>) -> Self::Future {
112        use serde_json::json;
113
114        let event = req.payload;
115        let _context = req.context;
116
117        // check if web client supports content-encoding: br
118        let client_br = event.client_supports_brotli();
119        // multi-value-headers response format
120        let multi_value = event.multi_value();
121
122        // Parse request
123        let hyper_request = HyperRequest::try_from(event);
124
125        // Call hyper service when request parsing succeeded
126        let svc_call = hyper_request.map(|req| self.0.call(req));
127
128        let fut = async move {
129            match svc_call {
130                Ok(svc_fut) => {
131                    // Request parsing succeeded
132                    if let Ok(response) = svc_fut.await {
133                        // Returns as API Gateway response
134                        api_gateway_response_from_hyper(response, client_br, multi_value)
135                            .await
136                            .or_else(|_err| {
137                                Ok(json!({
138                                    "isBase64Encoded": false,
139                                    "statusCode": 500u16,
140                                    "headers": { "content-type": "text/plain"},
141                                    "body": "Internal Server Error"
142                                }))
143                            })
144                    } else {
145                        // Some hyper error -> 500 Internal Server Error
146                        Ok(json!({
147                            "isBase64Encoded": false,
148                            "statusCode": 500u16,
149                            "headers": { "content-type": "text/plain"},
150                            "body": "Internal Server Error"
151                        }))
152                    }
153                }
154                Err(_request_err) => {
155                    // Request parsing error
156                    Ok(json!({
157                        "isBase64Encoded": false,
158                        "statusCode": 400u16,
159                        "headers": { "content-type": "text/plain"},
160                        "body": "Bad Request"
161                    }))
162                }
163            }
164        };
165        Box::pin(fut)
166    }
167}
168
169impl TryFrom<LambdaHttpEvent<'_>> for HyperRequest {
170    type Error = LambdaError;
171
172    /// hyper Request from API Gateway event
173    fn try_from(event: LambdaHttpEvent) -> Result<Self, Self::Error> {
174        use hyper::header::{HeaderName, HeaderValue};
175        use hyper::Method;
176        use std::str::FromStr;
177
178        // URI
179        let uri = format!(
180            "https://{}{}",
181            event.hostname().unwrap_or("localhost"),
182            event.path_query()
183        );
184
185        // Method
186        let method = Method::try_from(event.method())?;
187
188        // Construct hyper request
189        let mut reqbuilder = hyper::Request::builder().method(method).uri(&uri);
190
191        // headers
192        if let Some(headers_mut) = reqbuilder.headers_mut() {
193            for (k, v) in event.headers() {
194                if let (Ok(k), Ok(v)) = (
195                    HeaderName::from_str(k as &str),
196                    HeaderValue::from_str(&v as &str),
197                ) {
198                    headers_mut.insert(k, v);
199                }
200            }
201        }
202
203        // Body
204        let req = reqbuilder.body(hyper::Body::from(event.body()?))?;
205
206        Ok(req)
207    }
208}
209
210impl<B> crate::brotli::ResponseCompression for HyperResponse<B> {
211    /// Content-Encoding header value
212    fn content_encoding<'a>(&'a self) -> Option<&'a str> {
213        self.headers()
214            .get(hyper::header::CONTENT_ENCODING)
215            .and_then(|val| val.to_str().ok())
216    }
217
218    /// Content-Type header value
219    fn content_type<'a>(&'a self) -> Option<&'a str> {
220        self.headers()
221            .get(hyper::header::CONTENT_TYPE)
222            .and_then(|val| val.to_str().ok())
223    }
224}
225
226/// API Gateway response from hyper response
227async fn api_gateway_response_from_hyper<B>(
228    response: HyperResponse<B>,
229    client_support_br: bool,
230    multi_value: bool,
231) -> Result<serde_json::Value, LambdaError>
232where
233    B: hyper::body::HttpBody,
234    <B as hyper::body::HttpBody>::Error: std::error::Error + Send + Sync + 'static,
235{
236    use crate::brotli::ResponseCompression;
237    use hyper::header::SET_COOKIE;
238    use serde_json::json;
239
240    // Check if response should be compressed
241    let compress = client_support_br && response.can_brotli_compress();
242
243    // Divide resonse into headers and body
244    let (parts, res_body) = response.into_parts();
245
246    // HTTP status
247    let status_code = parts.status.as_u16();
248
249    // Convert header to JSON map
250    let mut cookies = Vec::<String>::new();
251    let mut headers = serde_json::Map::new();
252    for (k, v) in parts.headers.iter() {
253        if let Ok(value_str) = v.to_str() {
254            if multi_value {
255                // REST API format, returns multiValueHeaders
256                if let Some(values) = headers.get_mut(k.as_str()) {
257                    if let Some(value_ary) = values.as_array_mut() {
258                        value_ary.push(json!(value_str));
259                    }
260                } else {
261                    headers.insert(k.as_str().to_string(), json!([value_str]));
262                }
263            } else {
264                // HTTP API v2 format, returns headers
265                if k == SET_COOKIE {
266                    cookies.push(value_str.to_string());
267                } else {
268                    headers.insert(k.as_str().to_string(), json!(value_str));
269                }
270            }
271        }
272    }
273
274    // Compress, base64 encode the response body
275    let body_bytes = hyper::body::to_bytes(res_body).await?;
276    let body_base64 = if compress {
277        if multi_value {
278            headers.insert("content-encoding".to_string(), json!(["br"]));
279        } else {
280            headers.insert("content-encoding".to_string(), json!("br"));
281        }
282        crate::brotli::compress_response_body(&body_bytes)
283    } else {
284        base64::encode(body_bytes)
285    };
286
287    if multi_value {
288        Ok(json!({
289            "isBase64Encoded": true,
290            "statusCode": status_code,
291            "multiValueHeaders": headers,
292            "body": body_base64
293        }))
294    } else {
295        Ok(json!({
296            "isBase64Encoded": true,
297            "statusCode": status_code,
298            "cookies": cookies,
299            "headers": headers,
300            "body": body_base64
301        }))
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use crate::{request::LambdaHttpEvent, test_consts::*};
309
310    // Request JSON string to http::Request
311    fn prepare_request(event_str: &str) -> HyperRequest {
312        let reqjson: LambdaHttpEvent = serde_json::from_str(event_str).unwrap();
313        let req = HyperRequest::try_from(reqjson).unwrap();
314        req
315    }
316
317    #[test]
318    fn test_path_decode() {
319        let req = prepare_request(API_GATEWAY_V2_GET_ROOT_NOQUERY);
320        assert_eq!(req.uri().path(), "/");
321        let req = prepare_request(API_GATEWAY_REST_GET_ROOT_NOQUERY);
322        assert_eq!(req.uri().path(), "/stage/");
323
324        let req = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_NOQUERY);
325        assert_eq!(req.uri().path(), "/somewhere");
326        let req = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_NOQUERY);
327        assert_eq!(req.uri().path(), "/stage/somewhere");
328
329        let req = prepare_request(API_GATEWAY_V2_GET_SPACEPATH_NOQUERY);
330        assert_eq!(req.uri().path(), "/path%20with/space");
331        let req = prepare_request(API_GATEWAY_REST_GET_SPACEPATH_NOQUERY);
332        assert_eq!(req.uri().path(), "/stage/path%20with/space");
333
334        let req = prepare_request(API_GATEWAY_V2_GET_PERCENTPATH_NOQUERY);
335        assert_eq!(req.uri().path(), "/path%25with/percent");
336        let req = prepare_request(API_GATEWAY_REST_GET_PERCENTPATH_NOQUERY);
337        assert_eq!(req.uri().path(), "/stage/path%25with/percent");
338
339        let req = prepare_request(API_GATEWAY_V2_GET_UTF8PATH_NOQUERY);
340        assert_eq!(
341            req.uri().path(),
342            "/%E6%97%A5%E6%9C%AC%E8%AA%9E/%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB%E5%90%8D"
343        );
344        let req = prepare_request(API_GATEWAY_REST_GET_UTF8PATH_NOQUERY);
345        assert_eq!(
346            req.uri().path(),
347            "/stage/%E6%97%A5%E6%9C%AC%E8%AA%9E/%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB%E5%90%8D"
348        );
349    }
350
351    #[test]
352    fn test_query_decode() {
353        let req = prepare_request(API_GATEWAY_V2_GET_ROOT_ONEQUERY);
354        assert_eq!(req.uri().query(), Some("key=value"));
355        let req = prepare_request(API_GATEWAY_REST_GET_ROOT_ONEQUERY);
356        assert_eq!(req.uri().query(), Some("key=value"));
357
358        let req = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_ONEQUERY);
359        assert_eq!(req.uri().query(), Some("key=value"));
360        let req = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_ONEQUERY);
361        assert_eq!(req.uri().query(), Some("key=value"));
362
363        let req = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_TWOQUERY);
364        assert_eq!(req.uri().query(), Some("key1=value1&key2=value2"));
365        let req = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_TWOQUERY);
366        assert!(
367            req.uri().query() == Some("key1=value1&key2=value2")
368                || req.uri().query() == Some("key2=value2&key1=value1")
369        );
370
371        let req = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_SPACEQUERY);
372        assert_eq!(req.uri().query(), Some("key=value1+value2"));
373        let req = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_SPACEQUERY);
374        assert_eq!(req.uri().query(), Some("key=value1%20value2"));
375
376        let req = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_UTF8QUERY);
377        assert_eq!(req.uri().query(), Some("key=%E6%97%A5%E6%9C%AC%E8%AA%9E"));
378        let req = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_UTF8QUERY);
379        assert_eq!(req.uri().query(), Some("key=%E6%97%A5%E6%9C%AC%E8%AA%9E"));
380    }
381
382    #[tokio::test]
383    async fn test_form_post() {
384        use hyper::body::to_bytes;
385        use hyper::Method;
386
387        let req = prepare_request(API_GATEWAY_V2_POST_FORM_URLENCODED);
388        assert_eq!(req.method(), Method::POST);
389        assert_eq!(
390            to_bytes(req.into_body()).await.unwrap().as_ref(),
391            b"key1=value1&key2=value2&Ok=Ok"
392        );
393        let req = prepare_request(API_GATEWAY_REST_POST_FORM_URLENCODED);
394        assert_eq!(req.method(), Method::POST);
395        assert_eq!(
396            to_bytes(req.into_body()).await.unwrap().as_ref(),
397            b"key1=value1&key2=value2&Ok=Ok"
398        );
399
400        // Base64 encoded
401        let req = prepare_request(API_GATEWAY_V2_POST_FORM_URLENCODED_B64);
402        assert_eq!(req.method(), Method::POST);
403        assert_eq!(
404            to_bytes(req.into_body()).await.unwrap().as_ref(),
405            b"key1=value1&key2=value2&Ok=Ok"
406        );
407        let req = prepare_request(API_GATEWAY_REST_POST_FORM_URLENCODED_B64);
408        assert_eq!(req.method(), Method::POST);
409        assert_eq!(
410            to_bytes(req.into_body()).await.unwrap().as_ref(),
411            b"key1=value1&key2=value2&Ok=Ok"
412        );
413    }
414
415    #[test]
416    fn test_parse_header() {
417        let req = prepare_request(API_GATEWAY_V2_GET_ROOT_NOQUERY);
418        assert_eq!(req.headers().get("x-forwarded-port").unwrap(), &"443");
419        assert_eq!(req.headers().get("x-forwarded-proto").unwrap(), &"https");
420        let req = prepare_request(API_GATEWAY_REST_GET_ROOT_NOQUERY);
421        assert_eq!(req.headers().get("x-forwarded-port").unwrap(), &"443");
422        assert_eq!(req.headers().get("x-forwarded-proto").unwrap(), &"https");
423    }
424
425    #[test]
426    fn test_parse_cookies() {
427        let req = prepare_request(API_GATEWAY_V2_GET_ROOT_NOQUERY);
428        assert_eq!(req.headers().get("cookie"), None);
429        let req = prepare_request(API_GATEWAY_REST_GET_ROOT_NOQUERY);
430        assert_eq!(req.headers().get("cookie"), None);
431
432        let req = prepare_request(API_GATEWAY_V2_GET_ONE_COOKIE);
433        assert_eq!(req.headers().get("cookie").unwrap(), &"cookie1=value1");
434        let req = prepare_request(API_GATEWAY_REST_GET_ONE_COOKIE);
435        assert_eq!(req.headers().get("cookie").unwrap(), &"cookie1=value1");
436
437        let req = prepare_request(API_GATEWAY_V2_GET_TWO_COOKIES);
438        assert_eq!(
439            req.headers().get("cookie").unwrap(),
440            &"cookie1=value1; cookie2=value2"
441        );
442        let req = prepare_request(API_GATEWAY_REST_GET_TWO_COOKIES);
443        assert_eq!(
444            req.headers().get("cookie").unwrap(),
445            &"cookie1=value1; cookie2=value2"
446        );
447    }
448}