lambda_web/
rocket05.rs

1// SPDX-License-Identifier: MIT
2//!
3//! Run Rocket on AWS Lambda
4//!
5//!
6use crate::request::LambdaHttpEvent;
7use core::convert::TryFrom;
8use core::future::Future;
9use lambda_runtime::{Error as LambdaError, LambdaEvent, Service as LambdaService};
10use std::pin::Pin;
11use std::sync::Arc;
12
13/// Launch Rocket application on AWS Lambda
14///
15/// ```no_run
16/// use rocket::{self, get, routes};
17/// use lambda_web::{is_running_on_lambda, launch_rocket_on_lambda, LambdaError};
18///
19/// #[get("/hello/<name>/<age>")]
20/// fn hello(name: &str, age: u8) -> String {
21///     format!("Hello, {} year old named {}!", age, name)
22/// }
23///
24/// #[rocket::main]
25/// async fn main() -> Result<(), LambdaError> {
26///     let rocket = rocket::build().mount("/", routes![hello]);
27///     if is_running_on_lambda() {
28///         // Launch on AWS Lambda
29///         launch_rocket_on_lambda(rocket).await?;
30///     } else {
31///         // Launch local server
32///         rocket.launch().await?;
33///     }
34///     Ok(())
35/// }
36/// ```
37///
38pub async fn launch_rocket_on_lambda<P: rocket::Phase>(
39    r: rocket::Rocket<P>,
40) -> Result<(), LambdaError> {
41    lambda_runtime::run(RocketHandler(Arc::new(
42        rocket::local::asynchronous::Client::untracked(r).await?,
43    )))
44    .await?;
45
46    Ok(())
47}
48
49/// Lambda_runtime handler for Rocket
50struct RocketHandler(Arc<rocket::local::asynchronous::Client>);
51
52impl LambdaService<LambdaEvent<LambdaHttpEvent<'_>>> for RocketHandler {
53    type Response = serde_json::Value;
54    type Error = rocket::Error;
55    type Future = Pin<Box<dyn Future<Output = Result<serde_json::Value, Self::Error>> + Send>>;
56
57    /// Always ready in case of Rocket local client
58    fn poll_ready(
59        &mut self,
60        _cx: &mut core::task::Context<'_>,
61    ) -> core::task::Poll<Result<(), Self::Error>> {
62        core::task::Poll::Ready(Ok(()))
63    }
64
65    /// Lambda handler function
66    /// Parse Lambda event as Rocket LocalRequest,
67    /// serialize Rocket LocalResponse to Lambda JSON response
68    fn call(&mut self, req: LambdaEvent<LambdaHttpEvent<'_>>) -> Self::Future {
69        use serde_json::json;
70
71        let event = req.payload;
72        let _context = req.context;
73
74        // check if web client supports content-encoding: br
75        let client_br = event.client_supports_brotli();
76        // multi-value-headers response format
77        let multi_value = event.multi_value();
78
79        // Parse request
80        let decode_result = RequestDecode::try_from(event);
81        let client = self.0.clone();
82        let fut = async move {
83            match decode_result {
84                Ok(req_decode) => {
85                    // Request parsing succeeded, make Rocket LocalRequest
86                    let local_request = req_decode.make_request(&client);
87
88                    // Dispatch request and get response
89                    let response = local_request.dispatch().await;
90
91                    // Return response as API Gateway JSON
92                    api_gateway_response_from_rocket(response, client_br, multi_value).await
93                }
94                Err(_request_err) => {
95                    // Request parsing error
96                    Ok(json!({
97                        "isBase64Encoded": false,
98                        "statusCode": 400u16,
99                        "headers": { "content-type": "text/plain"},
100                        "body": "Bad Request" // No details for security
101                    }))
102                }
103            }
104        };
105        Box::pin(fut)
106    }
107}
108
109// Request decoded from API gateway JSON.
110// To move async boundary in call() function,
111// all elements must be owned
112struct RequestDecode {
113    path_and_query: String,
114    method: rocket::http::Method,
115    source_ip: std::net::IpAddr,
116    cookies: Vec<String>,
117    headers: Vec<rocket::http::Header<'static>>,
118    body: Vec<u8>,
119}
120
121impl TryFrom<LambdaHttpEvent<'_>> for RequestDecode {
122    type Error = LambdaError;
123
124    /// Request from API Gateway event
125    fn try_from(event: LambdaHttpEvent) -> Result<Self, Self::Error> {
126        use rocket::http::{Header, Method};
127        use std::net::IpAddr;
128        use std::str::FromStr;
129
130        // path ? query_string
131        let path_and_query = event.path_query();
132
133        // Method, Source IP
134        let method = Method::from_str(&event.method()).map_err(|_| "InvalidMethod")?;
135        let source_ip = event
136            .source_ip()
137            .unwrap_or(IpAddr::from([0u8, 0u8, 0u8, 0u8]));
138
139        // Parse cookies
140        let cookies = event.cookies().iter().map(|c| c.to_string()).collect();
141
142        // Headers
143        let headers = event
144            .headers()
145            .iter()
146            .map(|(k, v)| Header::new(k.to_string(), v.to_string()))
147            .collect::<Vec<Header>>();
148
149        // Body
150        let body = event.body()?;
151
152        Ok(Self {
153            path_and_query,
154            method,
155            source_ip,
156            cookies,
157            headers,
158            body,
159        })
160    }
161}
162
163impl RequestDecode {
164    /// Make Rocket LocalRequest
165    fn make_request<'c, 's: 'c>(
166        &'s self,
167        client: &'c rocket::local::asynchronous::Client,
168    ) -> rocket::local::asynchronous::LocalRequest<'c> {
169        use rocket::http::Cookie;
170
171        // path, method, remote address, body
172        let req = client
173            .req(self.method, &self.path_and_query)
174            .remote(std::net::SocketAddr::from((self.source_ip, 0u16)))
175            .body(&self.body);
176
177        // Copy cookies
178        let req = self.cookies.iter().fold(req, |req, cookie_name_val| {
179            if let Ok(cookie) = Cookie::parse_encoded(cookie_name_val) {
180                req.cookie(cookie)
181            } else {
182                req
183            }
184        });
185
186        // Copy headers
187        let req = self
188            .headers
189            .iter()
190            .fold(req, |req, header| req.header(header.clone()));
191
192        req
193    }
194}
195
196impl crate::brotli::ResponseCompression for rocket::local::asynchronous::LocalResponse<'_> {
197    /// Content-Encoding header value
198    fn content_encoding<'a>(&'a self) -> Option<&'a str> {
199        self.headers().get_one("content-encoding")
200    }
201
202    /// Content-Type header value
203    fn content_type<'a>(&'a self) -> Option<&'a str> {
204        self.headers().get_one("content-type")
205    }
206}
207
208/// API Gateway response from Rocket response
209async fn api_gateway_response_from_rocket(
210    response: rocket::local::asynchronous::LocalResponse<'_>,
211    client_support_br: bool,
212    multi_value: bool,
213) -> Result<serde_json::Value, rocket::Error> {
214    use crate::brotli::ResponseCompression;
215    use serde_json::json;
216
217    // HTTP status
218    let status_code = response.status().code;
219
220    // Convert header to JSON map
221    let mut cookies = Vec::<String>::new();
222    let mut headers = serde_json::Map::new();
223    for header in response.headers().iter() {
224        let header_name = header.name.into_string();
225        let header_value = header.value.into_owned();
226        if multi_value {
227            // REST API format, returns multiValueHeaders
228            if let Some(values) = headers.get_mut(&header_name) {
229                if let Some(value_ary) = values.as_array_mut() {
230                    value_ary.push(json!(header_value));
231                }
232            } else {
233                headers.insert(header_name, json!([header_value]));
234            }
235        } else {
236            // HTTP API v2 format, returns headers
237            if &header_name == "set-cookie" {
238                cookies.push(header_value);
239            } else {
240                headers.insert(header_name, json!(header_value));
241            }
242        }
243    }
244
245    // check if response should be compressed
246    let compress = client_support_br && response.can_brotli_compress();
247    let body_bytes = response.into_bytes().await.unwrap_or_default();
248    let body_base64 = if compress {
249        if multi_value {
250            headers.insert("content-encoding".to_string(), json!(["br"]));
251        } else {
252            headers.insert("content-encoding".to_string(), json!("br"));
253        }
254        crate::brotli::compress_response_body(&body_bytes)
255    } else {
256        base64::encode(body_bytes)
257    };
258
259    if multi_value {
260        Ok(json!({
261            "isBase64Encoded": true,
262            "statusCode": status_code,
263            "multiValueHeaders": headers,
264            "body": body_base64
265        }))
266    } else {
267        Ok(json!({
268            "isBase64Encoded": true,
269            "statusCode": status_code,
270            "cookies": cookies,
271            "headers": headers,
272            "body": body_base64
273        }))
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use crate::{request::LambdaHttpEvent, test_consts::*};
281    use rocket::{async_test, local::asynchronous::Client};
282    use std::path::PathBuf;
283
284    // Request JSON to actix_http::Request
285    fn prepare_request(event_str: &str) -> RequestDecode {
286        let reqjson: LambdaHttpEvent = serde_json::from_str(event_str).unwrap();
287        let decode = RequestDecode::try_from(reqjson).unwrap();
288        decode
289    }
290
291    #[async_test]
292    async fn test_path_decode() {
293        let rocket = rocket::build();
294        let client = Client::untracked(rocket).await.unwrap();
295
296        let decode = prepare_request(API_GATEWAY_V2_GET_ROOT_NOQUERY);
297        let req = decode.make_request(&client);
298        assert_eq!(&decode.path_and_query, "/");
299        assert_eq!(req.inner().segments(0..), Ok(PathBuf::new()));
300        let decode = prepare_request(API_GATEWAY_REST_GET_ROOT_NOQUERY);
301        let req = decode.make_request(&client);
302        assert_eq!(&decode.path_and_query, "/stage/");
303        assert_eq!(req.inner().segments(0..), Ok(PathBuf::from("stage")));
304
305        let decode = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_NOQUERY);
306        let req = decode.make_request(&client);
307        assert_eq!(&decode.path_and_query, "/somewhere");
308        assert_eq!(req.inner().segments(0..), Ok(PathBuf::from("somewhere")));
309        let decode = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_NOQUERY);
310        let req = decode.make_request(&client);
311        assert_eq!(&decode.path_and_query, "/stage/somewhere");
312        assert_eq!(
313            req.inner().segments(0..),
314            Ok(PathBuf::from("stage/somewhere"))
315        );
316
317        let decode = prepare_request(API_GATEWAY_V2_GET_SPACEPATH_NOQUERY);
318        let req = decode.make_request(&client);
319        assert_eq!(&decode.path_and_query, "/path%20with/space");
320        assert_eq!(
321            req.inner().segments(0..),
322            Ok(PathBuf::from("path with/space"))
323        );
324        let decode = prepare_request(API_GATEWAY_REST_GET_SPACEPATH_NOQUERY);
325        let req = decode.make_request(&client);
326        assert_eq!(&decode.path_and_query, "/stage/path%20with/space");
327        assert_eq!(
328            req.inner().segments(0..),
329            Ok(PathBuf::from("stage/path with/space"))
330        );
331
332        let decode = prepare_request(API_GATEWAY_V2_GET_PERCENTPATH_NOQUERY);
333        let req = decode.make_request(&client);
334        assert_eq!(&decode.path_and_query, "/path%25with/percent");
335        assert_eq!(
336            req.inner().segments(0..),
337            Ok(PathBuf::from("path%with/percent"))
338        );
339        let decode = prepare_request(API_GATEWAY_REST_GET_PERCENTPATH_NOQUERY);
340        let req = decode.make_request(&client);
341        assert_eq!(&decode.path_and_query, "/stage/path%25with/percent");
342        assert_eq!(
343            req.inner().segments(0..),
344            Ok(PathBuf::from("stage/path%with/percent"))
345        );
346
347        let decode = prepare_request(API_GATEWAY_V2_GET_UTF8PATH_NOQUERY);
348        let req = decode.make_request(&client);
349        assert_eq!(
350            &decode.path_and_query,
351            "/%E6%97%A5%E6%9C%AC%E8%AA%9E/%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB%E5%90%8D"
352        );
353        assert_eq!(
354            req.inner().segments(0..),
355            Ok(PathBuf::from("日本語/ファイル名"))
356        );
357        let decode = prepare_request(API_GATEWAY_REST_GET_UTF8PATH_NOQUERY);
358        let req = decode.make_request(&client);
359        assert_eq!(
360            &decode.path_and_query,
361            "/stage/%E6%97%A5%E6%9C%AC%E8%AA%9E/%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB%E5%90%8D"
362        );
363        assert_eq!(
364            req.inner().segments(0..),
365            Ok(PathBuf::from("stage/日本語/ファイル名"))
366        );
367    }
368
369    #[async_test]
370    async fn test_query_decode() {
371        let rocket = rocket::build();
372        let client = Client::untracked(rocket).await.unwrap();
373
374        let decode = prepare_request(API_GATEWAY_V2_GET_ROOT_ONEQUERY);
375        let req = decode.make_request(&client);
376        assert_eq!(&decode.path_and_query, "/?key=value");
377        assert_eq!(req.inner().segments(0..), Ok(PathBuf::new()));
378        assert_eq!(req.inner().query_value::<&str>("key").unwrap(), Ok("value"));
379        let decode = prepare_request(API_GATEWAY_REST_GET_ROOT_ONEQUERY);
380        let req = decode.make_request(&client);
381        assert_eq!(&decode.path_and_query, "/stage/?key=value");
382        assert_eq!(req.inner().segments(0..), Ok(PathBuf::from("stage")));
383        assert_eq!(req.inner().query_value::<&str>("key").unwrap(), Ok("value"));
384
385        let decode = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_ONEQUERY);
386        let req = decode.make_request(&client);
387        assert_eq!(&decode.path_and_query, "/somewhere?key=value");
388        assert_eq!(req.inner().segments(0..), Ok(PathBuf::from("somewhere")));
389        assert_eq!(req.inner().query_value::<&str>("key").unwrap(), Ok("value"));
390        let decode = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_ONEQUERY);
391        let req = decode.make_request(&client);
392        assert_eq!(&decode.path_and_query, "/stage/somewhere?key=value");
393        assert_eq!(
394            req.inner().segments(0..),
395            Ok(PathBuf::from("stage/somewhere"))
396        );
397        assert_eq!(req.inner().query_value::<&str>("key").unwrap(), Ok("value"));
398
399        let decode = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_TWOQUERY);
400        let req = decode.make_request(&client);
401        assert_eq!(
402            req.inner().query_value::<&str>("key1").unwrap(),
403            Ok("value1")
404        );
405        assert_eq!(
406            req.inner().query_value::<&str>("key2").unwrap(),
407            Ok("value2")
408        );
409        let decode = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_TWOQUERY);
410        let req = decode.make_request(&client);
411        assert_eq!(
412            req.inner().query_value::<&str>("key1").unwrap(),
413            Ok("value1")
414        );
415        assert_eq!(
416            req.inner().query_value::<&str>("key2").unwrap(),
417            Ok("value2")
418        );
419
420        let decode = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_SPACEQUERY);
421        let req = decode.make_request(&client);
422        assert_eq!(
423            req.inner().query_value::<&str>("key").unwrap(),
424            Ok("value1 value2")
425        );
426        let decode = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_SPACEQUERY);
427        let req = decode.make_request(&client);
428        assert_eq!(
429            req.inner().query_value::<&str>("key").unwrap(),
430            Ok("value1 value2")
431        );
432
433        let decode = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_UTF8QUERY);
434        let req = decode.make_request(&client);
435        assert_eq!(
436            req.inner().query_value::<&str>("key").unwrap(),
437            Ok("日本語")
438        );
439        let decode = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_UTF8QUERY);
440        let req = decode.make_request(&client);
441        assert_eq!(
442            req.inner().query_value::<&str>("key").unwrap(),
443            Ok("日本語")
444        );
445    }
446
447    #[async_test]
448    async fn test_remote_ip_decode() {
449        use std::net::IpAddr;
450        use std::str::FromStr;
451
452        let rocket = rocket::build();
453        let client = Client::untracked(rocket).await.unwrap();
454
455        let decode = prepare_request(API_GATEWAY_V2_GET_ROOT_ONEQUERY);
456        let req = decode.make_request(&client);
457        assert_eq!(decode.source_ip, IpAddr::from_str("1.2.3.4").unwrap());
458        assert_eq!(
459            req.inner().client_ip(),
460            Some(IpAddr::from_str("1.2.3.4").unwrap())
461        );
462        let decode = prepare_request(API_GATEWAY_REST_GET_ROOT_ONEQUERY);
463        let req = decode.make_request(&client);
464        assert_eq!(decode.source_ip, IpAddr::from_str("1.2.3.4").unwrap());
465        assert_eq!(
466            req.inner().client_ip(),
467            Some(IpAddr::from_str("1.2.3.4").unwrap())
468        );
469
470        let decode = prepare_request(API_GATEWAY_V2_GET_REMOTE_IPV6);
471        let req = decode.make_request(&client);
472        assert_eq!(
473            decode.source_ip,
474            IpAddr::from_str("2404:6800:400a:80c::2004").unwrap()
475        );
476        assert_eq!(
477            req.inner().client_ip(),
478            Some(IpAddr::from_str("2404:6800:400a:80c::2004").unwrap())
479        );
480        let decode = prepare_request(API_GATEWAY_REST_GET_REMOTE_IPV6);
481        let req = decode.make_request(&client);
482        assert_eq!(
483            decode.source_ip,
484            IpAddr::from_str("2404:6800:400a:80c::2004").unwrap()
485        );
486        assert_eq!(
487            req.inner().client_ip(),
488            Some(IpAddr::from_str("2404:6800:400a:80c::2004").unwrap())
489        );
490    }
491
492    #[async_test]
493    async fn test_form_post() {
494        use rocket::http::ContentType;
495        use rocket::http::Method;
496        let rocket = rocket::build();
497        let client = Client::untracked(rocket).await.unwrap();
498
499        let decode = prepare_request(API_GATEWAY_V2_POST_FORM_URLENCODED);
500        let req = decode.make_request(&client);
501        assert_eq!(&decode.body, b"key1=value1&key2=value2&Ok=Ok");
502        assert_eq!(req.inner().method(), Method::Post);
503        assert_eq!(req.inner().content_type(), Some(&ContentType::Form));
504        let decode = prepare_request(API_GATEWAY_REST_POST_FORM_URLENCODED);
505        let req = decode.make_request(&client);
506        assert_eq!(&decode.body, b"key1=value1&key2=value2&Ok=Ok");
507        assert_eq!(req.inner().method(), Method::Post);
508        assert_eq!(req.inner().content_type(), Some(&ContentType::Form));
509
510        // Base64 encoded
511        let decode = prepare_request(API_GATEWAY_V2_POST_FORM_URLENCODED_B64);
512        let req = decode.make_request(&client);
513        assert_eq!(&decode.body, b"key1=value1&key2=value2&Ok=Ok");
514        assert_eq!(req.inner().method(), Method::Post);
515        assert_eq!(req.inner().content_type(), Some(&ContentType::Form));
516        let decode = prepare_request(API_GATEWAY_REST_POST_FORM_URLENCODED_B64);
517        let req = decode.make_request(&client);
518        assert_eq!(&decode.body, b"key1=value1&key2=value2&Ok=Ok");
519        assert_eq!(req.inner().method(), Method::Post);
520        assert_eq!(req.inner().content_type(), Some(&ContentType::Form));
521    }
522
523    #[async_test]
524    async fn test_parse_header() {
525        let rocket = rocket::build();
526        let client = Client::untracked(rocket).await.unwrap();
527
528        let decode = prepare_request(API_GATEWAY_V2_GET_ROOT_NOQUERY);
529        let req = decode.make_request(&client);
530        assert_eq!(
531            req.inner().headers().get_one("x-forwarded-port"),
532            Some("443")
533        );
534        assert_eq!(
535            req.inner().headers().get_one("x-forwarded-proto"),
536            Some("https")
537        );
538        let decode = prepare_request(API_GATEWAY_REST_GET_ROOT_NOQUERY);
539        let req = decode.make_request(&client);
540        assert_eq!(
541            req.inner().headers().get_one("x-forwarded-port"),
542            Some("443")
543        );
544        assert_eq!(
545            req.inner().headers().get_one("x-forwarded-proto"),
546            Some("https")
547        );
548    }
549
550    #[async_test]
551    async fn test_parse_cookies() {
552        let rocket = rocket::build();
553        let client = Client::untracked(rocket).await.unwrap();
554
555        let decode = prepare_request(API_GATEWAY_V2_GET_ROOT_NOQUERY);
556        let req = decode.make_request(&client);
557        assert_eq!(req.inner().cookies().iter().count(), 0);
558        let decode = prepare_request(API_GATEWAY_REST_GET_ROOT_NOQUERY);
559        let req = decode.make_request(&client);
560        assert_eq!(req.inner().cookies().iter().count(), 0);
561
562        let decode = prepare_request(API_GATEWAY_V2_GET_ONE_COOKIE);
563        let req = decode.make_request(&client);
564        assert_eq!(
565            req.inner().cookies().get("cookie1").unwrap().value(),
566            "value1"
567        );
568        let decode = prepare_request(API_GATEWAY_REST_GET_ONE_COOKIE);
569        let req = decode.make_request(&client);
570        assert_eq!(
571            req.inner().cookies().get("cookie1").unwrap().value(),
572            "value1"
573        );
574
575        let decode = prepare_request(API_GATEWAY_V2_GET_TWO_COOKIES);
576        let req = decode.make_request(&client);
577        assert_eq!(
578            req.inner().cookies().get("cookie1").unwrap().value(),
579            "value1"
580        );
581        assert_eq!(
582            req.inner().cookies().get("cookie2").unwrap().value(),
583            "value2"
584        );
585        let decode = prepare_request(API_GATEWAY_REST_GET_TWO_COOKIES);
586        let req = decode.make_request(&client);
587        assert_eq!(
588            req.inner().cookies().get("cookie1").unwrap().value(),
589            "value1"
590        );
591        assert_eq!(
592            req.inner().cookies().get("cookie2").unwrap().value(),
593            "value2"
594        );
595    }
596}