Skip to main content

aws_smithy_http_server/routing/
lambda_handler.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use http::uri;
7use lambda_http::{Request, RequestExt};
8use std::{
9    fmt::Debug,
10    task::{Context, Poll},
11};
12use tower::Service;
13
14type ServiceRequest = http::Request<crate::body::BoxBodySync>;
15
16/// A [`Service`] that takes a `lambda_http::Request` and converts
17/// it to `http::Request<BoxBody>`.
18///
19/// **This version is only guaranteed to be compatible with
20/// [`lambda_http`](https://docs.rs/lambda_http) ^1.** Please ensure that your service crate's
21/// `Cargo.toml` depends on a compatible version.
22///
23/// [`Service`]: tower::Service
24#[derive(Debug, Clone)]
25pub struct LambdaHandler<S> {
26    service: S,
27}
28
29impl<S> LambdaHandler<S> {
30    pub fn new(service: S) -> Self {
31        Self { service }
32    }
33}
34
35impl<S> Service<Request> for LambdaHandler<S>
36where
37    S: Service<ServiceRequest>,
38{
39    type Error = S::Error;
40    type Response = S::Response;
41    type Future = S::Future;
42
43    #[inline]
44    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
45        self.service.poll_ready(cx)
46    }
47
48    fn call(&mut self, event: Request) -> Self::Future {
49        self.service.call(convert_event(event))
50    }
51}
52
53/// Converts a `lambda_http::Request` into a `http::Request<crate::body::BoxBodySync>`
54/// Issue: <https://github.com/smithy-lang/smithy-rs/issues/1125>
55///
56/// While converting the event the [API Gateway Stage] portion of the URI
57/// is removed from the uri that gets returned as a new `http::Request`.
58///
59/// [API Gateway Stage]: https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-stages.html
60fn convert_event(request: Request) -> ServiceRequest {
61    let raw_path: &str = request.extensions().raw_http_path();
62    let path: &str = request.uri().path();
63
64    let (parts, body) = if !raw_path.is_empty() && raw_path != path {
65        let mut path = raw_path.to_owned(); // Clone only when we need to strip out the stage.
66        let (mut parts, body) = request.into_parts();
67
68        let uri_parts: uri::Parts = parts.uri.into();
69        let path_and_query = uri_parts
70            .path_and_query
71            .expect("request URI does not have `PathAndQuery`");
72
73        if let Some(query) = path_and_query.query() {
74            path.push('?');
75            path.push_str(query);
76        }
77
78        parts.uri = uri::Uri::builder()
79            .authority(uri_parts.authority.expect("request URI does not have authority set"))
80            .scheme(uri_parts.scheme.expect("request URI does not have scheme set"))
81            .path_and_query(path)
82            .build()
83            .expect("unable to construct new URI");
84
85        (parts, body)
86    } else {
87        request.into_parts()
88    };
89
90    let body = match body {
91        lambda_http::Body::Empty => crate::body::empty_sync(),
92        lambda_http::Body::Text(s) => crate::body::to_boxed_sync(s),
93        lambda_http::Body::Binary(v) => crate::body::to_boxed_sync(v),
94        _ => {
95            tracing::error!("Unknown `lambda_http::Body` variant encountered, falling back to empty body");
96            crate::body::empty_sync()
97        }
98    };
99
100    http::Request::from_parts(parts, body)
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use bytes::Bytes;
107    use lambda_http::RequestExt;
108
109    /// Test utility to collect all bytes from a body.
110    async fn collect_bytes<B>(body: B) -> Result<Bytes, crate::Error>
111    where
112        B: http_body::Body,
113        B::Error: Into<crate::error::BoxError>,
114    {
115        use http_body_util::BodyExt;
116        let collected = body.collect().await.map_err(crate::Error::new)?;
117        Ok(collected.to_bytes())
118    }
119
120    #[test]
121    fn traits() {
122        use crate::test_helpers::*;
123
124        assert_send::<LambdaHandler<()>>();
125        assert_sync::<LambdaHandler<()>>();
126    }
127
128    #[test]
129    fn raw_http_path() {
130        // lambda_http::Request doesn't have a fn `builder`
131        let event = http::Request::builder()
132            .uri("https://id.execute-api.us-east-1.amazonaws.com/prod/resources/1")
133            .body(())
134            .expect("unable to build Request");
135        let (parts, _) = event.into_parts();
136
137        // the lambda event will have a raw path which is the path without stage name in it
138        let event =
139            lambda_http::Request::from_parts(parts, lambda_http::Body::Empty).with_raw_http_path("/resources/1");
140        let request = convert_event(event);
141
142        assert_eq!(request.uri().path(), "/resources/1");
143    }
144
145    #[tokio::test]
146    async fn body_conversion_empty() {
147        let event = http::Request::builder()
148            .uri("https://id.execute-api.us-east-1.amazonaws.com/test")
149            .body(())
150            .expect("unable to build Request");
151        let (parts, _) = event.into_parts();
152        let event = lambda_http::Request::from_parts(parts, lambda_http::Body::Empty);
153        let request = convert_event(event);
154        let bytes = collect_bytes(request.into_body()).await.unwrap();
155        assert_eq!(bytes.len(), 0);
156    }
157
158    #[tokio::test]
159    async fn body_conversion_text() {
160        let event = http::Request::builder()
161            .uri("https://id.execute-api.us-east-1.amazonaws.com/test")
162            .body(())
163            .expect("unable to build Request");
164        let (parts, _) = event.into_parts();
165        let event = lambda_http::Request::from_parts(parts, lambda_http::Body::Text("hello world".to_string()));
166        let request = convert_event(event);
167        let bytes = collect_bytes(request.into_body()).await.unwrap();
168        assert_eq!(bytes, "hello world");
169    }
170
171    #[tokio::test]
172    async fn body_conversion_binary() {
173        let event = http::Request::builder()
174            .uri("https://id.execute-api.us-east-1.amazonaws.com/test")
175            .body(())
176            .expect("unable to build Request");
177        let (parts, _) = event.into_parts();
178        let event = lambda_http::Request::from_parts(parts, lambda_http::Body::Binary(vec![1, 2, 3, 4, 5]));
179        let request = convert_event(event);
180        let bytes = collect_bytes(request.into_body()).await.unwrap();
181        assert_eq!(bytes.as_ref(), &[1, 2, 3, 4, 5]);
182    }
183
184    #[test]
185    fn uri_with_query_string() {
186        let event = http::Request::builder()
187            .uri("https://id.execute-api.us-east-1.amazonaws.com/prod/resources/1?foo=bar&baz=qux")
188            .body(())
189            .expect("unable to build Request");
190        let (parts, _) = event.into_parts();
191        let event =
192            lambda_http::Request::from_parts(parts, lambda_http::Body::Empty).with_raw_http_path("/resources/1");
193        let request = convert_event(event);
194
195        assert_eq!(request.uri().path(), "/resources/1");
196        assert_eq!(request.uri().query(), Some("foo=bar&baz=qux"));
197    }
198
199    #[test]
200    fn uri_without_stage_stripping() {
201        // When raw_http_path is empty or matches the path, no stripping should occur
202        let event = http::Request::builder()
203            .uri("https://id.execute-api.us-east-1.amazonaws.com/resources/1")
204            .body(())
205            .expect("unable to build Request");
206        let (parts, _) = event.into_parts();
207        let event = lambda_http::Request::from_parts(parts, lambda_http::Body::Empty);
208        let request = convert_event(event);
209
210        assert_eq!(request.uri().path(), "/resources/1");
211    }
212
213    #[test]
214    fn headers_are_preserved() {
215        let event = http::Request::builder()
216            .uri("https://id.execute-api.us-east-1.amazonaws.com/test")
217            .header("content-type", "application/json")
218            .header("x-custom-header", "custom-value")
219            .body(())
220            .expect("unable to build Request");
221        let (parts, _) = event.into_parts();
222        let event = lambda_http::Request::from_parts(parts, lambda_http::Body::Empty);
223        let request = convert_event(event);
224
225        assert_eq!(request.headers().get("content-type").unwrap(), "application/json");
226        assert_eq!(request.headers().get("x-custom-header").unwrap(), "custom-value");
227    }
228
229    #[test]
230    fn extensions_are_preserved() {
231        let event = http::Request::builder()
232            .uri("https://id.execute-api.us-east-1.amazonaws.com/test")
233            .body(())
234            .expect("unable to build Request");
235        let (mut parts, _) = event.into_parts();
236
237        // Add a test extension
238        #[derive(Debug, Clone, PartialEq)]
239        struct TestExtension(String);
240        parts.extensions.insert(TestExtension("test-value".to_string()));
241
242        let event = lambda_http::Request::from_parts(parts, lambda_http::Body::Empty);
243        let request = convert_event(event);
244
245        let ext = request.extensions().get::<TestExtension>();
246        assert!(ext.is_some());
247        assert_eq!(ext.unwrap(), &TestExtension("test-value".to_string()));
248    }
249
250    #[test]
251    fn method_is_preserved() {
252        let event = http::Request::builder()
253            .method("POST")
254            .uri("https://id.execute-api.us-east-1.amazonaws.com/test")
255            .body(())
256            .expect("unable to build Request");
257        let (parts, _) = event.into_parts();
258        let event = lambda_http::Request::from_parts(parts, lambda_http::Body::Empty);
259        let request = convert_event(event);
260
261        assert_eq!(request.method(), http::Method::POST);
262    }
263
264    #[tokio::test]
265    async fn lambda_handler_service_integration() {
266        use tower::ServiceExt;
267
268        // Create a simple service that echoes the URI path
269        let inner_service = tower::service_fn(|req: ServiceRequest| async move {
270            let path = req.uri().path().to_string();
271            let response = http::Response::builder()
272                .status(200)
273                .body(crate::body::to_boxed(path))
274                .unwrap();
275            Ok::<_, std::convert::Infallible>(response)
276        });
277
278        let mut lambda_handler = LambdaHandler::new(inner_service);
279
280        // Create a lambda request
281        let event = http::Request::builder()
282            .uri("https://id.execute-api.us-east-1.amazonaws.com/prod/test/path")
283            .body(())
284            .expect("unable to build Request");
285        let (parts, _) = event.into_parts();
286        let event = lambda_http::Request::from_parts(parts, lambda_http::Body::Empty).with_raw_http_path("/test/path");
287
288        // Call the service
289        let response = lambda_handler.ready().await.unwrap().call(event).await.unwrap();
290
291        // Verify response
292        assert_eq!(response.status(), 200);
293        let body_bytes = collect_bytes(response.into_body()).await.unwrap();
294        assert_eq!(body_bytes, "/test/path");
295    }
296
297    #[tokio::test]
298    async fn lambda_handler_with_request_body() {
299        use tower::ServiceExt;
300
301        // Create a service that processes the request body
302        let inner_service = tower::service_fn(|req: ServiceRequest| async move {
303            let body_bytes = collect_bytes(req.into_body()).await.unwrap();
304            let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
305
306            let response_body = format!("Received: {body_str}");
307            let response = http::Response::builder()
308                .status(200)
309                .header("content-type", "text/plain")
310                .body(crate::body::to_boxed(response_body))
311                .unwrap();
312            Ok::<_, std::convert::Infallible>(response)
313        });
314
315        let mut lambda_handler = LambdaHandler::new(inner_service);
316
317        // Create a lambda request with JSON body
318        let event = http::Request::builder()
319            .method("POST")
320            .uri("https://id.execute-api.us-east-1.amazonaws.com/api/process")
321            .header("content-type", "application/json")
322            .body(())
323            .expect("unable to build Request");
324        let (parts, _) = event.into_parts();
325        let event = lambda_http::Request::from_parts(parts, lambda_http::Body::Text(r#"{"key":"value"}"#.to_string()));
326
327        // Call the service
328        let response = lambda_handler.ready().await.unwrap().call(event).await.unwrap();
329
330        // Verify response
331        assert_eq!(response.status(), 200);
332        assert_eq!(response.headers().get("content-type").unwrap(), "text/plain");
333        let body_bytes = collect_bytes(response.into_body()).await.unwrap();
334        assert_eq!(body_bytes, r#"Received: {"key":"value"}"#);
335    }
336
337    #[tokio::test]
338    async fn lambda_handler_response_headers() {
339        use tower::ServiceExt;
340
341        // Create a service that returns custom headers
342        let inner_service = tower::service_fn(|_req: ServiceRequest| async move {
343            let response = http::Response::builder()
344                .status(201)
345                .header("x-custom-header", "custom-value")
346                .header("content-type", "application/json")
347                .header("x-request-id", "12345")
348                .body(crate::body::to_boxed(r#"{"status":"created"}"#))
349                .unwrap();
350            Ok::<_, std::convert::Infallible>(response)
351        });
352
353        let mut lambda_handler = LambdaHandler::new(inner_service);
354
355        let event = http::Request::builder()
356            .uri("https://id.execute-api.us-east-1.amazonaws.com/api/create")
357            .body(())
358            .expect("unable to build Request");
359        let (parts, _) = event.into_parts();
360        let event = lambda_http::Request::from_parts(parts, lambda_http::Body::Empty);
361
362        // Call the service
363        let response = lambda_handler.ready().await.unwrap().call(event).await.unwrap();
364
365        // Verify all response components
366        assert_eq!(response.status(), 201);
367        assert_eq!(response.headers().get("x-custom-header").unwrap(), "custom-value");
368        assert_eq!(response.headers().get("content-type").unwrap(), "application/json");
369        assert_eq!(response.headers().get("x-request-id").unwrap(), "12345");
370
371        let body_bytes = collect_bytes(response.into_body()).await.unwrap();
372        assert_eq!(body_bytes, r#"{"status":"created"}"#);
373    }
374
375    #[tokio::test]
376    async fn lambda_handler_error_response() {
377        use tower::ServiceExt;
378
379        // Create a service that returns an error status
380        let inner_service = tower::service_fn(|_req: ServiceRequest| async move {
381            let response = http::Response::builder()
382                .status(404)
383                .header("content-type", "application/json")
384                .body(crate::body::to_boxed(r#"{"error":"not found"}"#))
385                .unwrap();
386            Ok::<_, std::convert::Infallible>(response)
387        });
388
389        let mut lambda_handler = LambdaHandler::new(inner_service);
390
391        let event = http::Request::builder()
392            .uri("https://id.execute-api.us-east-1.amazonaws.com/api/missing")
393            .body(())
394            .expect("unable to build Request");
395        let (parts, _) = event.into_parts();
396        let event = lambda_http::Request::from_parts(parts, lambda_http::Body::Empty);
397
398        // Call the service
399        let response = lambda_handler.ready().await.unwrap().call(event).await.unwrap();
400
401        // Verify error response
402        assert_eq!(response.status(), 404);
403        let body_bytes = collect_bytes(response.into_body()).await.unwrap();
404        assert_eq!(body_bytes, r#"{"error":"not found"}"#);
405    }
406}