1use postcard::{from_bytes, to_allocvec};
3use serde::{de::DeserializeOwned, Serialize};
4use async_trait::async_trait;
6use axum::{
7 body::{Body, Bytes},
8 extract::{rejection::BytesRejection, FromRequest},
9 http::{header, HeaderMap, Request, StatusCode},
10 response::{IntoResponse, Response},
11};
12
13pub struct Postcard<T>(pub T);
98
99#[derive(thiserror::Error, Debug)]
100pub enum PostcardRejection {
101 #[error("Expected request with `Content-Type: application/postcard`")]
102 MissingPostcardContentType,
103 #[error(transparent)]
104 PostcardError(#[from] postcard::Error),
105 #[error(transparent)]
106 Bytes(#[from] BytesRejection),
107}
108
109impl IntoResponse for PostcardRejection {
110 fn into_response(self) -> Response {
111 use PostcardRejection::*;
112 match self {
114 MissingPostcardContentType => {
115 (StatusCode::UNSUPPORTED_MEDIA_TYPE, self.to_string()).into_response()
116 }
117 PostcardError(err) => (StatusCode::BAD_REQUEST, err.to_string()).into_response(),
118 _ => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response(),
119 }
120 }
121}
122
123#[async_trait]
124impl<T, S> FromRequest<S> for Postcard<T>
125where
126 T: DeserializeOwned,
127 S: Send + Sync,
128{
129 type Rejection = PostcardRejection;
130
131 async fn from_request(req: Request<Body>, state: &S) -> Result<Self, Self::Rejection> {
132 if postcard_content_type(req.headers()) {
133 let bytes = Bytes::from_request(req, state).await?;
134
135 let value = match from_bytes(&*bytes) {
136 Ok(value) => value,
137 Err(err) => return Err(PostcardRejection::PostcardError(err)),
138 };
139 Ok(Postcard(value))
140 } else {
141 Err(PostcardRejection::MissingPostcardContentType)
142 }
143 }
144}
145
146fn postcard_content_type(headers: &HeaderMap) -> bool {
147 let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
148 content_type
149 } else {
150 return false;
151 };
152
153 let content_type = if let Ok(content_type) = content_type.to_str() {
154 content_type
155 } else {
156 return false;
157 };
158
159 let mime = if let Ok(mime) = content_type.parse::<mime::Mime>() {
160 mime
161 } else {
162 return false;
163 };
164
165 let is_postcard_content_type = mime.type_() == "application"
166 && (mime.subtype() == "postcard" || mime.suffix().map_or(false, |name| name == "postcard"));
167
168 is_postcard_content_type
169}
170
171impl<T> IntoResponse for Postcard<T>
172where
173 T: Serialize,
174{
175 fn into_response(self) -> Response {
176 match to_allocvec(&self.0) {
178 Ok(value) => ([(header::CONTENT_TYPE, "application/postcard")], value).into_response(),
179 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
180 }
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use axum::{body::to_bytes, routing::post, Router};
188 use axum_test_helpers::*;
189 use serde::Deserialize;
190
191 #[tokio::test]
192 async fn deserialize_body() {
193 #[derive(Debug, Deserialize, Serialize)]
194 struct Input {
195 foo: String,
196 }
197
198 let app = Router::new().route("/", post(|input: Postcard<Input>| async { input.0.foo }));
199
200 let client = TestClient::new(app);
201
202 let res = client
203 .post("/")
204 .header("content-type", "application/postcard")
205 .body("\x03bar")
206 .await;
207 let body = res.text().await;
208
209 assert_eq!(body, "bar");
210 }
211
212 #[tokio::test]
213 async fn consume_body_to_postcard_requires_postcard_content_type() {
214 #[derive(Debug, Deserialize)]
215 struct Input {
216 foo: String,
217 }
218
219 let app = Router::new().route("/", post(|input: Postcard<Input>| async { input.0.foo }));
220
221 let client = TestClient::new(app);
222 let res = client.post("/").body("\x03bar").await;
223
224 let status = res.status();
225
226 assert_eq!(status, StatusCode::UNSUPPORTED_MEDIA_TYPE);
227 }
228
229 #[tokio::test]
230 async fn postcard_content_types() {
231 async fn valid_postcard_content_type(content_type: &str) -> bool {
232 println!("testing {content_type:?}");
233
234 let app = Router::new().route("/", post(|Postcard(_): Postcard<String>| async {}));
235
236 let res = TestClient::new(app)
237 .post("/")
238 .header("content-type", content_type)
239 .body("\x02hi")
240 .await;
241
242 res.status() == StatusCode::OK
243 }
244
245 assert!(valid_postcard_content_type("application/postcard").await);
246 assert!(valid_postcard_content_type("application/postcard; charset=utf-8").await);
247 assert!(valid_postcard_content_type("application/postcard;charset=utf-8").await);
248 assert!(valid_postcard_content_type("application/cloudevents+postcard").await);
249 assert!(!valid_postcard_content_type("text/postcard").await);
250 }
251
252 #[tokio::test]
253 async fn invalid_postcard_syntax() {
254 let app = Router::new().route("/", post(|_: Postcard<String>| async {}));
255
256 let client = TestClient::new(app);
257 let res = client
258 .post("/")
259 .body("\x03")
260 .header("content-type", "application/postcard")
261 .await;
262
263 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
264 }
265
266 #[derive(Deserialize)]
267 struct Foo {
268 #[allow(dead_code)]
269 a: i32,
270 #[allow(dead_code)]
271 b: Vec<Bar>,
272 }
273
274 #[derive(Deserialize)]
275 struct Bar {
276 #[allow(dead_code)]
277 x: i32,
278 #[allow(dead_code)]
279 y: i32,
280 }
281
282 #[tokio::test]
283 async fn invalid_postcard_data() {
284 let app = Router::new().route("/", post(|_: Postcard<Foo>| async {}));
285
286 let client = TestClient::new(app);
287 let res = client
288 .post("/")
289 .header("content-type", "application/postcard")
290 .body("\x02\x01\x04")
291 .await;
292
293 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
294 let body_text = res.text().await;
295 assert_eq!(body_text, "Hit the end of buffer, expected more data");
296 }
297
298 #[tokio::test]
299 async fn serialize_response() {
300 let response = Postcard("bar").into_response();
301
302 assert!(postcard_content_type(response.headers()));
303 let bytes = &to_bytes(response.into_body(), 4).await.unwrap()[..];
304
305 assert_eq!(bytes, b"\x03bar");
306 }
307}