1use std::ops::{Deref, DerefMut};
2
3use axum_core::{
4 extract::{FromRequest, Request},
5 response::{IntoResponse, Response},
6};
7use bytes::{BufMut, Bytes, BytesMut};
8use http::{header, HeaderMap, HeaderValue, StatusCode};
9use serde::{de::DeserializeOwned, Serialize};
10
11use crate::rejection::*;
12
13#[derive(Debug, Clone, Copy, Default)]
94pub struct Yaml<T>(pub T);
95
96impl<T, S> FromRequest<S> for Yaml<T>
97where
98 T: DeserializeOwned,
99 S: Send + Sync,
100{
101 type Rejection = YamlRejection;
102
103 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
104 if yaml_content_type(req.headers()) {
105 let bytes = Bytes::from_request(req, state).await?;
106 Self::from_bytes(&bytes)
107 } else {
108 Err(MissingYamlContentType.into())
109 }
110 }
111}
112
113fn yaml_content_type(headers: &HeaderMap) -> bool {
114 let Some(content_type) = headers.get(header::CONTENT_TYPE) else {
115 return false;
116 };
117
118 let Ok(content_type) = content_type.to_str() else {
119 return false;
120 };
121
122 let Ok(mime) = content_type.parse::<mime::Mime>() else {
123 return false;
124 };
125
126 let is_yaml_content_type = mime.type_() == "application"
127 && (mime.subtype() == "yaml" || mime.suffix().map_or(false, |name| name == "yaml"));
128
129 is_yaml_content_type
130}
131
132impl<T> Deref for Yaml<T> {
133 type Target = T;
134
135 #[inline]
136 fn deref(&self) -> &Self::Target {
137 &self.0
138 }
139}
140
141impl<T> DerefMut for Yaml<T> {
142 #[inline]
143 fn deref_mut(&mut self) -> &mut Self::Target {
144 &mut self.0
145 }
146}
147
148impl<T> From<T> for Yaml<T> {
149 fn from(inner: T) -> Self {
150 Self(inner)
151 }
152}
153
154impl<T> Yaml<T>
155where
156 T: DeserializeOwned,
157{
158 pub fn from_bytes(bytes: &[u8]) -> Result<Self, YamlRejection> {
162 let deserializer = serde_yaml::Deserializer::from_slice(bytes);
163
164 match serde_path_to_error::deserialize(deserializer) {
165 Ok(value) => Ok(Yaml(value)),
166 Err(err) => Err(YamlError::from_err(err).into()),
167 }
168 }
169}
170
171impl<T> IntoResponse for Yaml<T>
172where
173 T: Serialize,
174{
175 fn into_response(self) -> Response {
176 let mut buf = BytesMut::with_capacity(128).writer();
179 match serde_yaml::to_writer(&mut buf, &self.0) {
180 Ok(()) => (
181 [(
182 header::CONTENT_TYPE,
183 HeaderValue::from_static("application/yaml"),
184 )],
185 buf.into_inner().freeze(),
186 )
187 .into_response(),
188 Err(err) => (
189 StatusCode::INTERNAL_SERVER_ERROR,
190 [(
191 header::CONTENT_TYPE,
192 HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()),
193 )],
194 err.to_string(),
195 )
196 .into_response(),
197 }
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 use axum::routing::post;
206 use axum::Router;
207 use http::StatusCode;
208 use serde::Deserialize;
209 use serde_yaml::Value;
210
211 use crate::test_client::TestClient;
212
213 #[tokio::test]
214 async fn deserialize_body() {
215 #[derive(Debug, Deserialize)]
216 struct Input {
217 foo: String,
218 }
219
220 let app = Router::new().route("/", post(|input: Yaml<Input>| async { input.0.foo }));
221
222 let client = TestClient::new(app);
223 let res = client
224 .post("/")
225 .body("foo: bar")
226 .header("content-type", "application/yaml")
227 .await;
228
229 let body = res.text().await;
230 assert_eq!(body, "bar");
231 }
232
233 #[tokio::test]
234 async fn consume_body_to_yaml_requres_yaml_content_type() {
235 #[derive(Debug, Deserialize)]
236 struct Input {
237 foo: String,
238 }
239
240 let app = Router::new().route("/", post(|input: Yaml<Input>| async { input.0.foo }));
241
242 let client = TestClient::new(app);
243 let res = client.post("/").body("foo: bar").await;
244
245 let status = res.status();
246 assert_eq!(status, StatusCode::UNSUPPORTED_MEDIA_TYPE);
247 }
248
249 #[tokio::test]
250 async fn yaml_content_types() {
251 async fn valid_yaml_content_type(content_type: &str) -> bool {
252 println!("testing {:?}", content_type);
253
254 let app = Router::new().route("/", post(|Yaml(_): Yaml<Value>| async {}));
255
256 let res = TestClient::new(app)
257 .post("/")
258 .header("content-type", content_type)
259 .body("foo: ")
260 .await;
261
262 res.status() == StatusCode::OK
263 }
264
265 assert!(valid_yaml_content_type("application/yaml").await);
266 assert!(valid_yaml_content_type("application/yaml;charset=utf-8").await);
267 assert!(valid_yaml_content_type("application/yaml; charset=utf-8").await);
268 assert!(!valid_yaml_content_type("text/yaml").await);
269 }
270
271 #[tokio::test]
272 async fn invalid_yaml_syntax() {
273 let app = Router::new().route("/", post(|_: Yaml<Value>| async {}));
274
275 let client = TestClient::new(app);
276 let res = client
277 .post("/")
278 .body("- a\nb:")
279 .header("content-type", "application/yaml")
280 .await;
281
282 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
283 }
284
285 #[derive(Deserialize)]
286 struct Foo {
287 #[allow(dead_code)]
288 a: i32,
289 #[allow(dead_code)]
290 b: Vec<Bar>,
291 }
292
293 #[derive(Deserialize)]
294 struct Bar {
295 #[allow(dead_code)]
296 x: i32,
297 #[allow(dead_code)]
298 y: i32,
299 }
300
301 #[tokio::test]
302 async fn invalid_yaml_data() {
303 let app = Router::new().route("/", post(|_: Yaml<Foo>| async {}));
304
305 let client = TestClient::new(app);
306 let res = client
307 .post("/")
308 .body("a: 1\nb:\n - x: 2")
309 .header("content-type", "application/yaml")
310 .await;
311
312 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
313 let body_text = res.text().await;
314 assert_eq!(
315 body_text,
316 "Failed to deserialize the YAML body into the target type: b[0]: b[0]: missing field `y` at line 3 column 7"
317 );
318 }
319}