rama_http/service/web/endpoint/response/
into_response.rs

1use super::{IntoResponseParts, ResponseParts};
2use crate::dep::http_body::{Frame, SizeHint};
3use crate::dep::mime;
4use crate::{Body, Response};
5use crate::{
6    StatusCode,
7    dep::http::Extensions,
8    header::{self, HeaderMap, HeaderName, HeaderValue},
9};
10use bytes::{Buf, Bytes, BytesMut, buf::Chain};
11use rama_core::error::BoxError;
12use rama_http_types::dep::{http, http_body};
13use rama_utils::macros::all_the_tuples_no_last_special_case;
14use std::{
15    borrow::Cow,
16    convert::Infallible,
17    fmt,
18    pin::Pin,
19    task::{Context, Poll},
20};
21
22/// Trait for generating responses.
23///
24/// Types that implement `IntoResponse` can be returned from handlers.
25///
26/// # Implementing `IntoResponse`
27///
28/// You generally shouldn't have to implement `IntoResponse` manually, as rama
29/// provides implementations for many common types.
30pub trait IntoResponse {
31    /// Create a response.
32    fn into_response(self) -> Response;
33}
34
35/// Wrapper that can be used to turn an `IntoResponse` type into
36/// something that implements `Into<Response>`.
37pub struct StaticResponseFactory<T>(pub T);
38
39impl<T: IntoResponse> From<StaticResponseFactory<T>> for Response {
40    fn from(value: StaticResponseFactory<T>) -> Self {
41        value.0.into_response()
42    }
43}
44
45impl<T: fmt::Debug> fmt::Debug for StaticResponseFactory<T> {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        f.debug_tuple("StaticResponseFactory")
48            .field(&self.0)
49            .finish()
50    }
51}
52
53impl<T: Clone> Clone for StaticResponseFactory<T> {
54    fn clone(&self) -> Self {
55        Self(self.0.clone())
56    }
57}
58
59impl IntoResponse for StatusCode {
60    fn into_response(self) -> Response {
61        let mut res = ().into_response();
62        *res.status_mut() = self;
63        res
64    }
65}
66
67impl IntoResponse for () {
68    fn into_response(self) -> Response {
69        Body::empty().into_response()
70    }
71}
72
73impl IntoResponse for Infallible {
74    fn into_response(self) -> Response {
75        match self {}
76    }
77}
78
79impl<T, E> IntoResponse for Result<T, E>
80where
81    T: IntoResponse,
82    E: IntoResponse,
83{
84    fn into_response(self) -> Response {
85        match self {
86            Ok(value) => value.into_response(),
87            Err(err) => err.into_response(),
88        }
89    }
90}
91
92impl<B> IntoResponse for Response<B>
93where
94    B: http_body::Body<Data = Bytes, Error: Into<BoxError>> + Send + Sync + 'static,
95{
96    fn into_response(self) -> Response {
97        self.map(Body::new)
98    }
99}
100
101impl IntoResponse for http::response::Parts {
102    fn into_response(self) -> Response {
103        Response::from_parts(self, Body::empty())
104    }
105}
106
107impl IntoResponse for Body {
108    fn into_response(self) -> Response {
109        Response::new(self)
110    }
111}
112
113impl IntoResponse for &'static str {
114    fn into_response(self) -> Response {
115        Cow::Borrowed(self).into_response()
116    }
117}
118
119impl IntoResponse for String {
120    fn into_response(self) -> Response {
121        Cow::<'static, str>::Owned(self).into_response()
122    }
123}
124
125impl IntoResponse for Box<str> {
126    fn into_response(self) -> Response {
127        String::from(self).into_response()
128    }
129}
130
131impl IntoResponse for Cow<'static, str> {
132    fn into_response(self) -> Response {
133        let mut res = Body::from(self).into_response();
134        res.headers_mut().insert(
135            header::CONTENT_TYPE,
136            HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()),
137        );
138        res
139    }
140}
141
142impl IntoResponse for Bytes {
143    fn into_response(self) -> Response {
144        let mut res = Body::from(self).into_response();
145        res.headers_mut().insert(
146            header::CONTENT_TYPE,
147            HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()),
148        );
149        res
150    }
151}
152
153impl IntoResponse for BytesMut {
154    fn into_response(self) -> Response {
155        self.freeze().into_response()
156    }
157}
158
159impl<T, U> IntoResponse for Chain<T, U>
160where
161    T: Buf + Unpin + Send + Sync + 'static,
162    U: Buf + Unpin + Send + Sync + 'static,
163{
164    fn into_response(self) -> Response {
165        let (first, second) = self.into_inner();
166        let mut res = Response::new(Body::new(BytesChainBody {
167            first: Some(first),
168            second: Some(second),
169        }));
170        res.headers_mut().insert(
171            header::CONTENT_TYPE,
172            HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()),
173        );
174        res
175    }
176}
177
178struct BytesChainBody<T, U> {
179    first: Option<T>,
180    second: Option<U>,
181}
182
183impl<T, U> http_body::Body for BytesChainBody<T, U>
184where
185    T: Buf + Unpin,
186    U: Buf + Unpin,
187{
188    type Data = Bytes;
189    type Error = Infallible;
190
191    fn poll_frame(
192        mut self: Pin<&mut Self>,
193        _cx: &mut Context<'_>,
194    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
195        if let Some(mut buf) = self.first.take() {
196            let bytes = buf.copy_to_bytes(buf.remaining());
197            return Poll::Ready(Some(Ok(Frame::data(bytes))));
198        }
199
200        if let Some(mut buf) = self.second.take() {
201            let bytes = buf.copy_to_bytes(buf.remaining());
202            return Poll::Ready(Some(Ok(Frame::data(bytes))));
203        }
204
205        Poll::Ready(None)
206    }
207
208    fn is_end_stream(&self) -> bool {
209        self.first.is_none() && self.second.is_none()
210    }
211
212    fn size_hint(&self) -> SizeHint {
213        match (self.first.as_ref(), self.second.as_ref()) {
214            (Some(first), Some(second)) => {
215                let total_size = first.remaining() + second.remaining();
216                SizeHint::with_exact(total_size as u64)
217            }
218            (Some(buf), None) => SizeHint::with_exact(buf.remaining() as u64),
219            (None, Some(buf)) => SizeHint::with_exact(buf.remaining() as u64),
220            (None, None) => SizeHint::with_exact(0),
221        }
222    }
223}
224
225impl IntoResponse for &'static [u8] {
226    fn into_response(self) -> Response {
227        Cow::Borrowed(self).into_response()
228    }
229}
230
231impl<const N: usize> IntoResponse for &'static [u8; N] {
232    fn into_response(self) -> Response {
233        self.as_slice().into_response()
234    }
235}
236
237impl<const N: usize> IntoResponse for [u8; N] {
238    fn into_response(self) -> Response {
239        self.to_vec().into_response()
240    }
241}
242
243impl IntoResponse for Vec<u8> {
244    fn into_response(self) -> Response {
245        Cow::<'static, [u8]>::Owned(self).into_response()
246    }
247}
248
249impl IntoResponse for Box<[u8]> {
250    fn into_response(self) -> Response {
251        Vec::from(self).into_response()
252    }
253}
254
255impl IntoResponse for Cow<'static, [u8]> {
256    fn into_response(self) -> Response {
257        let mut res = Body::from(self).into_response();
258        res.headers_mut().insert(
259            header::CONTENT_TYPE,
260            HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()),
261        );
262        res
263    }
264}
265
266impl<R> IntoResponse for (StatusCode, R)
267where
268    R: IntoResponse,
269{
270    fn into_response(self) -> Response {
271        let mut res = self.1.into_response();
272        *res.status_mut() = self.0;
273        res
274    }
275}
276
277impl IntoResponse for HeaderMap {
278    fn into_response(self) -> Response {
279        let mut res = ().into_response();
280        *res.headers_mut() = self;
281        res
282    }
283}
284
285impl IntoResponse for Extensions {
286    fn into_response(self) -> Response {
287        let mut res = ().into_response();
288        *res.extensions_mut() = self;
289        res
290    }
291}
292
293impl<K, V, const N: usize> IntoResponse for [(K, V); N]
294where
295    K: TryInto<HeaderName, Error: fmt::Display>,
296    V: TryInto<HeaderValue, Error: fmt::Display>,
297{
298    fn into_response(self) -> Response {
299        (self, ()).into_response()
300    }
301}
302
303impl<R> IntoResponse for (http::response::Parts, R)
304where
305    R: IntoResponse,
306{
307    fn into_response(self) -> Response {
308        let (parts, res) = self;
309        (parts.status, parts.headers, parts.extensions, res).into_response()
310    }
311}
312
313impl<R> IntoResponse for (http::response::Response<()>, R)
314where
315    R: IntoResponse,
316{
317    fn into_response(self) -> Response {
318        let (template, res) = self;
319        let (parts, ()) = template.into_parts();
320        (parts, res).into_response()
321    }
322}
323
324impl<R> IntoResponse for (R,)
325where
326    R: IntoResponse,
327{
328    fn into_response(self) -> Response {
329        let (res,) = self;
330        res.into_response()
331    }
332}
333
334macro_rules! impl_into_response {
335    ( $($ty:ident),* $(,)? ) => {
336        #[allow(non_snake_case)]
337        impl<R, $($ty,)*> IntoResponse for ($($ty),*, R)
338        where
339            $( $ty: IntoResponseParts, )*
340            R: IntoResponse,
341        {
342            fn into_response(self) -> Response {
343                let ($($ty),*, res) = self;
344
345                let res = res.into_response();
346                let parts = ResponseParts { res };
347
348                $(
349                    let parts = match $ty.into_response_parts(parts) {
350                        Ok(parts) => parts,
351                        Err(err) => {
352                            return err.into_response();
353                        }
354                    };
355                )*
356
357                parts.res
358            }
359        }
360
361        #[allow(non_snake_case)]
362        impl<R, $($ty,)*> IntoResponse for (StatusCode, $($ty),*, R)
363        where
364            $( $ty: IntoResponseParts, )*
365            R: IntoResponse,
366        {
367            fn into_response(self) -> Response {
368                let (status, $($ty),*, res) = self;
369
370                let res = res.into_response();
371                let parts = ResponseParts { res };
372
373                $(
374                    let parts = match $ty.into_response_parts(parts) {
375                        Ok(parts) => parts,
376                        Err(err) => {
377                            return err.into_response();
378                        }
379                    };
380                )*
381
382                (status, parts.res).into_response()
383            }
384        }
385
386        #[allow(non_snake_case)]
387        impl<R, $($ty,)*> IntoResponse for (http::response::Parts, $($ty),*, R)
388        where
389            $( $ty: IntoResponseParts, )*
390            R: IntoResponse,
391        {
392            fn into_response(self) -> Response {
393                let (outer_parts, $($ty),*, res) = self;
394
395                let res = res.into_response();
396                let parts = ResponseParts { res };
397                $(
398                    let parts = match $ty.into_response_parts(parts) {
399                        Ok(parts) => parts,
400                        Err(err) => {
401                            return err.into_response();
402                        }
403                    };
404                )*
405
406                (outer_parts, parts.res).into_response()
407            }
408        }
409
410        #[allow(non_snake_case)]
411        impl<R, $($ty,)*> IntoResponse for (http::response::Response<()>, $($ty),*, R)
412        where
413            $( $ty: IntoResponseParts, )*
414            R: IntoResponse,
415        {
416            fn into_response(self) -> Response {
417                let (template, $($ty),*, res) = self;
418                let (parts, ()) = template.into_parts();
419                (parts, $($ty),*, res).into_response()
420            }
421        }
422    }
423}
424
425all_the_tuples_no_last_special_case!(impl_into_response);
426
427macro_rules! impl_into_response_either {
428    ($id:ident, $($param:ident),+ $(,)?) => {
429        impl<$($param),+> IntoResponse for rama_core::combinators::$id<$($param),+>
430        where
431            $($param: IntoResponse),+
432        {
433            fn into_response(self) -> Response {
434                match self {
435                    $(
436                        rama_core::combinators::$id::$param(val) => val.into_response(),
437                    )+
438                }
439            }
440        }
441    };
442}
443
444rama_core::combinators::impl_either!(impl_into_response_either);
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449    use rama_core::combinators::Either;
450
451    #[test]
452    fn test_either_into_response() {
453        let left: Either<&'static str, Vec<u8>> = Either::A("hello");
454        let right: Either<&'static str, Vec<u8>> = Either::B(vec![1, 2, 3]);
455
456        let left_res = left.into_response();
457        assert_eq!(
458            left_res.headers().get(header::CONTENT_TYPE).unwrap(),
459            mime::TEXT_PLAIN_UTF_8.as_ref()
460        );
461
462        let right_res = right.into_response();
463        assert_eq!(
464            right_res.headers().get(header::CONTENT_TYPE).unwrap(),
465            mime::APPLICATION_OCTET_STREAM.as_ref()
466        );
467    }
468
469    #[test]
470    fn test_either3_into_response() {
471        use rama_core::combinators::Either3;
472
473        let a: Either3<&'static str, Vec<u8>, StatusCode> = Either3::A("hello");
474        let b: Either3<&'static str, Vec<u8>, StatusCode> = Either3::B(vec![1, 2, 3]);
475        let c: Either3<&'static str, Vec<u8>, StatusCode> = Either3::C(StatusCode::NOT_FOUND);
476
477        let a_res = a.into_response();
478        assert_eq!(
479            a_res.headers().get(header::CONTENT_TYPE).unwrap(),
480            mime::TEXT_PLAIN_UTF_8.as_ref()
481        );
482
483        let b_res = b.into_response();
484        assert_eq!(
485            b_res.headers().get(header::CONTENT_TYPE).unwrap(),
486            mime::APPLICATION_OCTET_STREAM.as_ref()
487        );
488
489        let c_res = c.into_response();
490        assert_eq!(c_res.status(), StatusCode::NOT_FOUND);
491    }
492}