reinhardt_di/params/
query.rs1use async_trait::async_trait;
4use reinhardt_http::Request;
5use serde::de::DeserializeOwned;
6use std::fmt::{self, Debug};
7use std::ops::Deref;
8
9use super::{ParamContext, ParamError, ParamResult, ParamType, extract::FromRequest};
10
11#[cfg(feature = "multi-value-arrays")]
12use std::collections::HashMap;
13
14pub struct Query<T>(pub T);
54
55impl<T> Query<T> {
56 pub fn into_inner(self) -> T {
76 self.0
77 }
78}
79
80impl<T> Deref for Query<T> {
81 type Target = T;
82
83 fn deref(&self) -> &Self::Target {
84 &self.0
85 }
86}
87
88impl<T: Debug> Debug for Query<T> {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 self.0.fmt(f)
91 }
92}
93
94impl<T: Clone> Clone for Query<T> {
95 fn clone(&self) -> Self {
96 Query(self.0.clone())
97 }
98}
99
100#[cfg(feature = "multi-value-arrays")]
101fn parse_query_multi_value(query_string: &str) -> HashMap<String, Vec<String>> {
104 let mut result: HashMap<String, Vec<String>> = HashMap::new();
105
106 for (key, value) in form_urlencoded::parse(query_string.as_bytes()) {
107 result
108 .entry(key.into_owned())
109 .or_default()
110 .push(value.into_owned());
111 }
112
113 result
114}
115
116#[cfg(feature = "multi-value-arrays")]
117fn string_to_json_value(s: &str) -> serde_json::Value {
120 if let Ok(i) = s.parse::<i64>() {
122 return serde_json::Value::Number(i.into());
123 }
124 if let Ok(f) = s.parse::<f64>()
126 && let Some(num) = serde_json::Number::from_f64(f)
127 {
128 return serde_json::Value::Number(num);
129 }
130 if let Ok(b) = s.parse::<bool>() {
132 return serde_json::Value::Bool(b);
133 }
134 serde_json::Value::String(s.to_string())
136}
137
138#[cfg(feature = "multi-value-arrays")]
139fn multi_value_to_json_value(multi_map: &HashMap<String, Vec<String>>) -> serde_json::Value {
142 let mut result = serde_json::Map::new();
143
144 for (key, values) in multi_map {
145 let value = if values.len() == 1 {
146 string_to_json_value(&values[0])
148 } else {
149 serde_json::Value::Array(values.iter().map(|v| string_to_json_value(v)).collect())
151 };
152 result.insert(key.clone(), value);
153 }
154
155 serde_json::Value::Object(result)
156}
157
158#[async_trait]
159impl<T> FromRequest for Query<T>
160where
161 T: DeserializeOwned + Send,
162{
163 async fn from_request(req: &Request, _ctx: &ParamContext) -> ParamResult<Self> {
164 let query_string = req.uri.query().unwrap_or("");
166
167 #[cfg(feature = "multi-value-arrays")]
171 let result = {
172 let multi_map = parse_query_multi_value(query_string);
173 let json_value = multi_value_to_json_value(&multi_map);
174
175 serde_json::from_value(json_value).map(Query).map_err(|e| {
176 let raw_value = if query_string.is_empty() {
177 None
178 } else {
179 Some(query_string.to_string())
180 };
181 let mut ctx = super::ParamErrorContext::new(ParamType::Query, e.to_string())
182 .with_expected_type::<T>()
183 .with_source(Box::new(e));
184 if let Some(raw) = raw_value {
185 ctx = ctx.with_raw_value(raw);
186 }
187 ParamError::InvalidParameter(Box::new(ctx))
188 })
189 };
190
191 #[cfg(not(feature = "multi-value-arrays"))]
192 let result = serde_urlencoded::from_str(query_string)
193 .map(Query)
194 .map_err(|e| {
195 let raw_value = if query_string.is_empty() {
196 None
197 } else {
198 Some(query_string.to_string())
199 };
200 ParamError::url_encoding::<T>(ParamType::Query, e, raw_value)
201 });
202
203 result
204 }
205}
206
207#[cfg(feature = "validation")]
209impl<T> super::validation::WithValidation for Query<T> {}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use serde::Deserialize;
215
216 #[allow(dead_code)]
218 #[derive(Debug, Deserialize, PartialEq)]
219 struct TestQuery {
220 page: Option<i32>,
221 limit: Option<i32>,
222 search: Option<String>,
223 }
224}