#![allow(clippy::module_name_repetitions)]
use std::ops::{Deref, DerefMut};
use async_trait::async_trait;
use axum_core::extract::{FromRequest, RequestParts};
use axum_core::response::{IntoResponse, Response};
use axum_core::BoxError;
use bytes::Bytes;
use http::{header, HeaderValue, StatusCode};
use http_body::Body as HttpBody;
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::rejection::XmlRejection;
mod rejection;
#[cfg(test)]
mod tests;
#[derive(Debug, Clone, Copy, Default)]
pub struct Xml<T>(pub T);
#[async_trait]
impl<T, B> FromRequest<B> for Xml<T>
where
T: DeserializeOwned,
B: HttpBody + Send,
B::Data: Send,
B::Error: Into<BoxError>,
{
type Rejection = XmlRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
if xml_content_type(req) {
let bytes = Bytes::from_request(req).await?;
let value = quick_xml::de::from_reader(&*bytes)?;
Ok(Self(value))
} else {
Err(XmlRejection::MissingXMLContentType)
}
}
}
fn xml_content_type<B>(req: &RequestParts<B>) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
content_type
} else {
return false;
};
let content_type = if let Ok(content_type) = content_type.to_str() {
content_type
} else {
return false;
};
let mime = if let Ok(mime) = content_type.parse::<mime::Mime>() {
mime
} else {
return false;
};
let is_xml_content_type = (mime.type_() == "application" || mime.type_() == "text")
&& (mime.subtype() == "xml" || mime.suffix().map_or(false, |name| name == "xml"));
is_xml_content_type
}
impl<T> Deref for Xml<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Xml<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T> From<T> for Xml<T> {
fn from(inner: T) -> Self {
Self(inner)
}
}
impl<T> IntoResponse for Xml<T>
where
T: Serialize,
{
fn into_response(self) -> Response {
let mut bytes = Vec::new();
match quick_xml::se::to_writer(&mut bytes, &self.0) {
Ok(_) => (
[(
header::CONTENT_TYPE,
HeaderValue::from_static("application/xml"),
)],
bytes,
)
.into_response(),
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
[(
header::CONTENT_TYPE,
HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()),
)],
err.to_string(),
)
.into_response(),
}
}
}