use axum_core::{
body::Body,
extract::{FromRequest, Request},
response::{IntoResponse, Response},
};
use core::fmt;
use core::ops::{Deref, DerefMut};
use facet_core::Facet;
use http::{HeaderValue, StatusCode, header};
use http_body_util::BodyExt;
use crate::DeserializeError;
#[derive(Debug, Clone, Copy, Default)]
pub struct Json<T>(pub T);
impl<T> Json<T> {
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> Deref for Json<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Json<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T> From<T> for Json<T> {
fn from(inner: T) -> Self {
Self(inner)
}
}
#[derive(Debug)]
pub struct JsonRejection {
kind: JsonRejectionKind,
}
#[derive(Debug)]
enum JsonRejectionKind {
Body(axum_core::Error),
Deserialize(DeserializeError),
MissingContentType,
InvalidContentType,
}
impl JsonRejection {
pub const fn status(&self) -> StatusCode {
match &self.kind {
JsonRejectionKind::Body(_) => StatusCode::BAD_REQUEST,
JsonRejectionKind::Deserialize(_) => StatusCode::UNPROCESSABLE_ENTITY,
JsonRejectionKind::MissingContentType => StatusCode::UNSUPPORTED_MEDIA_TYPE,
JsonRejectionKind::InvalidContentType => StatusCode::UNSUPPORTED_MEDIA_TYPE,
}
}
pub const fn is_body_error(&self) -> bool {
matches!(&self.kind, JsonRejectionKind::Body(_))
}
pub const fn is_deserialize_error(&self) -> bool {
matches!(&self.kind, JsonRejectionKind::Deserialize(_))
}
pub const fn is_missing_content_type(&self) -> bool {
matches!(&self.kind, JsonRejectionKind::MissingContentType)
}
pub const fn is_invalid_content_type(&self) -> bool {
matches!(&self.kind, JsonRejectionKind::InvalidContentType)
}
}
impl fmt::Display for JsonRejection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.kind {
JsonRejectionKind::Body(err) => {
write!(f, "Failed to read request body: {err}")
}
JsonRejectionKind::Deserialize(err) => {
write!(f, "Failed to deserialize JSON: {err}")
}
JsonRejectionKind::MissingContentType => {
write!(f, "Missing `Content-Type: application/json` header")
}
JsonRejectionKind::InvalidContentType => {
write!(
f,
"Invalid `Content-Type` header: expected `application/json`"
)
}
}
}
}
impl std::error::Error for JsonRejection {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.kind {
JsonRejectionKind::Body(err) => Some(err),
JsonRejectionKind::Deserialize(err) => Some(err),
JsonRejectionKind::MissingContentType => None,
JsonRejectionKind::InvalidContentType => None,
}
}
}
impl IntoResponse for JsonRejection {
fn into_response(self) -> Response {
(self.status(), self.to_string()).into_response()
}
}
fn is_json_content_type(req: &Request) -> bool {
let Some(content_type) = req.headers().get(header::CONTENT_TYPE) else {
return false;
};
let Ok(content_type) = content_type.to_str() else {
return false;
};
let mime = content_type.parse::<mime::Mime>();
match mime {
Ok(mime) => {
mime.type_() == mime::APPLICATION
&& (mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON))
}
Err(_) => false,
}
}
impl<T, S> FromRequest<S> for Json<T>
where
T: Facet<'static>,
S: Send + Sync,
{
type Rejection = JsonRejection;
async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
if !is_json_content_type(&req) {
if req.headers().get(header::CONTENT_TYPE).is_none() {
return Err(JsonRejection {
kind: JsonRejectionKind::MissingContentType,
});
}
return Err(JsonRejection {
kind: JsonRejectionKind::InvalidContentType,
});
}
let bytes = req
.into_body()
.collect()
.await
.map_err(|e| JsonRejection {
kind: JsonRejectionKind::Body(axum_core::Error::new(e)),
})?
.to_bytes();
let value: T = crate::from_slice(&bytes).map_err(|e| JsonRejection {
kind: JsonRejectionKind::Deserialize(e),
})?;
Ok(Json(value))
}
}
impl<T> IntoResponse for Json<T>
where
T: Facet<'static>,
{
fn into_response(self) -> Response {
match crate::to_vec(&self.0) {
Ok(bytes) => {
let mut res = Response::new(Body::from(bytes));
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
res
}
Err(err) => {
let body = format!("Failed to serialize response: {err}");
(StatusCode::INTERNAL_SERVER_ERROR, body).into_response()
}
}
}
}