use crate::BoxError;
use crate::{
extract::{rejection::*, take_body, FromRequest, RequestParts},
response::IntoResponse,
};
use async_trait::async_trait;
use bytes::Bytes;
use http::{
header::{self, HeaderValue},
StatusCode,
};
use http_body::Full;
use hyper::Response;
use serde::{de::DeserializeOwned, Serialize};
use std::{
convert::Infallible,
ops::{Deref, DerefMut},
};
#[derive(Debug, Clone, Copy, Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
pub struct Json<T>(pub T);
#[async_trait]
impl<T, B> FromRequest<B> for Json<T>
where
T: DeserializeOwned,
B: http_body::Body + Send,
B::Data: Send,
B::Error: Into<BoxError>,
{
type Rejection = JsonRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
use bytes::Buf;
if json_content_type(req)? {
let body = take_body(req)?;
let buf = hyper::body::aggregate(body)
.await
.map_err(InvalidJsonBody::from_err)?;
let value = serde_json::from_reader(buf.reader()).map_err(InvalidJsonBody::from_err)?;
Ok(Json(value))
} else {
Err(MissingJsonContentType.into())
}
}
}
fn json_content_type<B>(req: &RequestParts<B>) -> Result<bool, HeadersAlreadyExtracted> {
let content_type = if let Some(content_type) = req
.headers()
.ok_or(HeadersAlreadyExtracted)?
.get(header::CONTENT_TYPE)
{
content_type
} else {
return Ok(false);
};
let content_type = if let Ok(content_type) = content_type.to_str() {
content_type
} else {
return Ok(false);
};
let mime = if let Ok(mime) = content_type.parse::<mime::Mime>() {
mime
} else {
return Ok(false);
};
let is_json_content_type = mime.type_() == "application"
&& (mime.subtype() == "json" || mime.suffix().filter(|name| *name == "json").is_some());
Ok(is_json_content_type)
}
impl<T> Deref for Json<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Json<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T> From<T> for Json<T> {
fn from(inner: T) -> Self {
Self(inner)
}
}
impl<T> IntoResponse for Json<T>
where
T: Serialize,
{
type Body = Full<Bytes>;
type BodyError = Infallible;
fn into_response(self) -> Response<Self::Body> {
let bytes = match serde_json::to_vec(&self.0) {
Ok(res) => res,
Err(err) => {
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(header::CONTENT_TYPE, "text/plain")
.body(Full::from(err.to_string()))
.unwrap();
}
};
let mut res = Response::new(Full::from(bytes));
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
res
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{routing::post, test_helpers::*, Router};
use serde::Deserialize;
use serde_json::{json, Value};
#[tokio::test]
async fn deserialize_body() {
#[derive(Debug, Deserialize)]
struct Input {
foo: String,
}
let app = Router::new().route("/", post(|input: Json<Input>| async { input.0.foo }));
let client = TestClient::new(app);
let res = client.post("/").json(&json!({ "foo": "bar" })).send().await;
let body = res.text().await;
assert_eq!(body, "bar");
}
#[tokio::test]
async fn consume_body_to_json_requires_json_content_type() {
#[derive(Debug, Deserialize)]
struct Input {
foo: String,
}
let app = Router::new().route("/", post(|input: Json<Input>| async { input.0.foo }));
let client = TestClient::new(app);
let res = client.post("/").body(r#"{ "foo": "bar" }"#).send().await;
let status = res.status();
dbg!(res.text().await);
assert_eq!(status, StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn json_content_types() {
async fn valid_json_content_type(content_type: &str) -> bool {
println!("testing {:?}", content_type);
let app = Router::new().route("/", post(|Json(_): Json<Value>| async {}));
let res = TestClient::new(app)
.post("/")
.header("content-type", content_type)
.body("{}")
.send()
.await;
res.status() == StatusCode::OK
}
assert!(valid_json_content_type("application/json").await);
assert!(valid_json_content_type("application/json; charset=utf-8").await);
assert!(valid_json_content_type("application/json;charset=utf-8").await);
assert!(valid_json_content_type("application/cloudevents+json").await);
assert!(!valid_json_content_type("text/json").await);
}
}