use std::ops::{Deref, DerefMut};
use http::StatusCode;
use serde::{Serialize, de::DeserializeOwned};
use crate::{
FromRequest, IntoResponse, Request, Response, Result, error::ParseYamlError, http::header,
web::RequestBody,
};
#[derive(Debug, Clone, Eq, PartialEq, Default)]
pub struct Yaml<T>(pub T);
impl<T> Deref for Yaml<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Yaml<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<'a, T: DeserializeOwned> FromRequest<'a> for Yaml<T> {
async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result<Self> {
let content_type = req
.headers()
.get(header::CONTENT_TYPE)
.and_then(|content_type| content_type.to_str().ok())
.ok_or(ParseYamlError::ContentTypeRequired)?;
if !is_yaml_content_type(content_type) {
return Err(ParseYamlError::InvalidContentType(content_type.into()).into());
}
Ok(Self(
serde_yaml::from_slice(&body.take()?.into_bytes().await?)
.map_err(ParseYamlError::Parse)?,
))
}
}
fn is_yaml_content_type(content_type: &str) -> bool {
matches!(content_type.parse::<mime::Mime>(),
Ok(content_type) if content_type.type_() == "application"
&& (content_type.subtype() == "yaml"
|| content_type
.suffix()
.is_some_and(|v| v == "yaml")))
}
impl<T: Serialize + Send> IntoResponse for Yaml<T> {
fn into_response(self) -> Response {
let data = match serde_yaml::to_string(&self.0) {
Ok(data) => data,
Err(err) => {
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(err.to_string());
}
};
Response::builder()
.header(header::CONTENT_TYPE, "application/yaml; charset=utf-8")
.body(data)
}
}
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
use super::*;
use crate::{handler, test::TestClient};
#[derive(Deserialize, Serialize, Debug, Eq, PartialEq)]
struct CreateResource {
name: String,
value: i32,
}
#[tokio::test]
async fn test_yaml_extractor() {
#[handler(internal)]
async fn index(query: Yaml<CreateResource>) {
assert_eq!(query.name, "abc");
assert_eq!(query.value, 100);
}
let cli = TestClient::new(index);
cli.post("/")
.body_yaml(&CreateResource {
name: "abc".to_string(),
value: 100,
}) .send()
.await
.assert_status_is_ok();
}
#[tokio::test]
async fn test_yaml_extractor_fail() {
#[handler(internal)]
async fn index(query: Yaml<CreateResource>) {
assert_eq!(query.name, "abc");
assert_eq!(query.value, 100);
}
let create_resource = CreateResource {
name: "abc".to_string(),
value: 100,
};
let cli = TestClient::new(index);
cli.post("/")
.body(serde_yaml::to_string(&create_resource).expect("Invalid yaml"))
.send()
.await
.assert_status(StatusCode::UNSUPPORTED_MEDIA_TYPE);
}
#[tokio::test]
async fn test_yaml_response() {
#[handler(internal)]
async fn index() -> Yaml<CreateResource> {
Yaml(CreateResource {
name: "abc".to_string(),
value: 100,
})
}
let cli = TestClient::new(index);
let resp = cli.get("/").send().await;
resp.assert_status_is_ok();
resp.assert_yaml(&CreateResource {
name: "abc".to_string(),
value: 100,
})
.await;
}
}