1use crate::{http::header::SET_COOKIE, request::LambdaRequest, update_xray_trace_id_header, Request, RequestExt};
2use bytes::Bytes;
3use core::{
4 fmt::Debug,
5 pin::Pin,
6 task::{Context, Poll},
7};
8use futures_util::{Stream, TryFutureExt};
9pub use http::{self, Response};
10use http_body::Body;
11use lambda_runtime::{
12 tower::{
13 util::{MapRequest, MapResponse},
14 ServiceBuilder, ServiceExt,
15 },
16 Diagnostic,
17};
18pub use lambda_runtime::{Error, LambdaEvent, MetadataPrelude, Service, StreamResponse};
19use std::{future::Future, marker::PhantomData};
20
21#[non_exhaustive]
25pub struct StreamAdapter<'a, S, B> {
26 service: S,
27 _phantom_data: PhantomData<&'a B>,
28}
29
30impl<'a, S, B> Clone for StreamAdapter<'a, S, B>
31where
32 S: Clone,
33{
34 fn clone(&self) -> Self {
35 Self {
36 service: self.service.clone(),
37 _phantom_data: PhantomData,
38 }
39 }
40}
41
42impl<'a, S, B, E> From<S> for StreamAdapter<'a, S, B>
43where
44 S: Service<Request, Response = Response<B>, Error = E>,
45 S::Future: Send + 'a,
46 B: Body + Unpin + Send + 'static,
47 B::Data: Into<Bytes> + Send,
48 B::Error: Into<Error> + Send + Debug,
49{
50 fn from(service: S) -> Self {
51 StreamAdapter {
52 service,
53 _phantom_data: PhantomData,
54 }
55 }
56}
57
58impl<'a, S, B, E> Service<LambdaEvent<LambdaRequest>> for StreamAdapter<'a, S, B>
59where
60 S: Service<Request, Response = Response<B>, Error = E>,
61 S::Future: Send + 'a,
62 B: Body + Unpin + Send + 'static,
63 B::Data: Into<Bytes> + Send,
64 B::Error: Into<Error> + Send + Debug,
65{
66 type Response = StreamResponse<BodyStream<B>>;
67 type Error = E;
68 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'a>>;
69
70 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
71 self.service.poll_ready(cx)
72 }
73
74 fn call(&mut self, req: LambdaEvent<LambdaRequest>) -> Self::Future {
75 let LambdaEvent { payload, context } = req;
76 let mut event: Request = payload.into();
77 update_xray_trace_id_header(event.headers_mut(), &context);
78 Box::pin(
79 self.service
80 .call(event.with_lambda_context(context))
81 .map_ok(into_stream_response),
82 )
83 }
84}
85
86#[allow(clippy::type_complexity)]
95fn into_stream_service<'a, S, B, E>(
96 handler: S,
97) -> MapResponse<
98 MapRequest<S, impl FnMut(LambdaEvent<LambdaRequest>) -> Request>,
99 impl FnOnce(Response<B>) -> StreamResponse<BodyStream<B>> + Clone,
100>
101where
102 S: Service<Request, Response = Response<B>, Error = E>,
103 S::Future: Send + 'a,
104 E: Debug + Into<Diagnostic>,
105 B: Body + Unpin + Send + 'static,
106 B::Data: Into<Bytes> + Send,
107 B::Error: Into<Error> + Send + Debug,
108{
109 ServiceBuilder::new()
110 .map_request(event_to_request as fn(LambdaEvent<LambdaRequest>) -> Request)
111 .service(handler)
112 .map_response(into_stream_response)
113}
114
115#[cfg(feature = "concurrency-tokio")]
118type EventToRequest = fn(LambdaEvent<LambdaRequest>) -> Request;
119
120#[cfg(feature = "concurrency-tokio")]
121#[allow(clippy::type_complexity)]
122fn into_stream_service_cloneable<S, B, E>(
123 handler: S,
124) -> MapResponse<MapRequest<S, EventToRequest>, impl FnOnce(Response<B>) -> StreamResponse<BodyStream<B>> + Clone>
125where
126 S: Service<Request, Response = Response<B>, Error = E> + Clone + Send + 'static,
127 S::Future: Send + 'static,
128 E: Debug + Into<Diagnostic> + Send + 'static,
129 B: Body + Unpin + Send + 'static,
130 B::Data: Into<Bytes> + Send,
131 B::Error: Into<Error> + Send + Debug,
132{
133 ServiceBuilder::new()
134 .map_request(event_to_request as EventToRequest)
135 .service(handler)
136 .map_response(into_stream_response)
137}
138
139fn into_stream_response<B>(res: Response<B>) -> StreamResponse<BodyStream<B>>
141where
142 B: Body + Unpin + Send + 'static,
143 B::Data: Into<Bytes> + Send,
144 B::Error: Into<Error> + Send + Debug,
145{
146 let (parts, body) = res.into_parts();
147
148 let mut headers = parts.headers;
149 let cookies = headers
150 .get_all(SET_COOKIE)
151 .iter()
152 .map(|c| String::from_utf8_lossy(c.as_bytes()).to_string())
153 .collect::<Vec<_>>();
154 headers.remove(SET_COOKIE);
155
156 StreamResponse {
157 metadata_prelude: MetadataPrelude {
158 headers,
159 status_code: parts.status,
160 cookies,
161 },
162 stream: BodyStream { body },
163 }
164}
165
166fn event_to_request(req: LambdaEvent<LambdaRequest>) -> Request {
167 let LambdaEvent { payload, context } = req;
168 let mut event: Request = payload.into();
169 update_xray_trace_id_header(event.headers_mut(), &context);
170 event.with_lambda_context(context)
171}
172
173pub async fn run_with_streaming_response<'a, S, B, E>(handler: S) -> Result<(), Error>
192where
193 S: Service<Request, Response = Response<B>, Error = E>,
194 S::Future: Send + 'a,
195 E: Debug + Into<Diagnostic>,
196 B: Body + Unpin + Send + 'static,
197 B::Data: Into<Bytes> + Send,
198 B::Error: Into<Error> + Send + Debug,
199{
200 lambda_runtime::run(into_stream_service(handler)).await
201}
202
203#[cfg(feature = "concurrency-tokio")]
220#[cfg_attr(docsrs, doc(cfg(feature = "concurrency-tokio")))]
221pub async fn run_with_streaming_response_concurrent<S, B, E>(handler: S) -> Result<(), Error>
222where
223 S: Service<Request, Response = Response<B>, Error = E> + Clone + Send + 'static,
224 S::Future: Send + 'static,
225 E: Debug + Into<Diagnostic> + Send + 'static,
226 B: Body + Unpin + Send + 'static,
227 B::Data: Into<Bytes> + Send,
228 B::Error: Into<Error> + Send + Debug,
229{
230 lambda_runtime::run_concurrent(into_stream_service_cloneable(handler)).await
231}
232
233pin_project_lite::pin_project! {
234#[non_exhaustive]
235pub struct BodyStream<B> {
236 #[pin]
237 pub(crate) body: B,
238}
239}
240
241impl<B> Stream for BodyStream<B>
242where
243 B: Body + Unpin + Send + 'static,
244 B::Data: Into<Bytes> + Send,
245 B::Error: Into<Error> + Send + Debug,
246{
247 type Item = Result<B::Data, B::Error>;
248
249 #[inline]
250 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
251 match futures_util::ready!(self.as_mut().project().body.poll_frame(cx)?) {
252 Some(frame) => match frame.into_data() {
253 Ok(data) => Poll::Ready(Some(Ok(data))),
254 Err(_frame) => Poll::Ready(None),
255 },
256 None => Poll::Ready(None),
257 }
258 }
259}
260
261#[cfg(test)]
262mod test_stream_adapter {
263 use super::*;
264
265 use crate::Body;
266 use http::StatusCode;
267
268 struct LogService<S> {
270 inner: S,
271 }
272
273 impl<S> Service<LambdaEvent<LambdaRequest>> for LogService<S>
274 where
275 S: Service<LambdaEvent<LambdaRequest>>,
276 {
277 type Response = S::Response;
278 type Error = S::Error;
279 type Future = S::Future;
280
281 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
282 self.inner.poll_ready(cx)
283 }
284
285 fn call(&mut self, event: LambdaEvent<LambdaRequest>) -> Self::Future {
286 println!("Lambda event: {event:#?}");
287 self.inner.call(event)
288 }
289 }
290
291 #[test]
292 fn stream_adapter_is_boxable() {
293 let svc = ServiceBuilder::new()
295 .layer_fn(|service| LogService { inner: service })
296 .layer_fn(StreamAdapter::from)
297 .service_fn(
298 |_req: Request| async move { http::Response::builder().status(StatusCode::OK).body(Body::Empty) },
299 );
300 let _boxed_svc = svc.boxed();
302 }
303}