Skip to main content

reinhardt_di/params/
header.rs

1//! Header parameter extraction
2
3use async_trait::async_trait;
4use reinhardt_http::Request;
5use serde::de::DeserializeOwned;
6use std::collections::HashMap;
7use std::fmt::{self, Debug};
8use std::ops::Deref;
9
10use super::{
11	ParamContext, ParamError, ParamErrorContext, ParamResult, ParamType, extract::FromRequest,
12};
13
14/// Extract a value from request headers
15pub struct Header<T>(pub T);
16
17impl<T> Header<T> {
18	/// Unwrap the Header and return the inner value
19	///
20	/// # Examples
21	///
22	/// ```
23	/// use reinhardt_di::params::Header;
24	///
25	/// let header = Header(String::from("application/json"));
26	/// let inner = header.into_inner();
27	/// assert_eq!(inner, "application/json");
28	/// ```
29	pub fn into_inner(self) -> T {
30		self.0
31	}
32}
33
34impl<T> Deref for Header<T> {
35	type Target = T;
36	fn deref(&self) -> &Self::Target {
37		&self.0
38	}
39}
40
41impl<T: Debug> Debug for Header<T> {
42	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43		self.0.fmt(f)
44	}
45}
46
47/// HeaderStruct extracts multiple headers into a struct
48///
49/// # Example
50///
51/// ```rust,no_run
52/// # use reinhardt_di::params::HeaderStruct;
53/// # use serde::Deserialize;
54/// #[derive(Deserialize)]
55/// struct MyHeaders {
56///     #[serde(rename = "x-api-key")]
57///     api_key: String,
58///
59///     #[serde(rename = "user-agent")]
60///     user_agent: Option<String>,
61/// }
62///
63/// async fn handler(headers: HeaderStruct<MyHeaders>) {
64///     let api_key = &headers.api_key;
65/// }
66/// ```
67pub struct HeaderStruct<T>(pub T);
68
69impl<T> HeaderStruct<T> {
70	/// Unwrap the HeaderStruct and return the inner value
71	///
72	/// # Examples
73	///
74	/// ```
75	/// use reinhardt_di::params::HeaderStruct;
76	/// use serde::Deserialize;
77	///
78	/// #[derive(Deserialize, Debug, PartialEq)]
79	/// struct MyHeaders {
80	///     content_type: String,
81	/// }
82	///
83	/// let headers = HeaderStruct(MyHeaders {
84	///     content_type: "text/html".to_string(),
85	/// });
86	/// let inner = headers.into_inner();
87	/// assert_eq!(inner.content_type, "text/html");
88	/// ```
89	pub fn into_inner(self) -> T {
90		self.0
91	}
92}
93
94impl<T> Deref for HeaderStruct<T> {
95	type Target = T;
96	fn deref(&self) -> &Self::Target {
97		&self.0
98	}
99}
100
101impl<T: Debug> Debug for HeaderStruct<T> {
102	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103		self.0.fmt(f)
104	}
105}
106
107/// Convert headers to a map for deserialization
108/// Header names are converted to lowercase
109fn headers_to_map(req: &Request) -> HashMap<String, String> {
110	let mut result = HashMap::new();
111
112	for (name, value) in req.headers.iter() {
113		if let Ok(value_str) = value.to_str() {
114			result.insert(name.as_str().to_lowercase(), value_str.to_string());
115		}
116	}
117
118	result
119}
120
121#[async_trait]
122impl<T> FromRequest for HeaderStruct<T>
123where
124	T: DeserializeOwned + Send,
125{
126	async fn from_request(req: &Request, _ctx: &ParamContext) -> ParamResult<Self> {
127		let headers_map = headers_to_map(req);
128
129		// Use serde_urlencoded for proper string-to-type deserialization
130		// This handles type coercion naturally (e.g., "123" -> i64)
131		let encoded = serde_urlencoded::to_string(&headers_map).map_err(|e| {
132			ParamError::ParseError(Box::new(
133				ParamErrorContext::new(
134					ParamType::Header,
135					format!("Failed to encode headers: {}", e),
136				)
137				.with_expected_type::<T>()
138				.with_source(Box::new(e)),
139			))
140		})?;
141
142		serde_urlencoded::from_str(&encoded)
143			.map(HeaderStruct)
144			.map_err(|e| ParamError::url_encoding::<T>(ParamType::Header, e, Some(encoded.clone())))
145	}
146}
147
148#[async_trait]
149impl FromRequest for Header<String> {
150	async fn from_request(req: &Request, ctx: &ParamContext) -> ParamResult<Self> {
151		let name = ctx.get_header_name::<String>().ok_or_else(|| {
152			ParamError::MissingParameter(
153				"Header name not specified in ParamContext for this type".to_string(),
154			)
155		})?;
156
157		let value = req
158			.headers
159			.get(name)
160			.and_then(|v| v.to_str().ok())
161			.ok_or_else(|| ParamError::MissingParameter(name.to_string()))?;
162
163		Ok(Header(value.to_string()))
164	}
165}
166
167#[async_trait]
168impl FromRequest for Header<Option<String>> {
169	async fn from_request(req: &Request, ctx: &ParamContext) -> ParamResult<Self> {
170		let name = match ctx.get_header_name::<String>() {
171			Some(n) => n,
172			None => return Ok(Header(None)),
173		};
174		let maybe = req.headers.get(name).and_then(|v| v.to_str().ok());
175		Ok(Header(maybe.map(|s| s.to_string())))
176	}
177}
178
179// Implement WithValidation trait for Header
180#[cfg(feature = "validation")]
181impl<T> super::validation::WithValidation for Header<T> {}