rama_http/service/web/endpoint/response/
into_response_parts.rs

1use super::IntoResponse;
2use crate::{
3    Response, StatusCode,
4    dep::http::Extensions,
5    header::{HeaderMap, HeaderName, HeaderValue},
6};
7use rama_utils::macros::all_the_tuples_no_last_special_case;
8use std::{convert::Infallible, fmt};
9
10/// Trait for adding headers and extensions to a response.
11///
12/// # Example
13///
14/// ```rust
15/// use rama_http_types::{
16///     StatusCode, HeaderName, HeaderValue, Response,
17/// };
18/// use rama_http::service::web::response::{
19///     ResponseParts, IntoResponse, IntoResponseParts,
20/// };
21///
22/// // Hypothetical helper type for setting a single header
23/// struct SetHeader<'a>(&'a str, &'a str);
24///
25/// impl<'a> IntoResponseParts for SetHeader<'a> {
26///     type Error = (StatusCode, String);
27///
28///     fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
29///         match (self.0.parse::<HeaderName>(), self.1.parse::<HeaderValue>()) {
30///             (Ok(name), Ok(value)) => {
31///                 res.headers_mut().insert(name, value);
32///             },
33///             (Err(_), _) => {
34///                 return Err((
35///                     StatusCode::INTERNAL_SERVER_ERROR,
36///                     format!("Invalid header name {}", self.0),
37///                 ));
38///             },
39///             (_, Err(_)) => {
40///                 return Err((
41///                     StatusCode::INTERNAL_SERVER_ERROR,
42///                     format!("Invalid header value {}", self.1),
43///                 ));
44///             },
45///         }
46///
47///         Ok(res)
48///     }
49/// }
50///
51/// // Its also recommended to implement `IntoResponse` so `SetHeader` can be used on its own as
52/// // the response
53/// impl<'a> IntoResponse for SetHeader<'a> {
54///     fn into_response(self) -> Response {
55///         // This gives an empty response with the header
56///         (self, ()).into_response()
57///     }
58/// }
59///
60/// // We can now return `SetHeader` in responses
61/// //
62/// // Note that returning `impl IntoResponse` might be easier if the response has many parts to
63/// // it. The return type is written out here for clarity.
64/// async fn handler() -> (SetHeader<'static>, SetHeader<'static>, &'static str) {
65///     (
66///         SetHeader("server", "rama"),
67///         SetHeader("x-foo", "custom"),
68///         "body",
69///     )
70/// }
71///
72/// // Or on its own as the whole response
73/// async fn other_handler() -> SetHeader<'static> {
74///     SetHeader("x-foo", "custom")
75/// }
76/// ```
77pub trait IntoResponseParts {
78    /// The type returned in the event of an error.
79    ///
80    /// This can be used to fallibly convert types into headers or extensions.
81    type Error: IntoResponse;
82
83    /// Set parts of the response
84    fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error>;
85}
86
87impl<T> IntoResponseParts for Option<T>
88where
89    T: IntoResponseParts,
90{
91    type Error = T::Error;
92
93    fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
94        if let Some(inner) = self {
95            inner.into_response_parts(res)
96        } else {
97            Ok(res)
98        }
99    }
100}
101
102/// Parts of a response.
103///
104/// Used with [`IntoResponseParts`].
105#[derive(Debug)]
106pub struct ResponseParts {
107    pub(crate) res: Response,
108}
109
110impl ResponseParts {
111    /// Gets a reference to the response headers.
112    pub fn headers(&self) -> &HeaderMap {
113        self.res.headers()
114    }
115
116    /// Gets a mutable reference to the response headers.
117    pub fn headers_mut(&mut self) -> &mut HeaderMap {
118        self.res.headers_mut()
119    }
120
121    /// Gets a reference to the response extensions.
122    pub fn extensions(&self) -> &Extensions {
123        self.res.extensions()
124    }
125
126    /// Gets a mutable reference to the response extensions.
127    pub fn extensions_mut(&mut self) -> &mut Extensions {
128        self.res.extensions_mut()
129    }
130}
131
132impl IntoResponseParts for HeaderMap {
133    type Error = Infallible;
134
135    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
136        res.headers_mut().extend(self);
137        Ok(res)
138    }
139}
140
141impl<K, V, const N: usize> IntoResponseParts for [(K, V); N]
142where
143    K: TryInto<HeaderName, Error: fmt::Display>,
144    V: TryInto<HeaderValue, Error: fmt::Display>,
145{
146    type Error = TryIntoHeaderError<K::Error, V::Error>;
147
148    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
149        for (key, value) in self {
150            let key = key.try_into().map_err(TryIntoHeaderError::key)?;
151            let value = value.try_into().map_err(TryIntoHeaderError::value)?;
152            res.headers_mut().insert(key, value);
153        }
154
155        Ok(res)
156    }
157}
158
159/// Error returned if converting a value to a header fails.
160pub struct TryIntoHeaderError<K, V> {
161    kind: TryIntoHeaderErrorKind<K, V>,
162}
163
164impl<K: fmt::Debug, V: fmt::Debug> fmt::Debug for TryIntoHeaderError<K, V> {
165    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166        f.debug_struct("TryIntoHeaderError")
167            .field("kind", &self.kind)
168            .finish()
169    }
170}
171
172impl<K, V> TryIntoHeaderError<K, V> {
173    pub(super) fn key(err: K) -> Self {
174        Self {
175            kind: TryIntoHeaderErrorKind::Key(err),
176        }
177    }
178
179    pub(super) fn value(err: V) -> Self {
180        Self {
181            kind: TryIntoHeaderErrorKind::Value(err),
182        }
183    }
184}
185
186enum TryIntoHeaderErrorKind<K, V> {
187    Key(K),
188    Value(V),
189}
190
191impl<K: fmt::Debug, V: fmt::Debug> fmt::Debug for TryIntoHeaderErrorKind<K, V> {
192    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193        match self {
194            Self::Key(key) => write!(f, "TryIntoHeaderErrorKind::Key({key:?})"),
195            Self::Value(value) => write!(f, "TryIntoHeaderErrorKind::Value({value:?})"),
196        }
197    }
198}
199
200impl<K, V> IntoResponse for TryIntoHeaderError<K, V>
201where
202    K: fmt::Display,
203    V: fmt::Display,
204{
205    fn into_response(self) -> Response {
206        match self.kind {
207            TryIntoHeaderErrorKind::Key(inner) => {
208                (StatusCode::INTERNAL_SERVER_ERROR, inner.to_string()).into_response()
209            }
210            TryIntoHeaderErrorKind::Value(inner) => {
211                (StatusCode::INTERNAL_SERVER_ERROR, inner.to_string()).into_response()
212            }
213        }
214    }
215}
216
217impl<K, V> fmt::Display for TryIntoHeaderError<K, V> {
218    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219        match self.kind {
220            TryIntoHeaderErrorKind::Key(_) => write!(f, "failed to convert key to a header name"),
221            TryIntoHeaderErrorKind::Value(_) => {
222                write!(f, "failed to convert value to a header value")
223            }
224        }
225    }
226}
227
228impl<K, V> std::error::Error for TryIntoHeaderError<K, V>
229where
230    K: std::error::Error + 'static,
231    V: std::error::Error + 'static,
232{
233    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
234        match &self.kind {
235            TryIntoHeaderErrorKind::Key(inner) => Some(inner),
236            TryIntoHeaderErrorKind::Value(inner) => Some(inner),
237        }
238    }
239}
240
241macro_rules! impl_into_response_parts {
242    ( $($ty:ident),* $(,)? ) => {
243        #[allow(non_snake_case)]
244        impl<$($ty,)*> IntoResponseParts for ($($ty,)*)
245        where
246            $( $ty: IntoResponseParts, )*
247        {
248            type Error = Response;
249
250            fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
251                let ($($ty,)*) = self;
252
253                $(
254                    let res = match $ty.into_response_parts(res) {
255                        Ok(res) => res,
256                        Err(err) => {
257                            return Err(err.into_response());
258                        }
259                    };
260                )*
261
262                Ok(res)
263            }
264        }
265    }
266}
267
268all_the_tuples_no_last_special_case!(impl_into_response_parts);
269
270impl IntoResponseParts for Extensions {
271    type Error = Infallible;
272
273    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
274        res.extensions_mut().extend(self);
275        Ok(res)
276    }
277}