use std::ops::{Deref, DerefMut};
use http::StatusCode;
use serde::{Serialize, de::DeserializeOwned};
use crate::{
FromRequest, IntoResponse, Request, Response, Result, error::ParseJsonError, http::header,
web::RequestBody,
};
#[derive(Debug, Clone, Eq, PartialEq, Default)]
pub struct Json<T>(pub T);
impl<T> Deref for Json<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Json<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<'a, T: DeserializeOwned> FromRequest<'a> for Json<T> {
async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result<Self> {
let content_type = req
.headers()
.get(header::CONTENT_TYPE)
.and_then(|content_type| content_type.to_str().ok())
.ok_or(ParseJsonError::ContentTypeRequired)?;
if !is_json_content_type(content_type) {
return Err(ParseJsonError::InvalidContentType(content_type.into()).into());
}
#[cfg(not(feature = "sonic-rs"))]
{
Ok(Self(
serde_json::from_slice(&body.take()?.into_bytes().await?)
.map_err(ParseJsonError::Parse)?,
))
}
#[cfg(feature = "sonic-rs")]
{
Ok(Self(
sonic_rs::from_slice(&body.take()?.into_bytes().await?)
.map_err(ParseJsonError::Parse)?,
))
}
}
}
fn is_json_content_type(content_type: &str) -> bool {
matches!(content_type.parse::<mime::Mime>(),
Ok(content_type) if content_type.type_() == "application"
&& (content_type.subtype() == "json"
|| content_type
.suffix()
.is_some_and(|v| v == "json")))
}
impl<T: Serialize + Send> IntoResponse for Json<T> {
fn into_response(self) -> Response {
#[cfg(not(feature = "sonic-rs"))]
let vec = serde_json::to_vec(&self.0);
#[cfg(feature = "sonic-rs")]
let vec = sonic_rs::to_vec(&self.0);
let data = match vec {
Ok(data) => data,
Err(err) => {
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(err.to_string());
}
};
Response::builder()
.header(header::CONTENT_TYPE, "application/json; charset=utf-8")
.body(data)
}
}
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
#[cfg(not(feature = "sonic-rs"))]
use serde_json::{json, to_string};
#[cfg(feature = "sonic-rs")]
use sonic_rs::{json, to_string};
use super::*;
use crate::{handler, test::TestClient};
#[derive(Deserialize, Serialize, Debug, Eq, PartialEq)]
struct CreateResource {
name: String,
value: i32,
}
#[tokio::test]
async fn test_json_extractor() {
#[handler(internal)]
async fn index(query: Json<CreateResource>) {
assert_eq!(query.name, "abc");
assert_eq!(query.value, 100);
}
let cli = TestClient::new(index);
cli.post("/")
.body_json(&json!({"name": "abc", "value": 100})) .send()
.await
.assert_status_is_ok();
}
#[tokio::test]
async fn test_json_extractor_fail() {
#[handler(internal)]
async fn index(query: Json<CreateResource>) {
assert_eq!(query.name, "abc");
assert_eq!(query.value, 100);
}
let create_resource = CreateResource {
name: "abc".to_string(),
value: 100,
};
let cli = TestClient::new(index);
cli.post("/")
.body(to_string(&create_resource).expect("Invalid json"))
.send()
.await
.assert_status(StatusCode::UNSUPPORTED_MEDIA_TYPE);
}
#[tokio::test]
async fn test_json_response() {
#[handler(internal)]
async fn index() -> Json<CreateResource> {
Json(CreateResource {
name: "abc".to_string(),
value: 100,
})
}
let cli = TestClient::new(index);
let resp = cli.get("/").send().await;
resp.assert_status_is_ok();
resp.assert_json(&CreateResource {
name: "abc".to_string(),
value: 100,
})
.await;
}
}