use axum_core::__composite_rejection as composite_rejection;
use axum_core::__define_rejection as define_rejection;
use axum_core::extract::FromRequestParts;
use http::{request::Parts, Uri};
use serde_core::de::DeserializeOwned;
#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
#[derive(Debug, Clone, Copy, Default)]
pub struct Query<T>(pub T);
impl<T, S> FromRequestParts<S> for Query<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = QueryRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let query = parts.uri.query().unwrap_or_default();
let deserializer =
serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
let value = serde_path_to_error::deserialize(deserializer)
.map_err(FailedToDeserializeQueryString::from_err)?;
Ok(Query(value))
}
}
impl<T> Query<T>
where
T: DeserializeOwned,
{
pub fn try_from_uri(value: &Uri) -> Result<Self, QueryRejection> {
let query = value.query().unwrap_or_default();
let params =
serde_html_form::from_str(query).map_err(FailedToDeserializeQueryString::from_err)?;
Ok(Self(params))
}
}
axum_core::__impl_deref!(Query);
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Failed to deserialize query string"]
pub struct FailedToDeserializeQueryString(Error);
}
composite_rejection! {
pub enum QueryRejection {
FailedToDeserializeQueryString,
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
#[derive(Debug, Clone, Copy, Default)]
pub struct OptionalQuery<T>(pub Option<T>);
impl<T, S> FromRequestParts<S> for OptionalQuery<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = OptionalQueryRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(query) = parts.uri.query() {
let deserializer =
serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
let value = serde_path_to_error::deserialize(deserializer)
.map_err(FailedToDeserializeQueryString::from_err)?;
Ok(OptionalQuery(Some(value)))
} else {
Ok(OptionalQuery(None))
}
}
}
impl<T> std::ops::Deref for OptionalQuery<T> {
type Target = Option<T>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> std::ops::DerefMut for OptionalQuery<T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
composite_rejection! {
pub enum OptionalQueryRejection {
FailedToDeserializeQueryString,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::routing::{get, post};
use axum::Router;
use http::header::CONTENT_TYPE;
use http::StatusCode;
use serde::Deserialize;
#[tokio::test]
async fn query_supports_multiple_values() {
#[derive(Deserialize)]
struct Data {
#[serde(rename = "value")]
values: Vec<String>,
}
let app = Router::new().route(
"/",
post(|Query(data): Query<Data>| async move { data.values.join(",") }),
);
let client = TestClient::new(app);
let res = client
.post("/?value=one&value=two")
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.body("")
.await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "one,two");
}
#[tokio::test]
async fn correct_rejection_status_code() {
#[derive(Deserialize)]
#[allow(dead_code)]
struct Params {
n: i32,
}
async fn handler(_: Query<Params>) {}
let app = Router::new().route("/", get(handler));
let client = TestClient::new(app);
let res = client.get("/?n=hi").await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
assert_eq!(
res.text().await,
"Failed to deserialize query string: n: invalid digit found in string"
);
}
#[tokio::test]
async fn optional_query_supports_multiple_values() {
#[derive(Deserialize)]
struct Data {
#[serde(rename = "value")]
values: Vec<String>,
}
let app = Router::new().route(
"/",
post(|OptionalQuery(data): OptionalQuery<Data>| async move {
data.map(|Data { values }| values.join(","))
.unwrap_or("None".to_owned())
}),
);
let client = TestClient::new(app);
let res = client
.post("/?value=one&value=two")
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.body("")
.await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "one,two");
}
#[tokio::test]
async fn optional_query_deserializes_no_parameters_into_none() {
#[derive(Deserialize)]
struct Data {
value: String,
}
let app = Router::new().route(
"/",
post(|OptionalQuery(data): OptionalQuery<Data>| async move {
match data {
None => "None".into(),
Some(data) => data.value,
}
}),
);
let client = TestClient::new(app);
let res = client.post("/").body("").await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "None");
}
#[tokio::test]
async fn optional_query_preserves_parsing_errors() {
#[derive(Deserialize)]
struct Data {
value: String,
}
let app = Router::new().route(
"/",
post(|OptionalQuery(data): OptionalQuery<Data>| async move {
match data {
None => "None".into(),
Some(data) => data.value,
}
}),
);
let client = TestClient::new(app);
let res = client
.post("/?other=something")
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.body("")
.await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_try_from_uri() {
#[derive(Deserialize)]
struct TestQueryParams {
foo: Vec<String>,
bar: u32,
}
let uri: Uri = "http://example.com/path?foo=hello&bar=42&foo=goodbye"
.parse()
.unwrap();
let result: Query<TestQueryParams> = Query::try_from_uri(&uri).unwrap();
assert_eq!(result.foo, [String::from("hello"), String::from("goodbye")]);
assert_eq!(result.bar, 42);
}
#[test]
fn test_try_from_uri_with_invalid_query() {
#[derive(Deserialize)]
struct TestQueryParams {
_foo: String,
_bar: u32,
}
let uri: Uri = "http://example.com/path?foo=hello&bar=invalid"
.parse()
.unwrap();
let result: Result<Query<TestQueryParams>, _> = Query::try_from_uri(&uri);
assert!(result.is_err());
}
}