1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
use axum::response::IntoResponse;
use http::Uri;
use lambda_http::RequestExt;
use std::{future::Future, pin::Pin};
use tower::Layer;
use tower_service::Service;

#[derive(Default, Clone, Copy)]
pub struct LambdaLayer {
    trim_stage: bool,
}

impl LambdaLayer {
    pub fn trim_stage(mut self) -> Self {
        self.trim_stage = true;
        self
    }
}

impl<S> Layer<S> for LambdaLayer {
    type Service = LambdaService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        LambdaService {
            inner,
            layer: self.clone(),
        }
    }
}

pub struct LambdaService<S> {
    inner: S,
    layer: LambdaLayer,
}

impl<S> Service<lambda_http::Request> for LambdaService<S>
where
    S: Service<axum::http::Request<axum::body::Body>>,
    S::Response: axum::response::IntoResponse + Send + 'static,
    S::Error: std::error::Error + Send + Sync + 'static,
    S::Future: Send + 'static,
{
    type Response = lambda_http::Response<lambda_http::Body>;
    type Error = lambda_http::Error;
    type Future =
        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;

    fn poll_ready(
        &mut self,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx).map_err(Into::into)
    }

    fn call(&mut self, req: lambda_http::Request) -> Self::Future {
        let uri = req.uri().clone();
        let rawpath = req.raw_http_path();
        let (mut parts, body) = req.into_parts();
        let body = match body {
            lambda_http::Body::Empty => axum::body::Body::default(),
            lambda_http::Body::Text(t) => t.into(),
            lambda_http::Body::Binary(v) => v.into(),
        };

        if self.layer.trim_stage {
            let mut url = match uri.host() {
                None => rawpath,
                Some(host) => format!(
                    "{}://{}{}",
                    uri.scheme_str().unwrap_or("https"),
                    host,
                    rawpath
                ),
            };

            if let Some(query) = uri.query() {
                url.push('?');
                url.push_str(&query);
            }
            parts.uri = url.parse::<Uri>().unwrap();
        }

        let request = axum::http::Request::from_parts(parts, body);

        let fut = self.inner.call(request);
        let fut = async move {
            let resp = fut.await?;
            let (parts, body) = resp.into_response().into_parts();
            let bytes = hyper::body::to_bytes(body).await?;
            let bytes: &[u8] = &bytes;
            let resp: hyper::Response<lambda_http::Body> = match std::str::from_utf8(bytes) {
                Ok(s) => hyper::Response::from_parts(parts, s.into()),
                Err(_) => hyper::Response::from_parts(parts, bytes.into()),
            };
            Ok(resp)
        };

        Box::pin(fut)
    }
}