rama_http/service/web/endpoint/response/
into_response.rs1use super::{IntoResponseParts, ResponseParts};
2use crate::dep::http_body::{Frame, SizeHint};
3use crate::dep::mime;
4use crate::{Body, Response};
5use crate::{
6 StatusCode,
7 dep::http::Extensions,
8 header::{self, HeaderMap, HeaderName, HeaderValue},
9};
10use bytes::{Buf, Bytes, BytesMut, buf::Chain};
11use rama_core::error::BoxError;
12use rama_http_types::dep::{http, http_body};
13use rama_utils::macros::all_the_tuples_no_last_special_case;
14use std::{
15 borrow::Cow,
16 convert::Infallible,
17 fmt,
18 pin::Pin,
19 task::{Context, Poll},
20};
21
22pub trait IntoResponse {
31 fn into_response(self) -> Response;
33}
34
35pub struct StaticResponseFactory<T>(pub T);
38
39impl<T: IntoResponse> From<StaticResponseFactory<T>> for Response {
40 fn from(value: StaticResponseFactory<T>) -> Self {
41 value.0.into_response()
42 }
43}
44
45impl<T: fmt::Debug> fmt::Debug for StaticResponseFactory<T> {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 f.debug_tuple("StaticResponseFactory")
48 .field(&self.0)
49 .finish()
50 }
51}
52
53impl<T: Clone> Clone for StaticResponseFactory<T> {
54 fn clone(&self) -> Self {
55 Self(self.0.clone())
56 }
57}
58
59impl IntoResponse for StatusCode {
60 fn into_response(self) -> Response {
61 let mut res = ().into_response();
62 *res.status_mut() = self;
63 res
64 }
65}
66
67impl IntoResponse for () {
68 fn into_response(self) -> Response {
69 Body::empty().into_response()
70 }
71}
72
73impl IntoResponse for Infallible {
74 fn into_response(self) -> Response {
75 match self {}
76 }
77}
78
79impl<T, E> IntoResponse for Result<T, E>
80where
81 T: IntoResponse,
82 E: IntoResponse,
83{
84 fn into_response(self) -> Response {
85 match self {
86 Ok(value) => value.into_response(),
87 Err(err) => err.into_response(),
88 }
89 }
90}
91
92impl<B> IntoResponse for Response<B>
93where
94 B: http_body::Body<Data = Bytes, Error: Into<BoxError>> + Send + Sync + 'static,
95{
96 fn into_response(self) -> Response {
97 self.map(Body::new)
98 }
99}
100
101impl IntoResponse for http::response::Parts {
102 fn into_response(self) -> Response {
103 Response::from_parts(self, Body::empty())
104 }
105}
106
107impl IntoResponse for Body {
108 fn into_response(self) -> Response {
109 Response::new(self)
110 }
111}
112
113impl IntoResponse for &'static str {
114 fn into_response(self) -> Response {
115 Cow::Borrowed(self).into_response()
116 }
117}
118
119impl IntoResponse for String {
120 fn into_response(self) -> Response {
121 Cow::<'static, str>::Owned(self).into_response()
122 }
123}
124
125impl IntoResponse for Box<str> {
126 fn into_response(self) -> Response {
127 String::from(self).into_response()
128 }
129}
130
131impl IntoResponse for Cow<'static, str> {
132 fn into_response(self) -> Response {
133 let mut res = Body::from(self).into_response();
134 res.headers_mut().insert(
135 header::CONTENT_TYPE,
136 HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()),
137 );
138 res
139 }
140}
141
142impl IntoResponse for Bytes {
143 fn into_response(self) -> Response {
144 let mut res = Body::from(self).into_response();
145 res.headers_mut().insert(
146 header::CONTENT_TYPE,
147 HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()),
148 );
149 res
150 }
151}
152
153impl IntoResponse for BytesMut {
154 fn into_response(self) -> Response {
155 self.freeze().into_response()
156 }
157}
158
159impl<T, U> IntoResponse for Chain<T, U>
160where
161 T: Buf + Unpin + Send + Sync + 'static,
162 U: Buf + Unpin + Send + Sync + 'static,
163{
164 fn into_response(self) -> Response {
165 let (first, second) = self.into_inner();
166 let mut res = Response::new(Body::new(BytesChainBody {
167 first: Some(first),
168 second: Some(second),
169 }));
170 res.headers_mut().insert(
171 header::CONTENT_TYPE,
172 HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()),
173 );
174 res
175 }
176}
177
178struct BytesChainBody<T, U> {
179 first: Option<T>,
180 second: Option<U>,
181}
182
183impl<T, U> http_body::Body for BytesChainBody<T, U>
184where
185 T: Buf + Unpin,
186 U: Buf + Unpin,
187{
188 type Data = Bytes;
189 type Error = Infallible;
190
191 fn poll_frame(
192 mut self: Pin<&mut Self>,
193 _cx: &mut Context<'_>,
194 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
195 if let Some(mut buf) = self.first.take() {
196 let bytes = buf.copy_to_bytes(buf.remaining());
197 return Poll::Ready(Some(Ok(Frame::data(bytes))));
198 }
199
200 if let Some(mut buf) = self.second.take() {
201 let bytes = buf.copy_to_bytes(buf.remaining());
202 return Poll::Ready(Some(Ok(Frame::data(bytes))));
203 }
204
205 Poll::Ready(None)
206 }
207
208 fn is_end_stream(&self) -> bool {
209 self.first.is_none() && self.second.is_none()
210 }
211
212 fn size_hint(&self) -> SizeHint {
213 match (self.first.as_ref(), self.second.as_ref()) {
214 (Some(first), Some(second)) => {
215 let total_size = first.remaining() + second.remaining();
216 SizeHint::with_exact(total_size as u64)
217 }
218 (Some(buf), None) => SizeHint::with_exact(buf.remaining() as u64),
219 (None, Some(buf)) => SizeHint::with_exact(buf.remaining() as u64),
220 (None, None) => SizeHint::with_exact(0),
221 }
222 }
223}
224
225impl IntoResponse for &'static [u8] {
226 fn into_response(self) -> Response {
227 Cow::Borrowed(self).into_response()
228 }
229}
230
231impl<const N: usize> IntoResponse for &'static [u8; N] {
232 fn into_response(self) -> Response {
233 self.as_slice().into_response()
234 }
235}
236
237impl<const N: usize> IntoResponse for [u8; N] {
238 fn into_response(self) -> Response {
239 self.to_vec().into_response()
240 }
241}
242
243impl IntoResponse for Vec<u8> {
244 fn into_response(self) -> Response {
245 Cow::<'static, [u8]>::Owned(self).into_response()
246 }
247}
248
249impl IntoResponse for Box<[u8]> {
250 fn into_response(self) -> Response {
251 Vec::from(self).into_response()
252 }
253}
254
255impl IntoResponse for Cow<'static, [u8]> {
256 fn into_response(self) -> Response {
257 let mut res = Body::from(self).into_response();
258 res.headers_mut().insert(
259 header::CONTENT_TYPE,
260 HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()),
261 );
262 res
263 }
264}
265
266impl<R> IntoResponse for (StatusCode, R)
267where
268 R: IntoResponse,
269{
270 fn into_response(self) -> Response {
271 let mut res = self.1.into_response();
272 *res.status_mut() = self.0;
273 res
274 }
275}
276
277impl IntoResponse for HeaderMap {
278 fn into_response(self) -> Response {
279 let mut res = ().into_response();
280 *res.headers_mut() = self;
281 res
282 }
283}
284
285impl IntoResponse for Extensions {
286 fn into_response(self) -> Response {
287 let mut res = ().into_response();
288 *res.extensions_mut() = self;
289 res
290 }
291}
292
293impl<K, V, const N: usize> IntoResponse for [(K, V); N]
294where
295 K: TryInto<HeaderName, Error: fmt::Display>,
296 V: TryInto<HeaderValue, Error: fmt::Display>,
297{
298 fn into_response(self) -> Response {
299 (self, ()).into_response()
300 }
301}
302
303impl<R> IntoResponse for (http::response::Parts, R)
304where
305 R: IntoResponse,
306{
307 fn into_response(self) -> Response {
308 let (parts, res) = self;
309 (parts.status, parts.headers, parts.extensions, res).into_response()
310 }
311}
312
313impl<R> IntoResponse for (http::response::Response<()>, R)
314where
315 R: IntoResponse,
316{
317 fn into_response(self) -> Response {
318 let (template, res) = self;
319 let (parts, ()) = template.into_parts();
320 (parts, res).into_response()
321 }
322}
323
324impl<R> IntoResponse for (R,)
325where
326 R: IntoResponse,
327{
328 fn into_response(self) -> Response {
329 let (res,) = self;
330 res.into_response()
331 }
332}
333
334macro_rules! impl_into_response {
335 ( $($ty:ident),* $(,)? ) => {
336 #[allow(non_snake_case)]
337 impl<R, $($ty,)*> IntoResponse for ($($ty),*, R)
338 where
339 $( $ty: IntoResponseParts, )*
340 R: IntoResponse,
341 {
342 fn into_response(self) -> Response {
343 let ($($ty),*, res) = self;
344
345 let res = res.into_response();
346 let parts = ResponseParts { res };
347
348 $(
349 let parts = match $ty.into_response_parts(parts) {
350 Ok(parts) => parts,
351 Err(err) => {
352 return err.into_response();
353 }
354 };
355 )*
356
357 parts.res
358 }
359 }
360
361 #[allow(non_snake_case)]
362 impl<R, $($ty,)*> IntoResponse for (StatusCode, $($ty),*, R)
363 where
364 $( $ty: IntoResponseParts, )*
365 R: IntoResponse,
366 {
367 fn into_response(self) -> Response {
368 let (status, $($ty),*, res) = self;
369
370 let res = res.into_response();
371 let parts = ResponseParts { res };
372
373 $(
374 let parts = match $ty.into_response_parts(parts) {
375 Ok(parts) => parts,
376 Err(err) => {
377 return err.into_response();
378 }
379 };
380 )*
381
382 (status, parts.res).into_response()
383 }
384 }
385
386 #[allow(non_snake_case)]
387 impl<R, $($ty,)*> IntoResponse for (http::response::Parts, $($ty),*, R)
388 where
389 $( $ty: IntoResponseParts, )*
390 R: IntoResponse,
391 {
392 fn into_response(self) -> Response {
393 let (outer_parts, $($ty),*, res) = self;
394
395 let res = res.into_response();
396 let parts = ResponseParts { res };
397 $(
398 let parts = match $ty.into_response_parts(parts) {
399 Ok(parts) => parts,
400 Err(err) => {
401 return err.into_response();
402 }
403 };
404 )*
405
406 (outer_parts, parts.res).into_response()
407 }
408 }
409
410 #[allow(non_snake_case)]
411 impl<R, $($ty,)*> IntoResponse for (http::response::Response<()>, $($ty),*, R)
412 where
413 $( $ty: IntoResponseParts, )*
414 R: IntoResponse,
415 {
416 fn into_response(self) -> Response {
417 let (template, $($ty),*, res) = self;
418 let (parts, ()) = template.into_parts();
419 (parts, $($ty),*, res).into_response()
420 }
421 }
422 }
423}
424
425all_the_tuples_no_last_special_case!(impl_into_response);
426
427macro_rules! impl_into_response_either {
428 ($id:ident, $($param:ident),+ $(,)?) => {
429 impl<$($param),+> IntoResponse for rama_core::combinators::$id<$($param),+>
430 where
431 $($param: IntoResponse),+
432 {
433 fn into_response(self) -> Response {
434 match self {
435 $(
436 rama_core::combinators::$id::$param(val) => val.into_response(),
437 )+
438 }
439 }
440 }
441 };
442}
443
444rama_core::combinators::impl_either!(impl_into_response_either);
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use rama_core::combinators::Either;
450
451 #[test]
452 fn test_either_into_response() {
453 let left: Either<&'static str, Vec<u8>> = Either::A("hello");
454 let right: Either<&'static str, Vec<u8>> = Either::B(vec![1, 2, 3]);
455
456 let left_res = left.into_response();
457 assert_eq!(
458 left_res.headers().get(header::CONTENT_TYPE).unwrap(),
459 mime::TEXT_PLAIN_UTF_8.as_ref()
460 );
461
462 let right_res = right.into_response();
463 assert_eq!(
464 right_res.headers().get(header::CONTENT_TYPE).unwrap(),
465 mime::APPLICATION_OCTET_STREAM.as_ref()
466 );
467 }
468
469 #[test]
470 fn test_either3_into_response() {
471 use rama_core::combinators::Either3;
472
473 let a: Either3<&'static str, Vec<u8>, StatusCode> = Either3::A("hello");
474 let b: Either3<&'static str, Vec<u8>, StatusCode> = Either3::B(vec![1, 2, 3]);
475 let c: Either3<&'static str, Vec<u8>, StatusCode> = Either3::C(StatusCode::NOT_FOUND);
476
477 let a_res = a.into_response();
478 assert_eq!(
479 a_res.headers().get(header::CONTENT_TYPE).unwrap(),
480 mime::TEXT_PLAIN_UTF_8.as_ref()
481 );
482
483 let b_res = b.into_response();
484 assert_eq!(
485 b_res.headers().get(header::CONTENT_TYPE).unwrap(),
486 mime::APPLICATION_OCTET_STREAM.as_ref()
487 );
488
489 let c_res = c.into_response();
490 assert_eq!(c_res.status(), StatusCode::NOT_FOUND);
491 }
492}