axum_extra/extract/
query.rs1use axum_core::__composite_rejection as composite_rejection;
2use axum_core::__define_rejection as define_rejection;
3use axum_core::extract::FromRequestParts;
4use http::{request::Parts, Uri};
5use serde_core::de::DeserializeOwned;
6
7#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
78#[derive(Debug, Clone, Copy, Default)]
79pub struct Query<T>(pub T);
80
81impl<T, S> FromRequestParts<S> for Query<T>
82where
83 T: DeserializeOwned,
84 S: Send + Sync,
85{
86 type Rejection = QueryRejection;
87
88 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
89 let query = parts.uri.query().unwrap_or_default();
90 let deserializer =
91 serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
92 let value = serde_path_to_error::deserialize(deserializer)
93 .map_err(FailedToDeserializeQueryString::from_err)?;
94 Ok(Query(value))
95 }
96}
97
98impl<T> Query<T>
99where
100 T: DeserializeOwned,
101{
102 pub fn try_from_uri(value: &Uri) -> Result<Self, QueryRejection> {
122 let query = value.query().unwrap_or_default();
123 let params =
124 serde_html_form::from_str(query).map_err(FailedToDeserializeQueryString::from_err)?;
125 Ok(Self(params))
126 }
127}
128
129axum_core::__impl_deref!(Query);
130
131define_rejection! {
132 #[status = BAD_REQUEST]
133 #[body = "Failed to deserialize query string"]
134 pub struct FailedToDeserializeQueryString(Error);
137}
138
139composite_rejection! {
140 pub enum QueryRejection {
144 FailedToDeserializeQueryString,
145 }
146}
147
148#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
188#[derive(Debug, Clone, Copy, Default)]
189pub struct OptionalQuery<T>(pub Option<T>);
190
191impl<T, S> FromRequestParts<S> for OptionalQuery<T>
192where
193 T: DeserializeOwned,
194 S: Send + Sync,
195{
196 type Rejection = OptionalQueryRejection;
197
198 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
199 if let Some(query) = parts.uri.query() {
200 let deserializer =
201 serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
202 let value = serde_path_to_error::deserialize(deserializer)
203 .map_err(FailedToDeserializeQueryString::from_err)?;
204 Ok(OptionalQuery(Some(value)))
205 } else {
206 Ok(OptionalQuery(None))
207 }
208 }
209}
210
211impl<T> std::ops::Deref for OptionalQuery<T> {
212 type Target = Option<T>;
213
214 #[inline]
215 fn deref(&self) -> &Self::Target {
216 &self.0
217 }
218}
219
220impl<T> std::ops::DerefMut for OptionalQuery<T> {
221 #[inline]
222 fn deref_mut(&mut self) -> &mut Self::Target {
223 &mut self.0
224 }
225}
226
227composite_rejection! {
228 pub enum OptionalQueryRejection {
232 FailedToDeserializeQueryString,
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use crate::test_helpers::*;
240 use axum::routing::{get, post};
241 use axum::Router;
242 use http::header::CONTENT_TYPE;
243 use http::StatusCode;
244 use serde::Deserialize;
245
246 #[tokio::test]
247 async fn query_supports_multiple_values() {
248 #[derive(Deserialize)]
249 struct Data {
250 #[serde(rename = "value")]
251 values: Vec<String>,
252 }
253
254 let app = Router::new().route(
255 "/",
256 post(|Query(data): Query<Data>| async move { data.values.join(",") }),
257 );
258
259 let client = TestClient::new(app);
260
261 let res = client
262 .post("/?value=one&value=two")
263 .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
264 .body("")
265 .await;
266
267 assert_eq!(res.status(), StatusCode::OK);
268 assert_eq!(res.text().await, "one,two");
269 }
270
271 #[tokio::test]
272 async fn correct_rejection_status_code() {
273 #[derive(Deserialize)]
274 #[allow(dead_code)]
275 struct Params {
276 n: i32,
277 }
278
279 async fn handler(_: Query<Params>) {}
280
281 let app = Router::new().route("/", get(handler));
282 let client = TestClient::new(app);
283
284 let res = client.get("/?n=hi").await;
285 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
286 assert_eq!(
287 res.text().await,
288 "Failed to deserialize query string: n: invalid digit found in string"
289 );
290 }
291
292 #[tokio::test]
293 async fn optional_query_supports_multiple_values() {
294 #[derive(Deserialize)]
295 struct Data {
296 #[serde(rename = "value")]
297 values: Vec<String>,
298 }
299
300 let app = Router::new().route(
301 "/",
302 post(|OptionalQuery(data): OptionalQuery<Data>| async move {
303 data.map(|Data { values }| values.join(","))
304 .unwrap_or("None".to_owned())
305 }),
306 );
307
308 let client = TestClient::new(app);
309
310 let res = client
311 .post("/?value=one&value=two")
312 .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
313 .body("")
314 .await;
315
316 assert_eq!(res.status(), StatusCode::OK);
317 assert_eq!(res.text().await, "one,two");
318 }
319
320 #[tokio::test]
321 async fn optional_query_deserializes_no_parameters_into_none() {
322 #[derive(Deserialize)]
323 struct Data {
324 value: String,
325 }
326
327 let app = Router::new().route(
328 "/",
329 post(|OptionalQuery(data): OptionalQuery<Data>| async move {
330 match data {
331 None => "None".into(),
332 Some(data) => data.value,
333 }
334 }),
335 );
336
337 let client = TestClient::new(app);
338
339 let res = client.post("/").body("").await;
340
341 assert_eq!(res.status(), StatusCode::OK);
342 assert_eq!(res.text().await, "None");
343 }
344
345 #[tokio::test]
346 async fn optional_query_preserves_parsing_errors() {
347 #[derive(Deserialize)]
348 struct Data {
349 value: String,
350 }
351
352 let app = Router::new().route(
353 "/",
354 post(|OptionalQuery(data): OptionalQuery<Data>| async move {
355 match data {
356 None => "None".into(),
357 Some(data) => data.value,
358 }
359 }),
360 );
361
362 let client = TestClient::new(app);
363
364 let res = client
365 .post("/?other=something")
366 .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
367 .body("")
368 .await;
369
370 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
371 }
372
373 #[test]
374 fn test_try_from_uri() {
375 #[derive(Deserialize)]
376 struct TestQueryParams {
377 foo: Vec<String>,
378 bar: u32,
379 }
380 let uri: Uri = "http://example.com/path?foo=hello&bar=42&foo=goodbye"
381 .parse()
382 .unwrap();
383 let result: Query<TestQueryParams> = Query::try_from_uri(&uri).unwrap();
384 assert_eq!(result.foo, [String::from("hello"), String::from("goodbye")]);
385 assert_eq!(result.bar, 42);
386 }
387
388 #[test]
389 fn test_try_from_uri_with_invalid_query() {
390 #[derive(Deserialize)]
391 struct TestQueryParams {
392 _foo: String,
393 _bar: u32,
394 }
395 let uri: Uri = "http://example.com/path?foo=hello&bar=invalid"
396 .parse()
397 .unwrap();
398 let result: Result<Query<TestQueryParams>, _> = Query::try_from_uri(&uri);
399
400 assert!(result.is_err());
401 }
402}