use http::StatusCode;
use http::header::HeaderValue;
use http_body_util::BodyExt;
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::body::TakoBody;
use crate::extractors::FromRequest;
use crate::responder::Responder;
use crate::types::Request;
use crate::types::Response;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdJsonMode {
Always,
Never,
Threshold(usize),
}
impl Default for SimdJsonMode {
fn default() -> Self {
Self::Threshold(2 * 1024 * 1024) }
}
#[doc(alias = "json")]
pub struct Json<T>(pub T);
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum JsonError {
InvalidContentType,
MissingContentType,
BodyReadError(String),
DeserializationError(String),
}
impl std::fmt::Display for JsonError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidContentType => {
write!(f, "invalid content type; expected application/json")
}
Self::MissingContentType => write!(f, "missing content type header"),
Self::BodyReadError(err) => write!(f, "failed to read request body: {err}"),
Self::DeserializationError(err) => write!(f, "failed to deserialize JSON: {err}"),
}
}
}
impl std::error::Error for JsonError {}
impl Responder for JsonError {
fn into_response(self) -> crate::types::Response {
match self {
JsonError::InvalidContentType => (
StatusCode::BAD_REQUEST,
"Invalid content type; expected application/json",
)
.into_response(),
JsonError::MissingContentType => {
(StatusCode::BAD_REQUEST, "Missing content type header").into_response()
}
JsonError::BodyReadError(err) => (
StatusCode::BAD_REQUEST,
format!("Failed to read request body: {err}"),
)
.into_response(),
JsonError::DeserializationError(err) => (
StatusCode::BAD_REQUEST,
format!("Failed to deserialize JSON: {err}"),
)
.into_response(),
}
}
}
use crate::extractors::is_json_content_type;
impl<'a, T> FromRequest<'a> for Json<T>
where
T: DeserializeOwned + Send + 'static,
{
type Error = JsonError;
fn from_request(
req: &'a mut Request,
) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
async move {
if !is_json_content_type(req.headers()) {
return Err(JsonError::InvalidContentType);
}
let body_bytes = req
.body_mut()
.collect()
.await
.map_err(|e| JsonError::BodyReadError(e.to_string()))?
.to_bytes();
#[cfg(feature = "simd")]
let data = {
let mode = req
.extensions()
.get::<SimdJsonMode>()
.copied()
.unwrap_or_default();
let use_simd = match mode {
SimdJsonMode::Always => true,
SimdJsonMode::Never => false,
SimdJsonMode::Threshold(threshold) => body_bytes.len() >= threshold,
};
if use_simd {
let mut owned = body_bytes.to_vec();
sonic_rs::from_slice::<T>(&mut owned)
.map_err(|e| JsonError::DeserializationError(e.to_string()))?
} else {
serde_json::from_slice(&body_bytes)
.map_err(|e| JsonError::DeserializationError(e.to_string()))?
}
};
#[cfg(not(feature = "simd"))]
let data = serde_json::from_slice(&body_bytes)
.map_err(|e| JsonError::DeserializationError(e.to_string()))?;
Ok(Json(data))
}
}
}
impl<T> Responder for Json<T>
where
T: Serialize,
{
fn into_response(self) -> Response {
match serde_json::to_vec(&self.0) {
Ok(buf) => {
let mut res = Response::new(TakoBody::from(buf));
res.headers_mut().insert(
http::header::CONTENT_TYPE,
HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()),
);
res
}
Err(err) => {
let mut res = Response::new(crate::body::TakoBody::from(err.to_string()));
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
res.headers_mut().insert(
http::header::CONTENT_TYPE,
HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()),
);
res
}
}
}
}