1use axum::response::IntoResponse;
2use http_body_util::BodyExt;
3use lambda_http::RequestExt;
4use std::{future::Future, pin::Pin};
5use tower::Layer;
6use tower_service::Service;
7
8#[derive(Default, Clone, Copy)]
9pub struct LambdaLayer {
10 trim_stage: bool,
11}
12
13impl LambdaLayer {
14 pub fn trim_stage(mut self) -> Self {
15 self.trim_stage = true;
16 self
17 }
18}
19
20impl<S> Layer<S> for LambdaLayer {
21 type Service = LambdaService<S>;
22
23 fn layer(&self, inner: S) -> Self::Service {
24 LambdaService {
25 inner,
26 layer: *self,
27 }
28 }
29}
30
31pub struct LambdaService<S> {
32 inner: S,
33 layer: LambdaLayer,
34}
35
36impl<S> Service<lambda_http::Request> for LambdaService<S>
37where
38 S: Service<axum::http::Request<axum::body::Body>>,
39 S::Response: axum::response::IntoResponse + Send + 'static,
40 S::Error: std::error::Error + Send + Sync + 'static,
41 S::Future: Send + 'static,
42{
43 type Response = lambda_http::Response<lambda_http::Body>;
44 type Error = lambda_http::Error;
45 type Future =
46 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
47
48 fn poll_ready(
49 &mut self,
50 cx: &mut std::task::Context<'_>,
51 ) -> std::task::Poll<Result<(), Self::Error>> {
52 self.inner.poll_ready(cx).map_err(Into::into)
53 }
54
55 fn call(&mut self, req: lambda_http::Request) -> Self::Future {
56 let uri = req.uri().clone();
57 let rawpath = req.raw_http_path().to_owned();
58 let (mut parts, body) = req.into_parts();
59 let body = match body {
60 lambda_http::Body::Empty => axum::body::Body::default(),
61 lambda_http::Body::Text(t) => t.into(),
62 lambda_http::Body::Binary(v) => v.into(),
63 };
64
65 if self.layer.trim_stage {
66 let mut url = match uri.host() {
67 None => rawpath,
68 Some(host) => format!(
69 "{}://{}{}",
70 uri.scheme_str().unwrap_or("https"),
71 host,
72 rawpath
73 ),
74 };
75
76 if let Some(query) = uri.query() {
77 url.push('?');
78 url.push_str(query);
79 }
80 parts.uri = url.parse::<hyper::Uri>().unwrap();
81 }
82
83 let request = axum::http::Request::from_parts(parts, body);
84
85 let fut = self.inner.call(request);
86 let fut = async move {
87 let resp = fut.await?;
88 let (parts, body) = resp.into_response().into_parts();
89 let bytes = body.into_data_stream().collect().await?.to_bytes();
90 let bytes: &[u8] = &bytes;
91 let resp: hyper::Response<lambda_http::Body> = match std::str::from_utf8(bytes) {
92 Ok(s) => hyper::Response::from_parts(parts, s.into()),
93 Err(_) => hyper::Response::from_parts(parts, bytes.into()),
94 };
95 Ok(resp)
96 };
97
98 Box::pin(fut)
99 }
100}