use std::ops::{Deref, DerefMut};
use http::StatusCode;
use serde::{Serialize, de::DeserializeOwned};
use crate::{
FromRequest, IntoResponse, Request, Response, Result, error::ParseXmlError, http::header,
web::RequestBody,
};
#[cfg_attr(docsrs, doc(cfg(feature = "xml")))]
#[derive(Debug, Clone, Eq, PartialEq, Default)]
pub struct Xml<T>(pub T);
impl<T> Deref for Xml<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Xml<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<'a, T: DeserializeOwned> FromRequest<'a> for Xml<T> {
async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result<Self> {
let content_type = req
.headers()
.get(header::CONTENT_TYPE)
.and_then(|content_type| content_type.to_str().ok())
.ok_or(ParseXmlError::ContentTypeRequired)?;
if !is_xml_content_type(content_type) {
return Err(ParseXmlError::InvalidContentType(content_type.into()).into());
}
Ok(Self(
quick_xml::de::from_reader(body.take()?.into_bytes().await?.as_ref())
.map_err(ParseXmlError::Parse)?,
))
}
}
fn is_xml_content_type(content_type: &str) -> bool {
matches!(content_type.parse::<mime::Mime>(),
Ok(content_type) if content_type.type_() == "application"
&& (content_type.subtype() == "xml"
|| content_type
.suffix()
.is_some_and(|v| v == "xml")))
}
impl<T: Serialize + Send> IntoResponse for Xml<T> {
fn into_response(self) -> Response {
let data = match quick_xml::se::to_string(&self.0) {
Ok(data) => data,
Err(err) => match err {
quick_xml::DeError::Unsupported(_) => {
match quick_xml::se::to_string_with_root("root", &self.0) {
Ok(data) => data,
Err(err) => {
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(err.to_string());
}
}
}
_ => {
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(err.to_string());
}
},
};
Response::builder()
.header(header::CONTENT_TYPE, "application/xml; charset=utf-8")
.body(data)
}
}
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
use super::*;
use crate::{handler, test::TestClient};
#[derive(Deserialize, Serialize, Debug, Eq, PartialEq)]
struct CreateResource {
name: String,
value: i32,
}
#[tokio::test]
async fn test_xml_extractor() {
#[handler(internal)]
async fn index(query: Xml<CreateResource>) {
assert_eq!(query.name, "abc");
assert_eq!(query.value, 100);
}
let cli = TestClient::new(index);
cli.post("/")
.body_xml(&CreateResource {
name: "abc".to_string(),
value: 100,
}) .send()
.await
.assert_status_is_ok();
}
#[tokio::test]
async fn test_xml_extractor_fail() {
#[handler(internal)]
async fn index(query: Xml<CreateResource>) {
assert_eq!(query.name, "abc");
assert_eq!(query.value, 100);
}
let create_resource = CreateResource {
name: "abc".to_string(),
value: 100,
};
let cli = TestClient::new(index);
cli.post("/")
.body(quick_xml::se::to_string(&create_resource).expect("Invalid xml"))
.send()
.await
.assert_status(StatusCode::UNSUPPORTED_MEDIA_TYPE);
}
#[tokio::test]
async fn test_xml_response() {
#[handler(internal)]
async fn index() -> Xml<CreateResource> {
Xml(CreateResource {
name: "abc".to_string(),
value: 100,
})
}
let cli = TestClient::new(index);
let resp = cli.get("/").send().await;
resp.assert_status_is_ok();
resp.assert_xml(&CreateResource {
name: "abc".to_string(),
value: 100,
})
.await;
}
}