rama_http_types/response/
into_response_parts.rs

1use crate::{
2    StatusCode,
3    dep::http::Extensions,
4    header::{HeaderMap, HeaderName, HeaderValue},
5    response::{IntoResponse, Response},
6};
7use rama_utils::macros::all_the_tuples_no_last_special_case;
8use std::{convert::Infallible, fmt};
9
10/// Trait for adding headers and extensions to a response.
11///
12/// # Example
13///
14/// ```rust
15/// use rama_http_types::{
16///     response::{ResponseParts, IntoResponse, IntoResponseParts, Response},
17///     StatusCode, HeaderName, HeaderValue,
18/// };
19///
20/// // Hypothetical helper type for setting a single header
21/// struct SetHeader<'a>(&'a str, &'a str);
22///
23/// impl<'a> IntoResponseParts for SetHeader<'a> {
24///     type Error = (StatusCode, String);
25///
26///     fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
27///         match (self.0.parse::<HeaderName>(), self.1.parse::<HeaderValue>()) {
28///             (Ok(name), Ok(value)) => {
29///                 res.headers_mut().insert(name, value);
30///             },
31///             (Err(_), _) => {
32///                 return Err((
33///                     StatusCode::INTERNAL_SERVER_ERROR,
34///                     format!("Invalid header name {}", self.0),
35///                 ));
36///             },
37///             (_, Err(_)) => {
38///                 return Err((
39///                     StatusCode::INTERNAL_SERVER_ERROR,
40///                     format!("Invalid header value {}", self.1),
41///                 ));
42///             },
43///         }
44///
45///         Ok(res)
46///     }
47/// }
48///
49/// // Its also recommended to implement `IntoResponse` so `SetHeader` can be used on its own as
50/// // the response
51/// impl<'a> IntoResponse for SetHeader<'a> {
52///     fn into_response(self) -> Response {
53///         // This gives an empty response with the header
54///         (self, ()).into_response()
55///     }
56/// }
57///
58/// // We can now return `SetHeader` in responses
59/// //
60/// // Note that returning `impl IntoResponse` might be easier if the response has many parts to
61/// // it. The return type is written out here for clarity.
62/// async fn handler() -> (SetHeader<'static>, SetHeader<'static>, &'static str) {
63///     (
64///         SetHeader("server", "rama"),
65///         SetHeader("x-foo", "custom"),
66///         "body",
67///     )
68/// }
69///
70/// // Or on its own as the whole response
71/// async fn other_handler() -> SetHeader<'static> {
72///     SetHeader("x-foo", "custom")
73/// }
74/// ```
75pub trait IntoResponseParts {
76    /// The type returned in the event of an error.
77    ///
78    /// This can be used to fallibly convert types into headers or extensions.
79    type Error: IntoResponse;
80
81    /// Set parts of the response
82    fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error>;
83}
84
85impl<T> IntoResponseParts for Option<T>
86where
87    T: IntoResponseParts,
88{
89    type Error = T::Error;
90
91    fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
92        if let Some(inner) = self {
93            inner.into_response_parts(res)
94        } else {
95            Ok(res)
96        }
97    }
98}
99
100/// Parts of a response.
101///
102/// Used with [`IntoResponseParts`].
103#[derive(Debug)]
104pub struct ResponseParts {
105    pub(crate) res: Response,
106}
107
108impl ResponseParts {
109    /// Gets a reference to the response headers.
110    pub fn headers(&self) -> &HeaderMap {
111        self.res.headers()
112    }
113
114    /// Gets a mutable reference to the response headers.
115    pub fn headers_mut(&mut self) -> &mut HeaderMap {
116        self.res.headers_mut()
117    }
118
119    /// Gets a reference to the response extensions.
120    pub fn extensions(&self) -> &Extensions {
121        self.res.extensions()
122    }
123
124    /// Gets a mutable reference to the response extensions.
125    pub fn extensions_mut(&mut self) -> &mut Extensions {
126        self.res.extensions_mut()
127    }
128}
129
130impl IntoResponseParts for HeaderMap {
131    type Error = Infallible;
132
133    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
134        res.headers_mut().extend(self);
135        Ok(res)
136    }
137}
138
139impl<K, V, const N: usize> IntoResponseParts for [(K, V); N]
140where
141    K: TryInto<HeaderName, Error: fmt::Display>,
142    V: TryInto<HeaderValue, Error: fmt::Display>,
143{
144    type Error = TryIntoHeaderError<K::Error, V::Error>;
145
146    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
147        for (key, value) in self {
148            let key = key.try_into().map_err(TryIntoHeaderError::key)?;
149            let value = value.try_into().map_err(TryIntoHeaderError::value)?;
150            res.headers_mut().insert(key, value);
151        }
152
153        Ok(res)
154    }
155}
156
157/// Error returned if converting a value to a header fails.
158pub struct TryIntoHeaderError<K, V> {
159    kind: TryIntoHeaderErrorKind<K, V>,
160}
161
162impl<K: fmt::Debug, V: fmt::Debug> fmt::Debug for TryIntoHeaderError<K, V> {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        f.debug_struct("TryIntoHeaderError")
165            .field("kind", &self.kind)
166            .finish()
167    }
168}
169
170impl<K, V> TryIntoHeaderError<K, V> {
171    pub(super) fn key(err: K) -> Self {
172        Self {
173            kind: TryIntoHeaderErrorKind::Key(err),
174        }
175    }
176
177    pub(super) fn value(err: V) -> Self {
178        Self {
179            kind: TryIntoHeaderErrorKind::Value(err),
180        }
181    }
182}
183
184enum TryIntoHeaderErrorKind<K, V> {
185    Key(K),
186    Value(V),
187}
188
189impl<K: fmt::Debug, V: fmt::Debug> fmt::Debug for TryIntoHeaderErrorKind<K, V> {
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        match self {
192            Self::Key(key) => write!(f, "TryIntoHeaderErrorKind::Key({key:?})"),
193            Self::Value(value) => write!(f, "TryIntoHeaderErrorKind::Value({value:?})"),
194        }
195    }
196}
197
198impl<K, V> IntoResponse for TryIntoHeaderError<K, V>
199where
200    K: fmt::Display,
201    V: fmt::Display,
202{
203    fn into_response(self) -> Response {
204        match self.kind {
205            TryIntoHeaderErrorKind::Key(inner) => {
206                (StatusCode::INTERNAL_SERVER_ERROR, inner.to_string()).into_response()
207            }
208            TryIntoHeaderErrorKind::Value(inner) => {
209                (StatusCode::INTERNAL_SERVER_ERROR, inner.to_string()).into_response()
210            }
211        }
212    }
213}
214
215impl<K, V> fmt::Display for TryIntoHeaderError<K, V> {
216    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217        match self.kind {
218            TryIntoHeaderErrorKind::Key(_) => write!(f, "failed to convert key to a header name"),
219            TryIntoHeaderErrorKind::Value(_) => {
220                write!(f, "failed to convert value to a header value")
221            }
222        }
223    }
224}
225
226impl<K, V> std::error::Error for TryIntoHeaderError<K, V>
227where
228    K: std::error::Error + 'static,
229    V: std::error::Error + 'static,
230{
231    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
232        match &self.kind {
233            TryIntoHeaderErrorKind::Key(inner) => Some(inner),
234            TryIntoHeaderErrorKind::Value(inner) => Some(inner),
235        }
236    }
237}
238
239macro_rules! impl_into_response_parts {
240    ( $($ty:ident),* $(,)? ) => {
241        #[allow(non_snake_case)]
242        impl<$($ty,)*> IntoResponseParts for ($($ty,)*)
243        where
244            $( $ty: IntoResponseParts, )*
245        {
246            type Error = Response;
247
248            fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
249                let ($($ty,)*) = self;
250
251                $(
252                    let res = match $ty.into_response_parts(res) {
253                        Ok(res) => res,
254                        Err(err) => {
255                            return Err(err.into_response());
256                        }
257                    };
258                )*
259
260                Ok(res)
261            }
262        }
263    }
264}
265
266all_the_tuples_no_last_special_case!(impl_into_response_parts);
267
268impl IntoResponseParts for Extensions {
269    type Error = Infallible;
270
271    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
272        res.extensions_mut().extend(self);
273        Ok(res)
274    }
275}