use axum_core::{
body::Body,
extract::{FromRequest, Request},
response::{IntoResponse, Response},
};
use core::fmt;
use core::ops::{Deref, DerefMut};
use facet_core::Facet;
use http::{HeaderValue, StatusCode, header};
use http_body_util::BodyExt;
use crate::{DeserializeError, MsgPackSerializeError};
#[derive(Debug, Clone, Copy, Default)]
pub struct MsgPack<T>(pub T);
impl<T> MsgPack<T> {
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> Deref for MsgPack<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for MsgPack<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T> From<T> for MsgPack<T> {
fn from(inner: T) -> Self {
Self(inner)
}
}
#[derive(Debug)]
pub struct MsgPackRejection {
kind: MsgPackRejectionKind,
}
#[derive(Debug)]
enum MsgPackRejectionKind {
Body(axum_core::Error),
Deserialize(DeserializeError),
}
impl MsgPackRejection {
pub fn status(&self) -> StatusCode {
match &self.kind {
MsgPackRejectionKind::Body(_) => StatusCode::BAD_REQUEST,
MsgPackRejectionKind::Deserialize(_) => StatusCode::UNPROCESSABLE_ENTITY,
}
}
pub fn is_body_error(&self) -> bool {
matches!(&self.kind, MsgPackRejectionKind::Body(_))
}
pub fn is_deserialize_error(&self) -> bool {
matches!(&self.kind, MsgPackRejectionKind::Deserialize(_))
}
}
impl fmt::Display for MsgPackRejection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.kind {
MsgPackRejectionKind::Body(err) => {
write!(f, "Failed to read request body: {err}")
}
MsgPackRejectionKind::Deserialize(err) => {
write!(f, "Failed to deserialize MsgPack: {err}")
}
}
}
}
impl std::error::Error for MsgPackRejection {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.kind {
MsgPackRejectionKind::Body(err) => Some(err),
MsgPackRejectionKind::Deserialize(err) => Some(err),
}
}
}
impl IntoResponse for MsgPackRejection {
fn into_response(self) -> Response {
(self.status(), self.to_string()).into_response()
}
}
impl<T, S> FromRequest<S> for MsgPack<T>
where
T: Facet<'static>,
S: Send + Sync,
{
type Rejection = MsgPackRejection;
async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
let bytes = req
.into_body()
.collect()
.await
.map_err(|e| MsgPackRejection {
kind: MsgPackRejectionKind::Body(axum_core::Error::new(e)),
})?
.to_bytes();
let value: T = crate::from_slice(&bytes).map_err(|e| MsgPackRejection {
kind: MsgPackRejectionKind::Deserialize(e),
})?;
Ok(MsgPack(value))
}
}
impl<T> IntoResponse for MsgPack<T>
where
T: Facet<'static>,
{
fn into_response(self) -> Response {
match crate::to_vec(&self.0) {
Ok(bytes) => {
let mut res = Response::new(Body::from(bytes));
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/msgpack"),
);
res
}
Err(err) => {
let body = format!("Failed to serialize response: {err}");
(StatusCode::INTERNAL_SERVER_ERROR, body).into_response()
}
}
}
}
#[derive(Debug)]
pub struct MsgPackSerializeRejection(pub MsgPackSerializeError);
impl fmt::Display for MsgPackSerializeRejection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Failed to serialize MsgPack response: {}", self.0)
}
}
impl std::error::Error for MsgPackSerializeRejection {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.0)
}
}
impl IntoResponse for MsgPackSerializeRejection {
fn into_response(self) -> Response {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
}
}