1use crate::request::LambdaHttpEvent;
7use core::convert::TryFrom;
8use core::future::Future;
9use lambda_runtime::{Error as LambdaError, LambdaEvent, Service as LambdaService};
10use std::pin::Pin;
11use std::sync::Arc;
12
13pub async fn launch_rocket_on_lambda<P: rocket::Phase>(
39 r: rocket::Rocket<P>,
40) -> Result<(), LambdaError> {
41 lambda_runtime::run(RocketHandler(Arc::new(
42 rocket::local::asynchronous::Client::untracked(r).await?,
43 )))
44 .await?;
45
46 Ok(())
47}
48
49struct RocketHandler(Arc<rocket::local::asynchronous::Client>);
51
52impl LambdaService<LambdaEvent<LambdaHttpEvent<'_>>> for RocketHandler {
53 type Response = serde_json::Value;
54 type Error = rocket::Error;
55 type Future = Pin<Box<dyn Future<Output = Result<serde_json::Value, Self::Error>> + Send>>;
56
57 fn poll_ready(
59 &mut self,
60 _cx: &mut core::task::Context<'_>,
61 ) -> core::task::Poll<Result<(), Self::Error>> {
62 core::task::Poll::Ready(Ok(()))
63 }
64
65 fn call(&mut self, req: LambdaEvent<LambdaHttpEvent<'_>>) -> Self::Future {
69 use serde_json::json;
70
71 let event = req.payload;
72 let _context = req.context;
73
74 let client_br = event.client_supports_brotli();
76 let multi_value = event.multi_value();
78
79 let decode_result = RequestDecode::try_from(event);
81 let client = self.0.clone();
82 let fut = async move {
83 match decode_result {
84 Ok(req_decode) => {
85 let local_request = req_decode.make_request(&client);
87
88 let response = local_request.dispatch().await;
90
91 api_gateway_response_from_rocket(response, client_br, multi_value).await
93 }
94 Err(_request_err) => {
95 Ok(json!({
97 "isBase64Encoded": false,
98 "statusCode": 400u16,
99 "headers": { "content-type": "text/plain"},
100 "body": "Bad Request" }))
102 }
103 }
104 };
105 Box::pin(fut)
106 }
107}
108
109struct RequestDecode {
113 path_and_query: String,
114 method: rocket::http::Method,
115 source_ip: std::net::IpAddr,
116 cookies: Vec<String>,
117 headers: Vec<rocket::http::Header<'static>>,
118 body: Vec<u8>,
119}
120
121impl TryFrom<LambdaHttpEvent<'_>> for RequestDecode {
122 type Error = LambdaError;
123
124 fn try_from(event: LambdaHttpEvent) -> Result<Self, Self::Error> {
126 use rocket::http::{Header, Method};
127 use std::net::IpAddr;
128 use std::str::FromStr;
129
130 let path_and_query = event.path_query();
132
133 let method = Method::from_str(&event.method()).map_err(|_| "InvalidMethod")?;
135 let source_ip = event
136 .source_ip()
137 .unwrap_or(IpAddr::from([0u8, 0u8, 0u8, 0u8]));
138
139 let cookies = event.cookies().iter().map(|c| c.to_string()).collect();
141
142 let headers = event
144 .headers()
145 .iter()
146 .map(|(k, v)| Header::new(k.to_string(), v.to_string()))
147 .collect::<Vec<Header>>();
148
149 let body = event.body()?;
151
152 Ok(Self {
153 path_and_query,
154 method,
155 source_ip,
156 cookies,
157 headers,
158 body,
159 })
160 }
161}
162
163impl RequestDecode {
164 fn make_request<'c, 's: 'c>(
166 &'s self,
167 client: &'c rocket::local::asynchronous::Client,
168 ) -> rocket::local::asynchronous::LocalRequest<'c> {
169 use rocket::http::Cookie;
170
171 let req = client
173 .req(self.method, &self.path_and_query)
174 .remote(std::net::SocketAddr::from((self.source_ip, 0u16)))
175 .body(&self.body);
176
177 let req = self.cookies.iter().fold(req, |req, cookie_name_val| {
179 if let Ok(cookie) = Cookie::parse_encoded(cookie_name_val) {
180 req.cookie(cookie)
181 } else {
182 req
183 }
184 });
185
186 let req = self
188 .headers
189 .iter()
190 .fold(req, |req, header| req.header(header.clone()));
191
192 req
193 }
194}
195
196impl crate::brotli::ResponseCompression for rocket::local::asynchronous::LocalResponse<'_> {
197 fn content_encoding<'a>(&'a self) -> Option<&'a str> {
199 self.headers().get_one("content-encoding")
200 }
201
202 fn content_type<'a>(&'a self) -> Option<&'a str> {
204 self.headers().get_one("content-type")
205 }
206}
207
208async fn api_gateway_response_from_rocket(
210 response: rocket::local::asynchronous::LocalResponse<'_>,
211 client_support_br: bool,
212 multi_value: bool,
213) -> Result<serde_json::Value, rocket::Error> {
214 use crate::brotli::ResponseCompression;
215 use serde_json::json;
216
217 let status_code = response.status().code;
219
220 let mut cookies = Vec::<String>::new();
222 let mut headers = serde_json::Map::new();
223 for header in response.headers().iter() {
224 let header_name = header.name.into_string();
225 let header_value = header.value.into_owned();
226 if multi_value {
227 if let Some(values) = headers.get_mut(&header_name) {
229 if let Some(value_ary) = values.as_array_mut() {
230 value_ary.push(json!(header_value));
231 }
232 } else {
233 headers.insert(header_name, json!([header_value]));
234 }
235 } else {
236 if &header_name == "set-cookie" {
238 cookies.push(header_value);
239 } else {
240 headers.insert(header_name, json!(header_value));
241 }
242 }
243 }
244
245 let compress = client_support_br && response.can_brotli_compress();
247 let body_bytes = response.into_bytes().await.unwrap_or_default();
248 let body_base64 = if compress {
249 if multi_value {
250 headers.insert("content-encoding".to_string(), json!(["br"]));
251 } else {
252 headers.insert("content-encoding".to_string(), json!("br"));
253 }
254 crate::brotli::compress_response_body(&body_bytes)
255 } else {
256 base64::encode(body_bytes)
257 };
258
259 if multi_value {
260 Ok(json!({
261 "isBase64Encoded": true,
262 "statusCode": status_code,
263 "multiValueHeaders": headers,
264 "body": body_base64
265 }))
266 } else {
267 Ok(json!({
268 "isBase64Encoded": true,
269 "statusCode": status_code,
270 "cookies": cookies,
271 "headers": headers,
272 "body": body_base64
273 }))
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use crate::{request::LambdaHttpEvent, test_consts::*};
281 use rocket::{async_test, local::asynchronous::Client};
282 use std::path::PathBuf;
283
284 fn prepare_request(event_str: &str) -> RequestDecode {
286 let reqjson: LambdaHttpEvent = serde_json::from_str(event_str).unwrap();
287 let decode = RequestDecode::try_from(reqjson).unwrap();
288 decode
289 }
290
291 #[async_test]
292 async fn test_path_decode() {
293 let rocket = rocket::build();
294 let client = Client::untracked(rocket).await.unwrap();
295
296 let decode = prepare_request(API_GATEWAY_V2_GET_ROOT_NOQUERY);
297 let req = decode.make_request(&client);
298 assert_eq!(&decode.path_and_query, "/");
299 assert_eq!(req.inner().segments(0..), Ok(PathBuf::new()));
300 let decode = prepare_request(API_GATEWAY_REST_GET_ROOT_NOQUERY);
301 let req = decode.make_request(&client);
302 assert_eq!(&decode.path_and_query, "/stage/");
303 assert_eq!(req.inner().segments(0..), Ok(PathBuf::from("stage")));
304
305 let decode = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_NOQUERY);
306 let req = decode.make_request(&client);
307 assert_eq!(&decode.path_and_query, "/somewhere");
308 assert_eq!(req.inner().segments(0..), Ok(PathBuf::from("somewhere")));
309 let decode = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_NOQUERY);
310 let req = decode.make_request(&client);
311 assert_eq!(&decode.path_and_query, "/stage/somewhere");
312 assert_eq!(
313 req.inner().segments(0..),
314 Ok(PathBuf::from("stage/somewhere"))
315 );
316
317 let decode = prepare_request(API_GATEWAY_V2_GET_SPACEPATH_NOQUERY);
318 let req = decode.make_request(&client);
319 assert_eq!(&decode.path_and_query, "/path%20with/space");
320 assert_eq!(
321 req.inner().segments(0..),
322 Ok(PathBuf::from("path with/space"))
323 );
324 let decode = prepare_request(API_GATEWAY_REST_GET_SPACEPATH_NOQUERY);
325 let req = decode.make_request(&client);
326 assert_eq!(&decode.path_and_query, "/stage/path%20with/space");
327 assert_eq!(
328 req.inner().segments(0..),
329 Ok(PathBuf::from("stage/path with/space"))
330 );
331
332 let decode = prepare_request(API_GATEWAY_V2_GET_PERCENTPATH_NOQUERY);
333 let req = decode.make_request(&client);
334 assert_eq!(&decode.path_and_query, "/path%25with/percent");
335 assert_eq!(
336 req.inner().segments(0..),
337 Ok(PathBuf::from("path%with/percent"))
338 );
339 let decode = prepare_request(API_GATEWAY_REST_GET_PERCENTPATH_NOQUERY);
340 let req = decode.make_request(&client);
341 assert_eq!(&decode.path_and_query, "/stage/path%25with/percent");
342 assert_eq!(
343 req.inner().segments(0..),
344 Ok(PathBuf::from("stage/path%with/percent"))
345 );
346
347 let decode = prepare_request(API_GATEWAY_V2_GET_UTF8PATH_NOQUERY);
348 let req = decode.make_request(&client);
349 assert_eq!(
350 &decode.path_and_query,
351 "/%E6%97%A5%E6%9C%AC%E8%AA%9E/%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB%E5%90%8D"
352 );
353 assert_eq!(
354 req.inner().segments(0..),
355 Ok(PathBuf::from("日本語/ファイル名"))
356 );
357 let decode = prepare_request(API_GATEWAY_REST_GET_UTF8PATH_NOQUERY);
358 let req = decode.make_request(&client);
359 assert_eq!(
360 &decode.path_and_query,
361 "/stage/%E6%97%A5%E6%9C%AC%E8%AA%9E/%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB%E5%90%8D"
362 );
363 assert_eq!(
364 req.inner().segments(0..),
365 Ok(PathBuf::from("stage/日本語/ファイル名"))
366 );
367 }
368
369 #[async_test]
370 async fn test_query_decode() {
371 let rocket = rocket::build();
372 let client = Client::untracked(rocket).await.unwrap();
373
374 let decode = prepare_request(API_GATEWAY_V2_GET_ROOT_ONEQUERY);
375 let req = decode.make_request(&client);
376 assert_eq!(&decode.path_and_query, "/?key=value");
377 assert_eq!(req.inner().segments(0..), Ok(PathBuf::new()));
378 assert_eq!(req.inner().query_value::<&str>("key").unwrap(), Ok("value"));
379 let decode = prepare_request(API_GATEWAY_REST_GET_ROOT_ONEQUERY);
380 let req = decode.make_request(&client);
381 assert_eq!(&decode.path_and_query, "/stage/?key=value");
382 assert_eq!(req.inner().segments(0..), Ok(PathBuf::from("stage")));
383 assert_eq!(req.inner().query_value::<&str>("key").unwrap(), Ok("value"));
384
385 let decode = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_ONEQUERY);
386 let req = decode.make_request(&client);
387 assert_eq!(&decode.path_and_query, "/somewhere?key=value");
388 assert_eq!(req.inner().segments(0..), Ok(PathBuf::from("somewhere")));
389 assert_eq!(req.inner().query_value::<&str>("key").unwrap(), Ok("value"));
390 let decode = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_ONEQUERY);
391 let req = decode.make_request(&client);
392 assert_eq!(&decode.path_and_query, "/stage/somewhere?key=value");
393 assert_eq!(
394 req.inner().segments(0..),
395 Ok(PathBuf::from("stage/somewhere"))
396 );
397 assert_eq!(req.inner().query_value::<&str>("key").unwrap(), Ok("value"));
398
399 let decode = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_TWOQUERY);
400 let req = decode.make_request(&client);
401 assert_eq!(
402 req.inner().query_value::<&str>("key1").unwrap(),
403 Ok("value1")
404 );
405 assert_eq!(
406 req.inner().query_value::<&str>("key2").unwrap(),
407 Ok("value2")
408 );
409 let decode = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_TWOQUERY);
410 let req = decode.make_request(&client);
411 assert_eq!(
412 req.inner().query_value::<&str>("key1").unwrap(),
413 Ok("value1")
414 );
415 assert_eq!(
416 req.inner().query_value::<&str>("key2").unwrap(),
417 Ok("value2")
418 );
419
420 let decode = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_SPACEQUERY);
421 let req = decode.make_request(&client);
422 assert_eq!(
423 req.inner().query_value::<&str>("key").unwrap(),
424 Ok("value1 value2")
425 );
426 let decode = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_SPACEQUERY);
427 let req = decode.make_request(&client);
428 assert_eq!(
429 req.inner().query_value::<&str>("key").unwrap(),
430 Ok("value1 value2")
431 );
432
433 let decode = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_UTF8QUERY);
434 let req = decode.make_request(&client);
435 assert_eq!(
436 req.inner().query_value::<&str>("key").unwrap(),
437 Ok("日本語")
438 );
439 let decode = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_UTF8QUERY);
440 let req = decode.make_request(&client);
441 assert_eq!(
442 req.inner().query_value::<&str>("key").unwrap(),
443 Ok("日本語")
444 );
445 }
446
447 #[async_test]
448 async fn test_remote_ip_decode() {
449 use std::net::IpAddr;
450 use std::str::FromStr;
451
452 let rocket = rocket::build();
453 let client = Client::untracked(rocket).await.unwrap();
454
455 let decode = prepare_request(API_GATEWAY_V2_GET_ROOT_ONEQUERY);
456 let req = decode.make_request(&client);
457 assert_eq!(decode.source_ip, IpAddr::from_str("1.2.3.4").unwrap());
458 assert_eq!(
459 req.inner().client_ip(),
460 Some(IpAddr::from_str("1.2.3.4").unwrap())
461 );
462 let decode = prepare_request(API_GATEWAY_REST_GET_ROOT_ONEQUERY);
463 let req = decode.make_request(&client);
464 assert_eq!(decode.source_ip, IpAddr::from_str("1.2.3.4").unwrap());
465 assert_eq!(
466 req.inner().client_ip(),
467 Some(IpAddr::from_str("1.2.3.4").unwrap())
468 );
469
470 let decode = prepare_request(API_GATEWAY_V2_GET_REMOTE_IPV6);
471 let req = decode.make_request(&client);
472 assert_eq!(
473 decode.source_ip,
474 IpAddr::from_str("2404:6800:400a:80c::2004").unwrap()
475 );
476 assert_eq!(
477 req.inner().client_ip(),
478 Some(IpAddr::from_str("2404:6800:400a:80c::2004").unwrap())
479 );
480 let decode = prepare_request(API_GATEWAY_REST_GET_REMOTE_IPV6);
481 let req = decode.make_request(&client);
482 assert_eq!(
483 decode.source_ip,
484 IpAddr::from_str("2404:6800:400a:80c::2004").unwrap()
485 );
486 assert_eq!(
487 req.inner().client_ip(),
488 Some(IpAddr::from_str("2404:6800:400a:80c::2004").unwrap())
489 );
490 }
491
492 #[async_test]
493 async fn test_form_post() {
494 use rocket::http::ContentType;
495 use rocket::http::Method;
496 let rocket = rocket::build();
497 let client = Client::untracked(rocket).await.unwrap();
498
499 let decode = prepare_request(API_GATEWAY_V2_POST_FORM_URLENCODED);
500 let req = decode.make_request(&client);
501 assert_eq!(&decode.body, b"key1=value1&key2=value2&Ok=Ok");
502 assert_eq!(req.inner().method(), Method::Post);
503 assert_eq!(req.inner().content_type(), Some(&ContentType::Form));
504 let decode = prepare_request(API_GATEWAY_REST_POST_FORM_URLENCODED);
505 let req = decode.make_request(&client);
506 assert_eq!(&decode.body, b"key1=value1&key2=value2&Ok=Ok");
507 assert_eq!(req.inner().method(), Method::Post);
508 assert_eq!(req.inner().content_type(), Some(&ContentType::Form));
509
510 let decode = prepare_request(API_GATEWAY_V2_POST_FORM_URLENCODED_B64);
512 let req = decode.make_request(&client);
513 assert_eq!(&decode.body, b"key1=value1&key2=value2&Ok=Ok");
514 assert_eq!(req.inner().method(), Method::Post);
515 assert_eq!(req.inner().content_type(), Some(&ContentType::Form));
516 let decode = prepare_request(API_GATEWAY_REST_POST_FORM_URLENCODED_B64);
517 let req = decode.make_request(&client);
518 assert_eq!(&decode.body, b"key1=value1&key2=value2&Ok=Ok");
519 assert_eq!(req.inner().method(), Method::Post);
520 assert_eq!(req.inner().content_type(), Some(&ContentType::Form));
521 }
522
523 #[async_test]
524 async fn test_parse_header() {
525 let rocket = rocket::build();
526 let client = Client::untracked(rocket).await.unwrap();
527
528 let decode = prepare_request(API_GATEWAY_V2_GET_ROOT_NOQUERY);
529 let req = decode.make_request(&client);
530 assert_eq!(
531 req.inner().headers().get_one("x-forwarded-port"),
532 Some("443")
533 );
534 assert_eq!(
535 req.inner().headers().get_one("x-forwarded-proto"),
536 Some("https")
537 );
538 let decode = prepare_request(API_GATEWAY_REST_GET_ROOT_NOQUERY);
539 let req = decode.make_request(&client);
540 assert_eq!(
541 req.inner().headers().get_one("x-forwarded-port"),
542 Some("443")
543 );
544 assert_eq!(
545 req.inner().headers().get_one("x-forwarded-proto"),
546 Some("https")
547 );
548 }
549
550 #[async_test]
551 async fn test_parse_cookies() {
552 let rocket = rocket::build();
553 let client = Client::untracked(rocket).await.unwrap();
554
555 let decode = prepare_request(API_GATEWAY_V2_GET_ROOT_NOQUERY);
556 let req = decode.make_request(&client);
557 assert_eq!(req.inner().cookies().iter().count(), 0);
558 let decode = prepare_request(API_GATEWAY_REST_GET_ROOT_NOQUERY);
559 let req = decode.make_request(&client);
560 assert_eq!(req.inner().cookies().iter().count(), 0);
561
562 let decode = prepare_request(API_GATEWAY_V2_GET_ONE_COOKIE);
563 let req = decode.make_request(&client);
564 assert_eq!(
565 req.inner().cookies().get("cookie1").unwrap().value(),
566 "value1"
567 );
568 let decode = prepare_request(API_GATEWAY_REST_GET_ONE_COOKIE);
569 let req = decode.make_request(&client);
570 assert_eq!(
571 req.inner().cookies().get("cookie1").unwrap().value(),
572 "value1"
573 );
574
575 let decode = prepare_request(API_GATEWAY_V2_GET_TWO_COOKIES);
576 let req = decode.make_request(&client);
577 assert_eq!(
578 req.inner().cookies().get("cookie1").unwrap().value(),
579 "value1"
580 );
581 assert_eq!(
582 req.inner().cookies().get("cookie2").unwrap().value(),
583 "value2"
584 );
585 let decode = prepare_request(API_GATEWAY_REST_GET_TWO_COOKIES);
586 let req = decode.make_request(&client);
587 assert_eq!(
588 req.inner().cookies().get("cookie1").unwrap().value(),
589 "value1"
590 );
591 assert_eq!(
592 req.inner().cookies().get("cookie2").unwrap().value(),
593 "value2"
594 );
595 }
596}