axum_postcard/
lib.rs

1// postcard deps
2use postcard::{from_bytes, to_allocvec};
3use serde::{de::DeserializeOwned, Serialize};
4// axum deps
5use async_trait::async_trait;
6use axum::{
7    body::{Body, Bytes},
8    extract::{rejection::BytesRejection, FromRequest},
9    http::{header, HeaderMap, Request, StatusCode},
10    response::{IntoResponse, Response},
11};
12
13/// Postcard Extractor / Response.
14///
15/// When used as an extractor, it can deserialize request bodies into some type that
16/// implements [`serde::Deserialize`]. The request will be rejected (and a [`PostcardRejection`] will
17/// be returned) if:
18///
19/// - The request doesn't have a `Content-Type: application/postcard` (or similar) header.
20/// - The body doesn't contain syntactically valid Postcard.
21/// - The body contains syntactically valid Postcard but it couldn't be deserialized into the target
22/// type.
23/// - Buffering the request body fails.
24///
25/// ⚠️ Since parsing Postcard requires consuming the request body, the `Postcard` extractor must be
26/// *last* if there are multiple extractors in a handler.
27/// See ["the order of extractors"][order-of-extractors]
28///
29/// [order-of-extractors]: crate::extract#the-order-of-extractors
30///
31/// See [`PostcardRejection`] for more details.
32///
33/// # Extractor example
34///
35/// ```rust,no_run
36/// use axum::{
37///     extract,
38///     routing::post,
39///     Router,
40/// };
41/// use serde::Deserialize;
42/// use axum_postcard::Postcard;
43///
44/// #[derive(Deserialize)]
45/// struct CreateUser {
46///     email: String,
47///     password: String,
48/// }
49///
50/// async fn create_user(Postcard(payload): Postcard<CreateUser>) {
51///     // payload is a `CreateUser`
52///     todo!()
53/// }
54///
55/// let app = Router::new().route("/users", post(create_user));
56/// # async {
57/// # let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
58/// # axum::serve(listener, app).await.unwrap();
59/// # };
60/// ```
61///
62/// When used as a response, it can serialize any type that implements [`serde::Serialize`] to
63/// `Postcard`, and will automatically set `Content-Type: application/postcard` header.
64///
65/// # Response example
66///
67/// ```
68/// use axum::{
69///     extract::Path,
70///     routing::get,
71///     Router,
72/// };
73/// use serde::Serialize;
74/// use axum_postcard::Postcard;
75///
76/// #[derive(Serialize)]
77/// struct User {
78///     id: u32,
79///     username: String,
80/// }
81///
82/// async fn get_user(Path(user_id) : Path<u32>) -> Postcard<User> {
83///     let user = find_user(user_id).await;
84///     Postcard(user)
85/// }
86///
87/// async fn find_user(user_id: u32) -> User {
88///     todo!()
89/// }
90///
91/// let app = Router::new().route("/users/:id", get(get_user));
92/// # async {
93/// # let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
94/// # axum::serve(listener, app).await.unwrap();
95/// # };
96/// ```
97pub struct Postcard<T>(pub T);
98
99#[derive(thiserror::Error, Debug)]
100pub enum PostcardRejection {
101    #[error("Expected request with `Content-Type: application/postcard`")]
102    MissingPostcardContentType,
103    #[error(transparent)]
104    PostcardError(#[from] postcard::Error),
105    #[error(transparent)]
106    Bytes(#[from] BytesRejection),
107}
108
109impl IntoResponse for PostcardRejection {
110    fn into_response(self) -> Response {
111        use PostcardRejection::*;
112        // its often easiest to implement `IntoResponse` by calling other implementations
113        match self {
114            MissingPostcardContentType => {
115                (StatusCode::UNSUPPORTED_MEDIA_TYPE, self.to_string()).into_response()
116            }
117            PostcardError(err) => (StatusCode::BAD_REQUEST, err.to_string()).into_response(),
118            _ => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response(),
119        }
120    }
121}
122
123#[async_trait]
124impl<T, S> FromRequest<S> for Postcard<T>
125where
126    T: DeserializeOwned,
127    S: Send + Sync,
128{
129    type Rejection = PostcardRejection;
130
131    async fn from_request(req: Request<Body>, state: &S) -> Result<Self, Self::Rejection> {
132        if postcard_content_type(req.headers()) {
133            let bytes = Bytes::from_request(req, state).await?;
134
135            let value = match from_bytes(&*bytes) {
136                Ok(value) => value,
137                Err(err) => return Err(PostcardRejection::PostcardError(err)),
138            };
139            Ok(Postcard(value))
140        } else {
141            Err(PostcardRejection::MissingPostcardContentType)
142        }
143    }
144}
145
146fn postcard_content_type(headers: &HeaderMap) -> bool {
147    let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
148        content_type
149    } else {
150        return false;
151    };
152
153    let content_type = if let Ok(content_type) = content_type.to_str() {
154        content_type
155    } else {
156        return false;
157    };
158
159    let mime = if let Ok(mime) = content_type.parse::<mime::Mime>() {
160        mime
161    } else {
162        return false;
163    };
164
165    let is_postcard_content_type = mime.type_() == "application"
166        && (mime.subtype() == "postcard" || mime.suffix().map_or(false, |name| name == "postcard"));
167
168    is_postcard_content_type
169}
170
171impl<T> IntoResponse for Postcard<T>
172where
173    T: Serialize,
174{
175    fn into_response(self) -> Response {
176        // TODO: maybe use 128 bytes cause serde is doing something like that
177        match to_allocvec(&self.0) {
178            Ok(value) => ([(header::CONTENT_TYPE, "application/postcard")], value).into_response(),
179            Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
180        }
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use axum::{body::to_bytes, routing::post, Router};
188    use axum_test_helpers::*;
189    use serde::Deserialize;
190
191    #[tokio::test]
192    async fn deserialize_body() {
193        #[derive(Debug, Deserialize, Serialize)]
194        struct Input {
195            foo: String,
196        }
197
198        let app = Router::new().route("/", post(|input: Postcard<Input>| async { input.0.foo }));
199
200        let client = TestClient::new(app);
201
202        let res = client
203            .post("/")
204            .header("content-type", "application/postcard")
205            .body("\x03bar")
206            .await;
207        let body = res.text().await;
208
209        assert_eq!(body, "bar");
210    }
211
212    #[tokio::test]
213    async fn consume_body_to_postcard_requires_postcard_content_type() {
214        #[derive(Debug, Deserialize)]
215        struct Input {
216            foo: String,
217        }
218
219        let app = Router::new().route("/", post(|input: Postcard<Input>| async { input.0.foo }));
220
221        let client = TestClient::new(app);
222        let res = client.post("/").body("\x03bar").await;
223
224        let status = res.status();
225
226        assert_eq!(status, StatusCode::UNSUPPORTED_MEDIA_TYPE);
227    }
228
229    #[tokio::test]
230    async fn postcard_content_types() {
231        async fn valid_postcard_content_type(content_type: &str) -> bool {
232            println!("testing {content_type:?}");
233
234            let app = Router::new().route("/", post(|Postcard(_): Postcard<String>| async {}));
235
236            let res = TestClient::new(app)
237                .post("/")
238                .header("content-type", content_type)
239                .body("\x02hi")
240                .await;
241
242            res.status() == StatusCode::OK
243        }
244
245        assert!(valid_postcard_content_type("application/postcard").await);
246        assert!(valid_postcard_content_type("application/postcard; charset=utf-8").await);
247        assert!(valid_postcard_content_type("application/postcard;charset=utf-8").await);
248        assert!(valid_postcard_content_type("application/cloudevents+postcard").await);
249        assert!(!valid_postcard_content_type("text/postcard").await);
250    }
251
252    #[tokio::test]
253    async fn invalid_postcard_syntax() {
254        let app = Router::new().route("/", post(|_: Postcard<String>| async {}));
255
256        let client = TestClient::new(app);
257        let res = client
258            .post("/")
259            .body("\x03")
260            .header("content-type", "application/postcard")
261            .await;
262
263        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
264    }
265
266    #[derive(Deserialize)]
267    struct Foo {
268        #[allow(dead_code)]
269        a: i32,
270        #[allow(dead_code)]
271        b: Vec<Bar>,
272    }
273
274    #[derive(Deserialize)]
275    struct Bar {
276        #[allow(dead_code)]
277        x: i32,
278        #[allow(dead_code)]
279        y: i32,
280    }
281
282    #[tokio::test]
283    async fn invalid_postcard_data() {
284        let app = Router::new().route("/", post(|_: Postcard<Foo>| async {}));
285
286        let client = TestClient::new(app);
287        let res = client
288            .post("/")
289            .header("content-type", "application/postcard")
290            .body("\x02\x01\x04")
291            .await;
292
293        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
294        let body_text = res.text().await;
295        assert_eq!(body_text, "Hit the end of buffer, expected more data");
296    }
297
298    #[tokio::test]
299    async fn serialize_response() {
300        let response = Postcard("bar").into_response();
301
302        assert!(postcard_content_type(response.headers()));
303        let bytes = &to_bytes(response.into_body(), 4).await.unwrap()[..];
304
305        assert_eq!(bytes, b"\x03bar");
306    }
307}