axum_extra/
typed_header.rs1use axum_core::{
4 extract::{FromRequestParts, OptionalFromRequestParts},
5 response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
6};
7use headers::{Header, HeaderMapExt};
8use http::{request::Parts, StatusCode};
9use std::convert::Infallible;
10
11#[cfg(feature = "typed-header")]
53#[derive(Debug, Clone, Copy)]
54#[must_use]
55pub struct TypedHeader<T>(pub T);
56
57impl<T, S> FromRequestParts<S> for TypedHeader<T>
58where
59 T: Header,
60 S: Send + Sync,
61{
62 type Rejection = TypedHeaderRejection;
63
64 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
65 let mut values = parts.headers.get_all(T::name()).iter();
66 let is_missing = values.size_hint() == (0, Some(0));
67 T::decode(&mut values)
68 .map(Self)
69 .map_err(|err| TypedHeaderRejection {
70 name: T::name(),
71 reason: if is_missing {
72 TypedHeaderRejectionReason::Missing
74 } else {
75 TypedHeaderRejectionReason::Error(err)
76 },
77 })
78 }
79}
80
81impl<T, S> OptionalFromRequestParts<S> for TypedHeader<T>
82where
83 T: Header,
84 S: Send + Sync,
85{
86 type Rejection = TypedHeaderRejection;
87
88 async fn from_request_parts(
89 parts: &mut Parts,
90 _state: &S,
91 ) -> Result<Option<Self>, Self::Rejection> {
92 let mut values = parts.headers.get_all(T::name()).iter();
93 let is_missing = values.size_hint() == (0, Some(0));
94 match T::decode(&mut values) {
95 Ok(res) => Ok(Some(Self(res))),
96 Err(_) if is_missing => Ok(None),
97 Err(err) => Err(TypedHeaderRejection {
98 name: T::name(),
99 reason: TypedHeaderRejectionReason::Error(err),
100 }),
101 }
102 }
103}
104
105axum_core::__impl_deref!(TypedHeader);
106
107impl<T> IntoResponseParts for TypedHeader<T>
108where
109 T: Header,
110{
111 type Error = Infallible;
112
113 fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
114 res.headers_mut().typed_insert(self.0);
115 Ok(res)
116 }
117}
118
119impl<T> IntoResponse for TypedHeader<T>
120where
121 T: Header,
122{
123 fn into_response(self) -> Response {
124 let mut res = ().into_response();
125 res.headers_mut().typed_insert(self.0);
126 res
127 }
128}
129
130#[cfg(feature = "typed-header")]
132#[derive(Debug)]
133pub struct TypedHeaderRejection {
134 name: &'static http::header::HeaderName,
135 reason: TypedHeaderRejectionReason,
136}
137
138impl TypedHeaderRejection {
139 #[must_use]
141 pub fn name(&self) -> &http::header::HeaderName {
142 self.name
143 }
144
145 #[must_use]
147 pub fn reason(&self) -> &TypedHeaderRejectionReason {
148 &self.reason
149 }
150
151 #[must_use]
155 pub fn is_missing(&self) -> bool {
156 self.reason.is_missing()
157 }
158}
159
160#[cfg(feature = "typed-header")]
162#[derive(Debug)]
163#[non_exhaustive]
164pub enum TypedHeaderRejectionReason {
165 Missing,
167 Error(headers::Error),
169}
170
171impl TypedHeaderRejectionReason {
172 #[must_use]
176 pub fn is_missing(&self) -> bool {
177 matches!(self, Self::Missing)
178 }
179}
180
181impl IntoResponse for TypedHeaderRejection {
182 fn into_response(self) -> Response {
183 let status = StatusCode::BAD_REQUEST;
184 let body = self.to_string();
185 axum_core::__log_rejection!(rejection_type = Self, body_text = body, status = status,);
186 (status, body).into_response()
187 }
188}
189
190impl std::fmt::Display for TypedHeaderRejection {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 match &self.reason {
193 TypedHeaderRejectionReason::Missing => {
194 write!(f, "Header of type `{}` was missing", self.name)
195 }
196 TypedHeaderRejectionReason::Error(err) => {
197 write!(f, "{err} ({})", self.name)
198 }
199 }
200 }
201}
202
203impl std::error::Error for TypedHeaderRejection {
204 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
205 match &self.reason {
206 TypedHeaderRejectionReason::Error(err) => Some(err),
207 TypedHeaderRejectionReason::Missing => None,
208 }
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use crate::test_helpers::*;
216 use axum::{routing::get, Router};
217
218 #[tokio::test]
219 async fn typed_header() {
220 async fn handle(
221 TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
222 TypedHeader(cookies): TypedHeader<headers::Cookie>,
223 ) -> impl IntoResponse {
224 let user_agent = user_agent.as_str();
225 let cookies = cookies.iter().collect::<Vec<_>>();
226 format!("User-Agent={user_agent:?}, Cookie={cookies:?}")
227 }
228
229 let app = Router::new().route("/", get(handle));
230
231 let client = TestClient::new(app);
232
233 let res = client
234 .get("/")
235 .header("user-agent", "foobar")
236 .header("cookie", "a=1; b=2")
237 .header("cookie", "c=3")
238 .await;
239 let body = res.text().await;
240 assert_eq!(
241 body,
242 r#"User-Agent="foobar", Cookie=[("a", "1"), ("b", "2"), ("c", "3")]"#
243 );
244
245 let res = client.get("/").header("user-agent", "foobar").await;
246 let body = res.text().await;
247 assert_eq!(body, r#"User-Agent="foobar", Cookie=[]"#);
248
249 let res = client.get("/").header("cookie", "a=1").await;
250 let body = res.text().await;
251 assert_eq!(body, "Header of type `user-agent` was missing");
252 }
253}