use std::ops::Deref;
use std::ops::DerefMut;
use super::*;
use axum_core::extract::FromRequest;
use axum_core::extract::Request;
use axum_core::response::{IntoResponse, Response};
use bytes::{BufMut, Bytes, BytesMut};
use bronzerde::EDeserialize;
use http::header::{self, HeaderMap, HeaderValue};
use krabby_details::INTERNAL_SERVER_ERROR;
use serde::{de::DeserializeOwned, Serialize};
#[derive(Debug, Clone, Copy, Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
#[must_use]
pub struct Json<T>(pub T);
#[cfg(all(not(feature = "validator"), not(feature = "serde_valid")))]
impl<T, S> FromRequest<S> for Json<T>
where
T: DeserializeOwned,
T: for<'de> EDeserialize<'de>,
S: Send + Sync,
{
type Rejection = JsonRejection;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
check_json_content_type(req.headers())?;
let bytes = Bytes::from_request(req, state).await?;
Self::from_bytes(&bytes)
}
}
#[cfg(feature = "validator")]
impl<T, S> FromRequest<S> for Json<T>
where
T: validator::Validate + DeserializeOwned,
T: for<'de> EDeserialize<'de>,
S: Send + Sync,
{
type Rejection = JsonRejection;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
check_json_content_type(req.headers())?;
let bytes = Bytes::from_request(req, state).await?;
let json = Self::from_bytes(&bytes)?;
json.0.validate().map_err(JsonRejection::ValidationErrors)?;
Ok(json)
}
}
#[cfg(feature = "serde_valid")]
impl<T, S> FromRequest<S> for Json<T>
where
T: serde_valid::Validate + DeserializeOwned,
T: for<'de> EDeserialize<'de>,
S: Send + Sync,
{
type Rejection = JsonRejection;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
check_json_content_type(req.headers())?;
let bytes = Bytes::from_request(req, state).await?;
let json = Self::from_bytes(&bytes)?;
json.0
.validate()
.map_err(JsonRejection::SerdeValidRejection)?;
Ok(json)
}
}
fn check_json_content_type(headers: &HeaderMap) -> Result<(), JsonRejection> {
let Some(content_type) = headers.get(http::header::CONTENT_TYPE) else {
return Err(MissingJsonContentType.into());
};
let Ok(content_type) = content_type.to_str() else {
return Err(MissingJsonContentType.into());
};
let Ok(mime) = content_type.parse::<mime::Mime>() else {
return Err(JsonContentTypeMismatch {
actual: content_type.to_string(),
}
.into());
};
let is_json_content_type = mime.type_() == "application"
&& (mime.subtype() == "json" || mime.suffix().is_some_and(|name| name == "json"));
if !is_json_content_type {
return Err(JsonContentTypeMismatch {
actual: content_type.to_string(),
}
.into());
}
Ok(())
}
impl<T> Deref for Json<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Json<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T> From<T> for Json<T> {
fn from(inner: T) -> Self {
Self(inner)
}
}
impl<T> Json<T>
where
T: DeserializeOwned,
T: for<'de> EDeserialize<'de>,
{
pub fn from_bytes(bytes: &[u8]) -> Result<Self, JsonRejection> {
match bronzerde::json::from_slice(bytes) {
Ok(value) => Ok(Json(value)),
Err(errors) => Err(JsonError::new(errors).into()),
}
}
}
impl<T> IntoResponse for Json<T>
where
T: Serialize,
{
fn into_response(self) -> Response {
let mut buf = BytesMut::with_capacity(128).writer();
match serde_json::to_writer(&mut buf, &self.0) {
Ok(()) => (
[(
header::CONTENT_TYPE,
HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()),
)],
buf.into_inner().freeze(),
)
.into_response(),
Err(_) => INTERNAL_SERVER_ERROR.into_response(),
}
}
}