1use std::ops::{Deref, DerefMut};
2
3use headers::{Header, HeaderMapExt};
4
5use crate::{FromRequest, Request, RequestBody, Result, error::ParseTypedHeaderError};
6
7#[derive(Debug)]
42pub struct TypedHeader<T>(pub T);
43
44impl<T> Deref for TypedHeader<T> {
45 type Target = T;
46
47 fn deref(&self) -> &Self::Target {
48 &self.0
49 }
50}
51
52impl<T> DerefMut for TypedHeader<T> {
53 fn deref_mut(&mut self) -> &mut Self::Target {
54 &mut self.0
55 }
56}
57
58impl<T: Header> TypedHeader<T> {
59 async fn internal_from_request(req: &Request) -> Result<Self, ParseTypedHeaderError> {
60 let value = req.headers().typed_try_get::<T>()?;
61 Ok(Self(value.ok_or_else(|| {
62 ParseTypedHeaderError::HeaderRequired(T::name().to_string())
63 })?))
64 }
65}
66
67impl<'a, T: Header> FromRequest<'a> for TypedHeader<T> {
68 async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
69 Self::internal_from_request(req).await.map_err(Into::into)
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use crate::{
77 handler,
78 test::TestClient,
79 web::headers::{ContentLength, Host},
80 };
81
82 #[tokio::test]
83 async fn test_typed_header_extractor() {
84 #[handler(internal)]
85 async fn index(content_length: TypedHeader<ContentLength>) {
86 assert_eq!(content_length.0.0, 3);
87 }
88
89 let cli = TestClient::new(index);
90 let resp = cli
91 .get("/")
92 .header("content-length", 3)
93 .body("abc")
94 .send()
95 .await;
96 resp.assert_status_is_ok();
97 }
98
99 #[tokio::test]
100 async fn test_typed_header_extractor_error() {
101 let (req, mut body) = Request::builder().body("abc").split();
102 let res = TypedHeader::<Host>::from_request(&req, &mut body).await;
103
104 match res.unwrap_err().downcast_ref::<ParseTypedHeaderError>() {
105 Some(ParseTypedHeaderError::HeaderRequired(name)) if name == "host" => {}
106 _ => panic!(),
107 }
108 }
109}