axum_extra/
typed_header.rs

1//! Extractor and response for typed headers.
2
3use axum_core::{
4    extract::{FromRequestParts, OptionalFromRequestParts},
5    response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
6};
7use headers::{Header, HeaderMapExt};
8use http::{request::Parts, StatusCode};
9use std::convert::Infallible;
10
11/// Extractor and response that works with typed header values from [`headers`].
12///
13/// # As extractor
14///
15/// In general, it's recommended to extract only the needed headers via `TypedHeader` rather than
16/// removing all headers with the `HeaderMap` extractor.
17///
18/// ```rust,no_run
19/// use axum::{
20///     routing::get,
21///     Router,
22/// };
23/// use headers::UserAgent;
24/// use axum_extra::TypedHeader;
25///
26/// async fn users_teams_show(
27///     TypedHeader(user_agent): TypedHeader<UserAgent>,
28/// ) {
29///     // ...
30/// }
31///
32/// let app = Router::new().route("/users/{user_id}/team/{team_id}", get(users_teams_show));
33/// # let _: Router = app;
34/// ```
35///
36/// # As response
37///
38/// ```rust
39/// use axum::{
40///     response::IntoResponse,
41/// };
42/// use headers::ContentType;
43/// use axum_extra::TypedHeader;
44///
45/// async fn handler() -> (TypedHeader<ContentType>, &'static str) {
46///     (
47///         TypedHeader(ContentType::text_utf8()),
48///         "Hello, World!",
49///     )
50/// }
51/// ```
52#[cfg(feature = "typed-header")]
53#[derive(Debug, Clone, Copy)]
54#[must_use]
55pub struct TypedHeader<T>(pub T);
56
57impl<T, S> FromRequestParts<S> for TypedHeader<T>
58where
59    T: Header,
60    S: Send + Sync,
61{
62    type Rejection = TypedHeaderRejection;
63
64    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
65        let mut values = parts.headers.get_all(T::name()).iter();
66        let is_missing = values.size_hint() == (0, Some(0));
67        T::decode(&mut values)
68            .map(Self)
69            .map_err(|err| TypedHeaderRejection {
70                name: T::name(),
71                reason: if is_missing {
72                    // Report a more precise rejection for the missing header case.
73                    TypedHeaderRejectionReason::Missing
74                } else {
75                    TypedHeaderRejectionReason::Error(err)
76                },
77            })
78    }
79}
80
81impl<T, S> OptionalFromRequestParts<S> for TypedHeader<T>
82where
83    T: Header,
84    S: Send + Sync,
85{
86    type Rejection = TypedHeaderRejection;
87
88    async fn from_request_parts(
89        parts: &mut Parts,
90        _state: &S,
91    ) -> Result<Option<Self>, Self::Rejection> {
92        let mut values = parts.headers.get_all(T::name()).iter();
93        let is_missing = values.size_hint() == (0, Some(0));
94        match T::decode(&mut values) {
95            Ok(res) => Ok(Some(Self(res))),
96            Err(_) if is_missing => Ok(None),
97            Err(err) => Err(TypedHeaderRejection {
98                name: T::name(),
99                reason: TypedHeaderRejectionReason::Error(err),
100            }),
101        }
102    }
103}
104
105axum_core::__impl_deref!(TypedHeader);
106
107impl<T> IntoResponseParts for TypedHeader<T>
108where
109    T: Header,
110{
111    type Error = Infallible;
112
113    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
114        res.headers_mut().typed_insert(self.0);
115        Ok(res)
116    }
117}
118
119impl<T> IntoResponse for TypedHeader<T>
120where
121    T: Header,
122{
123    fn into_response(self) -> Response {
124        let mut res = ().into_response();
125        res.headers_mut().typed_insert(self.0);
126        res
127    }
128}
129
130/// Rejection used for [`TypedHeader`].
131#[cfg(feature = "typed-header")]
132#[derive(Debug)]
133pub struct TypedHeaderRejection {
134    name: &'static http::header::HeaderName,
135    reason: TypedHeaderRejectionReason,
136}
137
138impl TypedHeaderRejection {
139    /// Name of the header that caused the rejection
140    #[must_use]
141    pub fn name(&self) -> &http::header::HeaderName {
142        self.name
143    }
144
145    /// Reason why the header extraction has failed
146    #[must_use]
147    pub fn reason(&self) -> &TypedHeaderRejectionReason {
148        &self.reason
149    }
150
151    /// Returns `true` if the typed header rejection reason is [`Missing`].
152    ///
153    /// [`Missing`]: TypedHeaderRejectionReason::Missing
154    #[must_use]
155    pub fn is_missing(&self) -> bool {
156        self.reason.is_missing()
157    }
158}
159
160/// Additional information regarding a [`TypedHeaderRejection`]
161#[cfg(feature = "typed-header")]
162#[derive(Debug)]
163#[non_exhaustive]
164pub enum TypedHeaderRejectionReason {
165    /// The header was missing from the HTTP request
166    Missing,
167    /// An error occurred when parsing the header from the HTTP request
168    Error(headers::Error),
169}
170
171impl TypedHeaderRejectionReason {
172    /// Returns `true` if the typed header rejection reason is [`Missing`].
173    ///
174    /// [`Missing`]: TypedHeaderRejectionReason::Missing
175    #[must_use]
176    pub fn is_missing(&self) -> bool {
177        matches!(self, Self::Missing)
178    }
179}
180
181impl IntoResponse for TypedHeaderRejection {
182    fn into_response(self) -> Response {
183        let status = StatusCode::BAD_REQUEST;
184        let body = self.to_string();
185        axum_core::__log_rejection!(rejection_type = Self, body_text = body, status = status,);
186        (status, body).into_response()
187    }
188}
189
190impl std::fmt::Display for TypedHeaderRejection {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        match &self.reason {
193            TypedHeaderRejectionReason::Missing => {
194                write!(f, "Header of type `{}` was missing", self.name)
195            }
196            TypedHeaderRejectionReason::Error(err) => {
197                write!(f, "{err} ({})", self.name)
198            }
199        }
200    }
201}
202
203impl std::error::Error for TypedHeaderRejection {
204    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
205        match &self.reason {
206            TypedHeaderRejectionReason::Error(err) => Some(err),
207            TypedHeaderRejectionReason::Missing => None,
208        }
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use crate::test_helpers::*;
216    use axum::{routing::get, Router};
217
218    #[tokio::test]
219    async fn typed_header() {
220        async fn handle(
221            TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
222            TypedHeader(cookies): TypedHeader<headers::Cookie>,
223        ) -> impl IntoResponse {
224            let user_agent = user_agent.as_str();
225            let cookies = cookies.iter().collect::<Vec<_>>();
226            format!("User-Agent={user_agent:?}, Cookie={cookies:?}")
227        }
228
229        let app = Router::new().route("/", get(handle));
230
231        let client = TestClient::new(app);
232
233        let res = client
234            .get("/")
235            .header("user-agent", "foobar")
236            .header("cookie", "a=1; b=2")
237            .header("cookie", "c=3")
238            .await;
239        let body = res.text().await;
240        assert_eq!(
241            body,
242            r#"User-Agent="foobar", Cookie=[("a", "1"), ("b", "2"), ("c", "3")]"#
243        );
244
245        let res = client.get("/").header("user-agent", "foobar").await;
246        let body = res.text().await;
247        assert_eq!(body, r#"User-Agent="foobar", Cookie=[]"#);
248
249        let res = client.get("/").header("cookie", "a=1").await;
250        let body = res.text().await;
251        assert_eq!(body, "Header of type `user-agent` was missing");
252    }
253}