use crate::{
body::{Bytes, HttpBody},
extract::{rejection::*, FromRequest},
BoxError,
};
use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response};
use bytes::{BufMut, BytesMut};
use http::{
header::{self, HeaderMap, HeaderValue},
Request, StatusCode,
};
use serde::{de::DeserializeOwned, Serialize};
use std::ops::{Deref, DerefMut};
#[derive(Debug, Clone, Copy, Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
pub struct Json<T>(pub T);
#[async_trait]
impl<T, S, B> FromRequest<S, B> for Json<T>
where
T: DeserializeOwned,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send + Sync,
{
type Rejection = JsonRejection;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
if json_content_type(req.headers()) {
let bytes = Bytes::from_request(req, state).await?;
let deserializer = &mut serde_json::Deserializer::from_slice(&bytes);
let value = match serde_path_to_error::deserialize(deserializer) {
Ok(value) => value,
Err(err) => {
let rejection = match err.inner().classify() {
serde_json::error::Category::Data => JsonDataError::from_err(err).into(),
serde_json::error::Category::Syntax | serde_json::error::Category::Eof => {
JsonSyntaxError::from_err(err).into()
}
serde_json::error::Category::Io => {
if cfg!(debug_assertions) {
unreachable!()
} else {
JsonSyntaxError::from_err(err).into()
}
}
};
return Err(rejection);
}
};
Ok(Json(value))
} else {
Err(MissingJsonContentType.into())
}
}
}
fn json_content_type(headers: &HeaderMap) -> bool {
let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
content_type
} else {
return false;
};
let content_type = if let Ok(content_type) = content_type.to_str() {
content_type
} else {
return false;
};
let mime = if let Ok(mime) = content_type.parse::<mime::Mime>() {
mime
} else {
return false;
};
let is_json_content_type = mime.type_() == "application"
&& (mime.subtype() == "json" || mime.suffix().map_or(false, |name| name == "json"));
is_json_content_type
}
impl<T> Deref for Json<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Json<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T> From<T> for Json<T> {
fn from(inner: T) -> Self {
Self(inner)
}
}
impl<T> IntoResponse for Json<T>
where
T: Serialize,
{
fn into_response(self) -> Response {
let mut buf = BytesMut::with_capacity(128).writer();
match serde_json::to_writer(&mut buf, &self.0) {
Ok(()) => (
[(
header::CONTENT_TYPE,
HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()),
)],
buf.into_inner().freeze(),
)
.into_response(),
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
[(
header::CONTENT_TYPE,
HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()),
)],
err.to_string(),
)
.into_response(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{routing::post, test_helpers::*, Router};
use serde::Deserialize;
use serde_json::{json, Value};
#[crate::test]
async fn deserialize_body() {
#[derive(Debug, Deserialize)]
struct Input {
foo: String,
}
let app = Router::new().route("/", post(|input: Json<Input>| async { input.0.foo }));
let client = TestClient::new(app);
let res = client.post("/").json(&json!({ "foo": "bar" })).send().await;
let body = res.text().await;
assert_eq!(body, "bar");
}
#[crate::test]
async fn consume_body_to_json_requires_json_content_type() {
#[derive(Debug, Deserialize)]
struct Input {
foo: String,
}
let app = Router::new().route("/", post(|input: Json<Input>| async { input.0.foo }));
let client = TestClient::new(app);
let res = client.post("/").body(r#"{ "foo": "bar" }"#).send().await;
let status = res.status();
assert_eq!(status, StatusCode::UNSUPPORTED_MEDIA_TYPE);
}
#[crate::test]
async fn json_content_types() {
async fn valid_json_content_type(content_type: &str) -> bool {
println!("testing {content_type:?}");
let app = Router::new().route("/", post(|Json(_): Json<Value>| async {}));
let res = TestClient::new(app)
.post("/")
.header("content-type", content_type)
.body("{}")
.send()
.await;
res.status() == StatusCode::OK
}
assert!(valid_json_content_type("application/json").await);
assert!(valid_json_content_type("application/json; charset=utf-8").await);
assert!(valid_json_content_type("application/json;charset=utf-8").await);
assert!(valid_json_content_type("application/cloudevents+json").await);
assert!(!valid_json_content_type("text/json").await);
}
#[crate::test]
async fn invalid_json_syntax() {
let app = Router::new().route("/", post(|_: Json<serde_json::Value>| async {}));
let client = TestClient::new(app);
let res = client
.post("/")
.body("{")
.header("content-type", "application/json")
.send()
.await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[derive(Deserialize)]
struct Foo {
#[allow(dead_code)]
a: i32,
#[allow(dead_code)]
b: Vec<Bar>,
}
#[derive(Deserialize)]
struct Bar {
#[allow(dead_code)]
x: i32,
#[allow(dead_code)]
y: i32,
}
#[crate::test]
async fn invalid_json_data() {
let app = Router::new().route("/", post(|_: Json<Foo>| async {}));
let client = TestClient::new(app);
let res = client
.post("/")
.body("{\"a\": 1, \"b\": [{\"x\": 2}]}")
.header("content-type", "application/json")
.send()
.await;
assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
let body_text = res.text().await;
assert_eq!(
body_text,
"Failed to deserialize the JSON body into the target type: b[0]: missing field `y` at line 1 column 23"
);
}
}