poem/web/
typed_header.rs

1use std::ops::{Deref, DerefMut};
2
3use headers::{Header, HeaderMapExt};
4
5use crate::{FromRequest, Request, RequestBody, Result, error::ParseTypedHeaderError};
6
7/// An extractor that extracts a typed header value.
8///
9/// # Errors
10///
11/// - [`ParseTypedHeaderError`]
12///
13/// # Example
14///
15/// ```
16/// use poem::{
17///     Endpoint, Request, Route, get, handler,
18///     http::{StatusCode, header},
19///     test::TestClient,
20///     web::{TypedHeader, headers::Host},
21/// };
22///
23/// #[handler]
24/// fn index(TypedHeader(host): TypedHeader<Host>) -> String {
25///     host.hostname().to_string()
26/// }
27///
28/// let app = Route::new().at("/", get(index));
29/// let cli = TestClient::new(app);
30///
31/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
32/// let resp = cli
33///     .get("/")
34///     .header(header::HOST, "example.com")
35///     .send()
36///     .await;
37/// resp.assert_status_is_ok();
38/// resp.assert_text("example.com").await;
39/// # });
40/// ```
41#[derive(Debug)]
42pub struct TypedHeader<T>(pub T);
43
44impl<T> Deref for TypedHeader<T> {
45    type Target = T;
46
47    fn deref(&self) -> &Self::Target {
48        &self.0
49    }
50}
51
52impl<T> DerefMut for TypedHeader<T> {
53    fn deref_mut(&mut self) -> &mut Self::Target {
54        &mut self.0
55    }
56}
57
58impl<T: Header> TypedHeader<T> {
59    async fn internal_from_request(req: &Request) -> Result<Self, ParseTypedHeaderError> {
60        let value = req.headers().typed_try_get::<T>()?;
61        Ok(Self(value.ok_or_else(|| {
62            ParseTypedHeaderError::HeaderRequired(T::name().to_string())
63        })?))
64    }
65}
66
67impl<'a, T: Header> FromRequest<'a> for TypedHeader<T> {
68    async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
69        Self::internal_from_request(req).await.map_err(Into::into)
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use crate::{
77        handler,
78        test::TestClient,
79        web::headers::{ContentLength, Host},
80    };
81
82    #[tokio::test]
83    async fn test_typed_header_extractor() {
84        #[handler(internal)]
85        async fn index(content_length: TypedHeader<ContentLength>) {
86            assert_eq!(content_length.0.0, 3);
87        }
88
89        let cli = TestClient::new(index);
90        let resp = cli
91            .get("/")
92            .header("content-length", 3)
93            .body("abc")
94            .send()
95            .await;
96        resp.assert_status_is_ok();
97    }
98
99    #[tokio::test]
100    async fn test_typed_header_extractor_error() {
101        let (req, mut body) = Request::builder().body("abc").split();
102        let res = TypedHeader::<Host>::from_request(&req, &mut body).await;
103
104        match res.unwrap_err().downcast_ref::<ParseTypedHeaderError>() {
105            Some(ParseTypedHeaderError::HeaderRequired(name)) if name == "host" => {}
106            _ => panic!(),
107        }
108    }
109}