Skip to main content

reinhardt_di/params/
query.rs

1//! Query parameter extraction
2
3use async_trait::async_trait;
4use reinhardt_http::Request;
5use serde::de::DeserializeOwned;
6use std::fmt::{self, Debug};
7use std::ops::Deref;
8
9use super::{ParamContext, ParamError, ParamResult, ParamType, extract::FromRequest};
10
11#[cfg(feature = "multi-value-arrays")]
12use std::collections::HashMap;
13
14/// Extract query parameters from the URL
15///
16/// With the `multi-value-arrays` feature (enabled by default), repeated query
17/// parameters are properly parsed into vectors. For example, `?q=5&q=6` will be
18/// parsed as `vec![5, 6]` when the target field type is `Vec<T>`.
19///
20/// # Example
21///
22/// ```rust
23/// use reinhardt_di::params::Query;
24/// # use serde::Deserialize;
25/// #[derive(Deserialize)]
26/// struct Pagination {
27///     page: Option<i32>,
28///     per_page: Option<i32>,
29/// }
30///
31/// let pagination = Pagination { page: Some(2), per_page: Some(25) };
32/// let query = Query(pagination);
33/// let page = query.page.unwrap_or(1);
34/// let per_page = query.per_page.unwrap_or(10);
35/// assert_eq!(page, 2);
36/// assert_eq!(per_page, 25);
37/// ```
38///
39/// # Multi-value Parameters
40///
41/// ```rust
42/// use reinhardt_di::params::Query;
43/// # use serde::Deserialize;
44/// #[derive(Deserialize)]
45/// struct SearchQuery {
46///     q: Vec<i64>,  // Supports repeated keys: ?q=5&q=6
47/// }
48///
49/// let search = SearchQuery { q: vec![5, 6] };
50/// let query = Query(search);
51/// assert_eq!(query.q, vec![5, 6]);
52/// ```
53pub struct Query<T>(pub T);
54
55impl<T> Query<T> {
56	/// Unwrap the Query and return the inner value
57	///
58	/// # Examples
59	///
60	/// ```
61	/// use reinhardt_di::params::Query;
62	/// use serde::Deserialize;
63	///
64	/// #[derive(Deserialize, Debug, PartialEq)]
65	/// struct Pagination {
66	///     page: i32,
67	///     per_page: i32,
68	/// }
69	///
70	/// let query = Query(Pagination { page: 1, per_page: 10 });
71	/// let inner = query.into_inner();
72	/// assert_eq!(inner.page, 1);
73	/// assert_eq!(inner.per_page, 10);
74	/// ```
75	pub fn into_inner(self) -> T {
76		self.0
77	}
78}
79
80impl<T> Deref for Query<T> {
81	type Target = T;
82
83	fn deref(&self) -> &Self::Target {
84		&self.0
85	}
86}
87
88impl<T: Debug> Debug for Query<T> {
89	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90		self.0.fmt(f)
91	}
92}
93
94impl<T: Clone> Clone for Query<T> {
95	fn clone(&self) -> Self {
96		Query(self.0.clone())
97	}
98}
99
100#[cfg(feature = "multi-value-arrays")]
101/// Parse query string supporting multiple values for the same key
102/// Converts q=5&q=6 into {"q": ["5", "6"]}
103fn parse_query_multi_value(query_string: &str) -> HashMap<String, Vec<String>> {
104	let mut result: HashMap<String, Vec<String>> = HashMap::new();
105
106	for (key, value) in form_urlencoded::parse(query_string.as_bytes()) {
107		result
108			.entry(key.into_owned())
109			.or_default()
110			.push(value.into_owned());
111	}
112
113	result
114}
115
116#[cfg(feature = "multi-value-arrays")]
117/// Convert a string value to the most appropriate JSON value type
118/// Tries number types first, then falls back to string
119fn string_to_json_value(s: &str) -> serde_json::Value {
120	// Try parsing as integer
121	if let Ok(i) = s.parse::<i64>() {
122		return serde_json::Value::Number(i.into());
123	}
124	// Try parsing as float
125	if let Ok(f) = s.parse::<f64>()
126		&& let Some(num) = serde_json::Number::from_f64(f)
127	{
128		return serde_json::Value::Number(num);
129	}
130	// Try parsing as boolean
131	if let Ok(b) = s.parse::<bool>() {
132		return serde_json::Value::Bool(b);
133	}
134	// Fall back to string
135	serde_json::Value::String(s.to_string())
136}
137
138#[cfg(feature = "multi-value-arrays")]
139/// Convert multi-value map to JSON for deserialization
140/// This allows serde to properly deserialize arrays and type coercion
141fn multi_value_to_json_value(multi_map: &HashMap<String, Vec<String>>) -> serde_json::Value {
142	let mut result = serde_json::Map::new();
143
144	for (key, values) in multi_map {
145		let value = if values.len() == 1 {
146			// Single value: convert to appropriate type
147			string_to_json_value(&values[0])
148		} else {
149			// Multiple values: use as array with type conversion
150			serde_json::Value::Array(values.iter().map(|v| string_to_json_value(v)).collect())
151		};
152		result.insert(key.clone(), value);
153	}
154
155	serde_json::Value::Object(result)
156}
157
158#[async_trait]
159impl<T> FromRequest for Query<T>
160where
161	T: DeserializeOwned + Send,
162{
163	async fn from_request(req: &Request, _ctx: &ParamContext) -> ParamResult<Self> {
164		// Extract query string from request
165		let query_string = req.uri.query().unwrap_or("");
166
167		// Deserialize query string to T
168		// If multi-value-arrays feature is enabled, parse repeated parameters as arrays
169		// (e.g., q=5&q=6 -> vec![5, 6])
170		#[cfg(feature = "multi-value-arrays")]
171		let result = {
172			let multi_map = parse_query_multi_value(query_string);
173			let json_value = multi_value_to_json_value(&multi_map);
174
175			serde_json::from_value(json_value).map(Query).map_err(|e| {
176				let raw_value = if query_string.is_empty() {
177					None
178				} else {
179					Some(query_string.to_string())
180				};
181				let mut ctx = super::ParamErrorContext::new(ParamType::Query, e.to_string())
182					.with_expected_type::<T>()
183					.with_source(Box::new(e));
184				if let Some(raw) = raw_value {
185					ctx = ctx.with_raw_value(raw);
186				}
187				ParamError::InvalidParameter(Box::new(ctx))
188			})
189		};
190
191		#[cfg(not(feature = "multi-value-arrays"))]
192		let result = serde_urlencoded::from_str(query_string)
193			.map(Query)
194			.map_err(|e| {
195				let raw_value = if query_string.is_empty() {
196					None
197				} else {
198					Some(query_string.to_string())
199				};
200				ParamError::url_encoding::<T>(ParamType::Query, e, raw_value)
201			});
202
203		result
204	}
205}
206
207// Implement WithValidation trait for Query
208#[cfg(feature = "validation")]
209impl<T> super::validation::WithValidation for Query<T> {}
210
211#[cfg(test)]
212mod tests {
213	use super::*;
214	use serde::Deserialize;
215
216	// Allow dead_code: fields are accessed via Deserialize derive, not directly in code
217	#[allow(dead_code)]
218	#[derive(Debug, Deserialize, PartialEq)]
219	struct TestQuery {
220		page: Option<i32>,
221		limit: Option<i32>,
222		search: Option<String>,
223	}
224}