axum_aws_lambda/
lib.rs

1use axum::response::IntoResponse;
2use http_body_util::BodyExt;
3use lambda_http::RequestExt;
4use std::{future::Future, pin::Pin};
5use tower::Layer;
6use tower_service::Service;
7
8#[derive(Default, Clone, Copy)]
9pub struct LambdaLayer {
10    trim_stage: bool,
11}
12
13impl LambdaLayer {
14    pub fn trim_stage(mut self) -> Self {
15        self.trim_stage = true;
16        self
17    }
18}
19
20impl<S> Layer<S> for LambdaLayer {
21    type Service = LambdaService<S>;
22
23    fn layer(&self, inner: S) -> Self::Service {
24        LambdaService {
25            inner,
26            layer: *self,
27        }
28    }
29}
30
31pub struct LambdaService<S> {
32    inner: S,
33    layer: LambdaLayer,
34}
35
36impl<S> Service<lambda_http::Request> for LambdaService<S>
37where
38    S: Service<axum::http::Request<axum::body::Body>>,
39    S::Response: axum::response::IntoResponse + Send + 'static,
40    S::Error: std::error::Error + Send + Sync + 'static,
41    S::Future: Send + 'static,
42{
43    type Response = lambda_http::Response<lambda_http::Body>;
44    type Error = lambda_http::Error;
45    type Future =
46        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
47
48    fn poll_ready(
49        &mut self,
50        cx: &mut std::task::Context<'_>,
51    ) -> std::task::Poll<Result<(), Self::Error>> {
52        self.inner.poll_ready(cx).map_err(Into::into)
53    }
54
55    fn call(&mut self, req: lambda_http::Request) -> Self::Future {
56        let uri = req.uri().clone();
57        let rawpath = req.raw_http_path().to_owned();
58        let (mut parts, body) = req.into_parts();
59        let body = match body {
60            lambda_http::Body::Empty => axum::body::Body::default(),
61            lambda_http::Body::Text(t) => t.into(),
62            lambda_http::Body::Binary(v) => v.into(),
63        };
64
65        if self.layer.trim_stage {
66            let mut url = match uri.host() {
67                None => rawpath,
68                Some(host) => format!(
69                    "{}://{}{}",
70                    uri.scheme_str().unwrap_or("https"),
71                    host,
72                    rawpath
73                ),
74            };
75
76            if let Some(query) = uri.query() {
77                url.push('?');
78                url.push_str(query);
79            }
80            parts.uri = url.parse::<hyper::Uri>().unwrap();
81        }
82
83        let request = axum::http::Request::from_parts(parts, body);
84
85        let fut = self.inner.call(request);
86        let fut = async move {
87            let resp = fut.await?;
88            let (parts, body) = resp.into_response().into_parts();
89            let bytes = body.into_data_stream().collect().await?.to_bytes();
90            let bytes: &[u8] = &bytes;
91            let resp: hyper::Response<lambda_http::Body> = match std::str::from_utf8(bytes) {
92                Ok(s) => hyper::Response::from_parts(parts, s.into()),
93                Err(_) => hyper::Response::from_parts(parts, bytes.into()),
94            };
95            Ok(resp)
96        };
97
98        Box::pin(fut)
99    }
100}