axum_yaml/
yaml.rs

1use std::ops::{Deref, DerefMut};
2
3use axum_core::{
4    extract::{FromRequest, Request},
5    response::{IntoResponse, Response},
6};
7use bytes::{BufMut, Bytes, BytesMut};
8use http::{header, HeaderMap, HeaderValue, StatusCode};
9use serde::{de::DeserializeOwned, Serialize};
10
11use crate::rejection::*;
12
13/// YAML Extractor / Response.
14///
15/// When used as an extractor, it can deserialize request bodies into some type that
16/// implements [`serde::Deserialize`]. If the request body cannot be parsed, or it does not contain
17/// the `Content-Type: application/yaml` header, it will reject the request and return a
18/// `400 Bad Request` response.
19///
20/// # Extractor example
21///
22/// ```no_run
23/// use axum::{
24///     extract,
25///     routing::post,
26///     Router,
27/// };
28/// use axum_yaml::Yaml;
29/// use serde::Deserialize;
30///
31/// #[derive(Deserialize)]
32/// struct CreateUser {
33///     email: String,
34///     password: String,
35/// }
36///
37/// async fn create_user(Yaml(payload): Yaml<CreateUser>) {
38///     // payload is a `CreateUser`
39/// }
40///
41/// let app = Router::new().route("/users", post(create_user));
42/// # async {
43/// #   axum::serve(
44/// #       tokio::net::TcpListener::bind("").await.unwrap(),
45/// #       app.into_make_service(),
46/// #   )
47/// #   .await
48/// #   .unwrap();
49/// # };
50/// ```
51///
52/// When used as a response, it can serialize any type that implements [`serde::Serialize`] to
53/// `YAML`, and will automatically set `Content-Type: application/yaml` header.
54///
55/// # Response example
56///
57/// ```no_run
58/// use axum::{
59///     extract::Path,
60///     routing::get,
61///     Router,
62/// };
63/// use axum_yaml::Yaml;
64/// use serde::Serialize;
65/// use uuid::Uuid;
66///
67/// #[derive(Serialize)]
68/// struct User {
69///     id: Uuid,
70///     username: String,
71/// }
72///
73/// async fn get_user(Path(user_id) : Path<Uuid>) -> Yaml<User> {
74///     let user = find_user(user_id).await;
75///     Yaml(user)
76/// }
77///
78/// async fn find_user(user_id: Uuid) -> User {
79///     // ...
80///     # unimplemented!()
81/// }
82///
83/// let app = Router::new().route("/users/:id", get(get_user));
84/// # async {
85/// #   axum::serve(
86/// #       tokio::net::TcpListener::bind("").await.unwrap(),
87/// #       app.into_make_service(),
88/// #   )
89/// #   .await
90/// #   .unwrap();
91/// # };
92/// ```
93#[derive(Debug, Clone, Copy, Default)]
94pub struct Yaml<T>(pub T);
95
96impl<T, S> FromRequest<S> for Yaml<T>
97where
98    T: DeserializeOwned,
99    S: Send + Sync,
100{
101    type Rejection = YamlRejection;
102
103    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
104        if yaml_content_type(req.headers()) {
105            let bytes = Bytes::from_request(req, state).await?;
106            Self::from_bytes(&bytes)
107        } else {
108            Err(MissingYamlContentType.into())
109        }
110    }
111}
112
113fn yaml_content_type(headers: &HeaderMap) -> bool {
114    let Some(content_type) = headers.get(header::CONTENT_TYPE) else {
115        return false;
116    };
117
118    let Ok(content_type) = content_type.to_str() else {
119        return false;
120    };
121
122    let Ok(mime) = content_type.parse::<mime::Mime>() else {
123        return false;
124    };
125
126    let is_yaml_content_type = mime.type_() == "application"
127        && (mime.subtype() == "yaml" || mime.suffix().map_or(false, |name| name == "yaml"));
128
129    is_yaml_content_type
130}
131
132impl<T> Deref for Yaml<T> {
133    type Target = T;
134
135    #[inline]
136    fn deref(&self) -> &Self::Target {
137        &self.0
138    }
139}
140
141impl<T> DerefMut for Yaml<T> {
142    #[inline]
143    fn deref_mut(&mut self) -> &mut Self::Target {
144        &mut self.0
145    }
146}
147
148impl<T> From<T> for Yaml<T> {
149    fn from(inner: T) -> Self {
150        Self(inner)
151    }
152}
153
154impl<T> Yaml<T>
155where
156    T: DeserializeOwned,
157{
158    /// Construct a `Yaml<T>` from a byte slice. Most users should prefer to use the `FromRequest` impl
159    /// but special cases may require first extracting a `Request` into `Bytes` then optionally
160    /// constructing a `Yaml<T>`.
161    pub fn from_bytes(bytes: &[u8]) -> Result<Self, YamlRejection> {
162        let deserializer = serde_yaml::Deserializer::from_slice(bytes);
163
164        match serde_path_to_error::deserialize(deserializer) {
165            Ok(value) => Ok(Yaml(value)),
166            Err(err) => Err(YamlError::from_err(err).into()),
167        }
168    }
169}
170
171impl<T> IntoResponse for Yaml<T>
172where
173    T: Serialize,
174{
175    fn into_response(self) -> Response {
176        // Use a small initial capacity of 128 bytes like serde_json::to_vec
177        // https://docs.rs/serde_json/1.0.82/src/serde_json/ser.rs.html#2189
178        let mut buf = BytesMut::with_capacity(128).writer();
179        match serde_yaml::to_writer(&mut buf, &self.0) {
180            Ok(()) => (
181                [(
182                    header::CONTENT_TYPE,
183                    HeaderValue::from_static("application/yaml"),
184                )],
185                buf.into_inner().freeze(),
186            )
187                .into_response(),
188            Err(err) => (
189                StatusCode::INTERNAL_SERVER_ERROR,
190                [(
191                    header::CONTENT_TYPE,
192                    HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()),
193                )],
194                err.to_string(),
195            )
196                .into_response(),
197        }
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    use axum::routing::post;
206    use axum::Router;
207    use http::StatusCode;
208    use serde::Deserialize;
209    use serde_yaml::Value;
210
211    use crate::test_client::TestClient;
212
213    #[tokio::test]
214    async fn deserialize_body() {
215        #[derive(Debug, Deserialize)]
216        struct Input {
217            foo: String,
218        }
219
220        let app = Router::new().route("/", post(|input: Yaml<Input>| async { input.0.foo }));
221
222        let client = TestClient::new(app);
223        let res = client
224            .post("/")
225            .body("foo: bar")
226            .header("content-type", "application/yaml")
227            .await;
228
229        let body = res.text().await;
230        assert_eq!(body, "bar");
231    }
232
233    #[tokio::test]
234    async fn consume_body_to_yaml_requres_yaml_content_type() {
235        #[derive(Debug, Deserialize)]
236        struct Input {
237            foo: String,
238        }
239
240        let app = Router::new().route("/", post(|input: Yaml<Input>| async { input.0.foo }));
241
242        let client = TestClient::new(app);
243        let res = client.post("/").body("foo: bar").await;
244
245        let status = res.status();
246        assert_eq!(status, StatusCode::UNSUPPORTED_MEDIA_TYPE);
247    }
248
249    #[tokio::test]
250    async fn yaml_content_types() {
251        async fn valid_yaml_content_type(content_type: &str) -> bool {
252            println!("testing {:?}", content_type);
253
254            let app = Router::new().route("/", post(|Yaml(_): Yaml<Value>| async {}));
255
256            let res = TestClient::new(app)
257                .post("/")
258                .header("content-type", content_type)
259                .body("foo: ")
260                .await;
261
262            res.status() == StatusCode::OK
263        }
264
265        assert!(valid_yaml_content_type("application/yaml").await);
266        assert!(valid_yaml_content_type("application/yaml;charset=utf-8").await);
267        assert!(valid_yaml_content_type("application/yaml; charset=utf-8").await);
268        assert!(!valid_yaml_content_type("text/yaml").await);
269    }
270
271    #[tokio::test]
272    async fn invalid_yaml_syntax() {
273        let app = Router::new().route("/", post(|_: Yaml<Value>| async {}));
274
275        let client = TestClient::new(app);
276        let res = client
277            .post("/")
278            .body("- a\nb:")
279            .header("content-type", "application/yaml")
280            .await;
281
282        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
283    }
284
285    #[derive(Deserialize)]
286    struct Foo {
287        #[allow(dead_code)]
288        a: i32,
289        #[allow(dead_code)]
290        b: Vec<Bar>,
291    }
292
293    #[derive(Deserialize)]
294    struct Bar {
295        #[allow(dead_code)]
296        x: i32,
297        #[allow(dead_code)]
298        y: i32,
299    }
300
301    #[tokio::test]
302    async fn invalid_yaml_data() {
303        let app = Router::new().route("/", post(|_: Yaml<Foo>| async {}));
304
305        let client = TestClient::new(app);
306        let res = client
307            .post("/")
308            .body("a: 1\nb:\n    - x: 2")
309            .header("content-type", "application/yaml")
310            .await;
311
312        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
313        let body_text = res.text().await;
314        assert_eq!(
315            body_text,
316            "Failed to deserialize the YAML body into the target type: b[0]: b[0]: missing field `y` at line 3 column 7"
317        );
318    }
319}