1use crate::request::LambdaHttpEvent;
6use core::convert::TryFrom;
7use core::future::Future;
8use lambda_runtime::{Error as LambdaError, LambdaEvent, Service as LambdaService};
9use std::convert::Infallible;
10use std::pin::Pin;
11
12type HyperRequest = hyper::Request<hyper::Body>;
13type HyperResponse<B> = hyper::Response<B>;
14
15pub async fn run_hyper_on_lambda<S, B>(svc: S) -> Result<(), LambdaError>
71where
72 S: hyper::service::Service<HyperRequest, Response = HyperResponse<B>, Error = Infallible>
73 + 'static,
74 B: hyper::body::HttpBody,
75 <B as hyper::body::HttpBody>::Error: std::error::Error + Send + Sync + 'static,
76{
77 lambda_runtime::run(HyperHandler(svc)).await?;
78 Ok(())
79}
80
81struct HyperHandler<S, B>(S)
83where
84 S: hyper::service::Service<HyperRequest, Response = HyperResponse<B>, Error = Infallible>
85 + 'static,
86 B: hyper::body::HttpBody,
87 <B as hyper::body::HttpBody>::Error: std::error::Error + Send + Sync + 'static;
88
89impl<S, B> LambdaService<LambdaEvent<LambdaHttpEvent<'_>>> for HyperHandler<S, B>
90where
91 S: hyper::service::Service<HyperRequest, Response = HyperResponse<B>, Error = Infallible>
92 + 'static,
93 B: hyper::body::HttpBody,
94 <B as hyper::body::HttpBody>::Error: std::error::Error + Send + Sync + 'static,
95{
96 type Response = serde_json::Value;
97 type Error = Infallible;
98 type Future = Pin<Box<dyn Future<Output = Result<serde_json::Value, Self::Error>>>>;
99
100 fn poll_ready(
102 &mut self,
103 cx: &mut core::task::Context<'_>,
104 ) -> core::task::Poll<Result<(), Self::Error>> {
105 self.0.poll_ready(cx)
106 }
107
108 fn call(&mut self, req: LambdaEvent<LambdaHttpEvent<'_>>) -> Self::Future {
112 use serde_json::json;
113
114 let event = req.payload;
115 let _context = req.context;
116
117 let client_br = event.client_supports_brotli();
119 let multi_value = event.multi_value();
121
122 let hyper_request = HyperRequest::try_from(event);
124
125 let svc_call = hyper_request.map(|req| self.0.call(req));
127
128 let fut = async move {
129 match svc_call {
130 Ok(svc_fut) => {
131 if let Ok(response) = svc_fut.await {
133 api_gateway_response_from_hyper(response, client_br, multi_value)
135 .await
136 .or_else(|_err| {
137 Ok(json!({
138 "isBase64Encoded": false,
139 "statusCode": 500u16,
140 "headers": { "content-type": "text/plain"},
141 "body": "Internal Server Error"
142 }))
143 })
144 } else {
145 Ok(json!({
147 "isBase64Encoded": false,
148 "statusCode": 500u16,
149 "headers": { "content-type": "text/plain"},
150 "body": "Internal Server Error"
151 }))
152 }
153 }
154 Err(_request_err) => {
155 Ok(json!({
157 "isBase64Encoded": false,
158 "statusCode": 400u16,
159 "headers": { "content-type": "text/plain"},
160 "body": "Bad Request"
161 }))
162 }
163 }
164 };
165 Box::pin(fut)
166 }
167}
168
169impl TryFrom<LambdaHttpEvent<'_>> for HyperRequest {
170 type Error = LambdaError;
171
172 fn try_from(event: LambdaHttpEvent) -> Result<Self, Self::Error> {
174 use hyper::header::{HeaderName, HeaderValue};
175 use hyper::Method;
176 use std::str::FromStr;
177
178 let uri = format!(
180 "https://{}{}",
181 event.hostname().unwrap_or("localhost"),
182 event.path_query()
183 );
184
185 let method = Method::try_from(event.method())?;
187
188 let mut reqbuilder = hyper::Request::builder().method(method).uri(&uri);
190
191 if let Some(headers_mut) = reqbuilder.headers_mut() {
193 for (k, v) in event.headers() {
194 if let (Ok(k), Ok(v)) = (
195 HeaderName::from_str(k as &str),
196 HeaderValue::from_str(&v as &str),
197 ) {
198 headers_mut.insert(k, v);
199 }
200 }
201 }
202
203 let req = reqbuilder.body(hyper::Body::from(event.body()?))?;
205
206 Ok(req)
207 }
208}
209
210impl<B> crate::brotli::ResponseCompression for HyperResponse<B> {
211 fn content_encoding<'a>(&'a self) -> Option<&'a str> {
213 self.headers()
214 .get(hyper::header::CONTENT_ENCODING)
215 .and_then(|val| val.to_str().ok())
216 }
217
218 fn content_type<'a>(&'a self) -> Option<&'a str> {
220 self.headers()
221 .get(hyper::header::CONTENT_TYPE)
222 .and_then(|val| val.to_str().ok())
223 }
224}
225
226async fn api_gateway_response_from_hyper<B>(
228 response: HyperResponse<B>,
229 client_support_br: bool,
230 multi_value: bool,
231) -> Result<serde_json::Value, LambdaError>
232where
233 B: hyper::body::HttpBody,
234 <B as hyper::body::HttpBody>::Error: std::error::Error + Send + Sync + 'static,
235{
236 use crate::brotli::ResponseCompression;
237 use hyper::header::SET_COOKIE;
238 use serde_json::json;
239
240 let compress = client_support_br && response.can_brotli_compress();
242
243 let (parts, res_body) = response.into_parts();
245
246 let status_code = parts.status.as_u16();
248
249 let mut cookies = Vec::<String>::new();
251 let mut headers = serde_json::Map::new();
252 for (k, v) in parts.headers.iter() {
253 if let Ok(value_str) = v.to_str() {
254 if multi_value {
255 if let Some(values) = headers.get_mut(k.as_str()) {
257 if let Some(value_ary) = values.as_array_mut() {
258 value_ary.push(json!(value_str));
259 }
260 } else {
261 headers.insert(k.as_str().to_string(), json!([value_str]));
262 }
263 } else {
264 if k == SET_COOKIE {
266 cookies.push(value_str.to_string());
267 } else {
268 headers.insert(k.as_str().to_string(), json!(value_str));
269 }
270 }
271 }
272 }
273
274 let body_bytes = hyper::body::to_bytes(res_body).await?;
276 let body_base64 = if compress {
277 if multi_value {
278 headers.insert("content-encoding".to_string(), json!(["br"]));
279 } else {
280 headers.insert("content-encoding".to_string(), json!("br"));
281 }
282 crate::brotli::compress_response_body(&body_bytes)
283 } else {
284 base64::encode(body_bytes)
285 };
286
287 if multi_value {
288 Ok(json!({
289 "isBase64Encoded": true,
290 "statusCode": status_code,
291 "multiValueHeaders": headers,
292 "body": body_base64
293 }))
294 } else {
295 Ok(json!({
296 "isBase64Encoded": true,
297 "statusCode": status_code,
298 "cookies": cookies,
299 "headers": headers,
300 "body": body_base64
301 }))
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use crate::{request::LambdaHttpEvent, test_consts::*};
309
310 fn prepare_request(event_str: &str) -> HyperRequest {
312 let reqjson: LambdaHttpEvent = serde_json::from_str(event_str).unwrap();
313 let req = HyperRequest::try_from(reqjson).unwrap();
314 req
315 }
316
317 #[test]
318 fn test_path_decode() {
319 let req = prepare_request(API_GATEWAY_V2_GET_ROOT_NOQUERY);
320 assert_eq!(req.uri().path(), "/");
321 let req = prepare_request(API_GATEWAY_REST_GET_ROOT_NOQUERY);
322 assert_eq!(req.uri().path(), "/stage/");
323
324 let req = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_NOQUERY);
325 assert_eq!(req.uri().path(), "/somewhere");
326 let req = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_NOQUERY);
327 assert_eq!(req.uri().path(), "/stage/somewhere");
328
329 let req = prepare_request(API_GATEWAY_V2_GET_SPACEPATH_NOQUERY);
330 assert_eq!(req.uri().path(), "/path%20with/space");
331 let req = prepare_request(API_GATEWAY_REST_GET_SPACEPATH_NOQUERY);
332 assert_eq!(req.uri().path(), "/stage/path%20with/space");
333
334 let req = prepare_request(API_GATEWAY_V2_GET_PERCENTPATH_NOQUERY);
335 assert_eq!(req.uri().path(), "/path%25with/percent");
336 let req = prepare_request(API_GATEWAY_REST_GET_PERCENTPATH_NOQUERY);
337 assert_eq!(req.uri().path(), "/stage/path%25with/percent");
338
339 let req = prepare_request(API_GATEWAY_V2_GET_UTF8PATH_NOQUERY);
340 assert_eq!(
341 req.uri().path(),
342 "/%E6%97%A5%E6%9C%AC%E8%AA%9E/%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB%E5%90%8D"
343 );
344 let req = prepare_request(API_GATEWAY_REST_GET_UTF8PATH_NOQUERY);
345 assert_eq!(
346 req.uri().path(),
347 "/stage/%E6%97%A5%E6%9C%AC%E8%AA%9E/%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB%E5%90%8D"
348 );
349 }
350
351 #[test]
352 fn test_query_decode() {
353 let req = prepare_request(API_GATEWAY_V2_GET_ROOT_ONEQUERY);
354 assert_eq!(req.uri().query(), Some("key=value"));
355 let req = prepare_request(API_GATEWAY_REST_GET_ROOT_ONEQUERY);
356 assert_eq!(req.uri().query(), Some("key=value"));
357
358 let req = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_ONEQUERY);
359 assert_eq!(req.uri().query(), Some("key=value"));
360 let req = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_ONEQUERY);
361 assert_eq!(req.uri().query(), Some("key=value"));
362
363 let req = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_TWOQUERY);
364 assert_eq!(req.uri().query(), Some("key1=value1&key2=value2"));
365 let req = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_TWOQUERY);
366 assert!(
367 req.uri().query() == Some("key1=value1&key2=value2")
368 || req.uri().query() == Some("key2=value2&key1=value1")
369 );
370
371 let req = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_SPACEQUERY);
372 assert_eq!(req.uri().query(), Some("key=value1+value2"));
373 let req = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_SPACEQUERY);
374 assert_eq!(req.uri().query(), Some("key=value1%20value2"));
375
376 let req = prepare_request(API_GATEWAY_V2_GET_SOMEWHERE_UTF8QUERY);
377 assert_eq!(req.uri().query(), Some("key=%E6%97%A5%E6%9C%AC%E8%AA%9E"));
378 let req = prepare_request(API_GATEWAY_REST_GET_SOMEWHERE_UTF8QUERY);
379 assert_eq!(req.uri().query(), Some("key=%E6%97%A5%E6%9C%AC%E8%AA%9E"));
380 }
381
382 #[tokio::test]
383 async fn test_form_post() {
384 use hyper::body::to_bytes;
385 use hyper::Method;
386
387 let req = prepare_request(API_GATEWAY_V2_POST_FORM_URLENCODED);
388 assert_eq!(req.method(), Method::POST);
389 assert_eq!(
390 to_bytes(req.into_body()).await.unwrap().as_ref(),
391 b"key1=value1&key2=value2&Ok=Ok"
392 );
393 let req = prepare_request(API_GATEWAY_REST_POST_FORM_URLENCODED);
394 assert_eq!(req.method(), Method::POST);
395 assert_eq!(
396 to_bytes(req.into_body()).await.unwrap().as_ref(),
397 b"key1=value1&key2=value2&Ok=Ok"
398 );
399
400 let req = prepare_request(API_GATEWAY_V2_POST_FORM_URLENCODED_B64);
402 assert_eq!(req.method(), Method::POST);
403 assert_eq!(
404 to_bytes(req.into_body()).await.unwrap().as_ref(),
405 b"key1=value1&key2=value2&Ok=Ok"
406 );
407 let req = prepare_request(API_GATEWAY_REST_POST_FORM_URLENCODED_B64);
408 assert_eq!(req.method(), Method::POST);
409 assert_eq!(
410 to_bytes(req.into_body()).await.unwrap().as_ref(),
411 b"key1=value1&key2=value2&Ok=Ok"
412 );
413 }
414
415 #[test]
416 fn test_parse_header() {
417 let req = prepare_request(API_GATEWAY_V2_GET_ROOT_NOQUERY);
418 assert_eq!(req.headers().get("x-forwarded-port").unwrap(), &"443");
419 assert_eq!(req.headers().get("x-forwarded-proto").unwrap(), &"https");
420 let req = prepare_request(API_GATEWAY_REST_GET_ROOT_NOQUERY);
421 assert_eq!(req.headers().get("x-forwarded-port").unwrap(), &"443");
422 assert_eq!(req.headers().get("x-forwarded-proto").unwrap(), &"https");
423 }
424
425 #[test]
426 fn test_parse_cookies() {
427 let req = prepare_request(API_GATEWAY_V2_GET_ROOT_NOQUERY);
428 assert_eq!(req.headers().get("cookie"), None);
429 let req = prepare_request(API_GATEWAY_REST_GET_ROOT_NOQUERY);
430 assert_eq!(req.headers().get("cookie"), None);
431
432 let req = prepare_request(API_GATEWAY_V2_GET_ONE_COOKIE);
433 assert_eq!(req.headers().get("cookie").unwrap(), &"cookie1=value1");
434 let req = prepare_request(API_GATEWAY_REST_GET_ONE_COOKIE);
435 assert_eq!(req.headers().get("cookie").unwrap(), &"cookie1=value1");
436
437 let req = prepare_request(API_GATEWAY_V2_GET_TWO_COOKIES);
438 assert_eq!(
439 req.headers().get("cookie").unwrap(),
440 &"cookie1=value1; cookie2=value2"
441 );
442 let req = prepare_request(API_GATEWAY_REST_GET_TWO_COOKIES);
443 assert_eq!(
444 req.headers().get("cookie").unwrap(),
445 &"cookie1=value1; cookie2=value2"
446 );
447 }
448}