reinhardt_di/params/
header.rs1use async_trait::async_trait;
4use reinhardt_http::Request;
5use serde::de::DeserializeOwned;
6use std::collections::HashMap;
7use std::fmt::{self, Debug};
8use std::ops::Deref;
9
10use super::{
11 ParamContext, ParamError, ParamErrorContext, ParamResult, ParamType, extract::FromRequest,
12};
13
14pub struct Header<T>(pub T);
16
17impl<T> Header<T> {
18 pub fn into_inner(self) -> T {
30 self.0
31 }
32}
33
34impl<T> Deref for Header<T> {
35 type Target = T;
36 fn deref(&self) -> &Self::Target {
37 &self.0
38 }
39}
40
41impl<T: Debug> Debug for Header<T> {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 self.0.fmt(f)
44 }
45}
46
47pub struct HeaderStruct<T>(pub T);
68
69impl<T> HeaderStruct<T> {
70 pub fn into_inner(self) -> T {
90 self.0
91 }
92}
93
94impl<T> Deref for HeaderStruct<T> {
95 type Target = T;
96 fn deref(&self) -> &Self::Target {
97 &self.0
98 }
99}
100
101impl<T: Debug> Debug for HeaderStruct<T> {
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 self.0.fmt(f)
104 }
105}
106
107fn headers_to_map(req: &Request) -> HashMap<String, String> {
110 let mut result = HashMap::new();
111
112 for (name, value) in req.headers.iter() {
113 if let Ok(value_str) = value.to_str() {
114 result.insert(name.as_str().to_lowercase(), value_str.to_string());
115 }
116 }
117
118 result
119}
120
121#[async_trait]
122impl<T> FromRequest for HeaderStruct<T>
123where
124 T: DeserializeOwned + Send,
125{
126 async fn from_request(req: &Request, _ctx: &ParamContext) -> ParamResult<Self> {
127 let headers_map = headers_to_map(req);
128
129 let encoded = serde_urlencoded::to_string(&headers_map).map_err(|e| {
132 ParamError::ParseError(Box::new(
133 ParamErrorContext::new(
134 ParamType::Header,
135 format!("Failed to encode headers: {}", e),
136 )
137 .with_expected_type::<T>()
138 .with_source(Box::new(e)),
139 ))
140 })?;
141
142 serde_urlencoded::from_str(&encoded)
143 .map(HeaderStruct)
144 .map_err(|e| ParamError::url_encoding::<T>(ParamType::Header, e, Some(encoded.clone())))
145 }
146}
147
148#[async_trait]
149impl FromRequest for Header<String> {
150 async fn from_request(req: &Request, ctx: &ParamContext) -> ParamResult<Self> {
151 let name = ctx.get_header_name::<String>().ok_or_else(|| {
152 ParamError::MissingParameter(
153 "Header name not specified in ParamContext for this type".to_string(),
154 )
155 })?;
156
157 let value = req
158 .headers
159 .get(name)
160 .and_then(|v| v.to_str().ok())
161 .ok_or_else(|| ParamError::MissingParameter(name.to_string()))?;
162
163 Ok(Header(value.to_string()))
164 }
165}
166
167#[async_trait]
168impl FromRequest for Header<Option<String>> {
169 async fn from_request(req: &Request, ctx: &ParamContext) -> ParamResult<Self> {
170 let name = match ctx.get_header_name::<String>() {
171 Some(n) => n,
172 None => return Ok(Header(None)),
173 };
174 let maybe = req.headers.get(name).and_then(|v| v.to_str().ok());
175 Ok(Header(maybe.map(|s| s.to_string())))
176 }
177}
178
179#[cfg(feature = "validation")]
181impl<T> super::validation::WithValidation for Header<T> {}