axum_msgpack/
lib.rs

1#![forbid(unsafe_code)]
2
3use crate::rejection::{InvalidMsgPackBody, MissingMsgPackContentType};
4use axum::{
5    body::{Bytes, Body},
6    extract::{FromRequest, Request},
7    response::{IntoResponse, Response},
8    http::{header::HeaderValue, StatusCode},
9};
10use hyper::header;
11use rejection::MsgPackRejection;
12use serde::{de::DeserializeOwned, Serialize};
13use std::ops::{Deref, DerefMut};
14
15mod error;
16mod rejection;
17
18/// MessagePack Extractor / Response.
19///
20/// When used as an extractor, it can deserialize request bodies into some type that
21/// implements [`serde::Deserialize`]. If the request body cannot be parsed, or value of the
22/// `Content-Type` header does not match any of the `application/msgpack`, `application/x-msgpack`
23/// or `application/*+msgpack` it will reject the request and return a `400 Bad Request` response.
24///
25/// # Extractor example
26///
27/// ```no_run
28/// use axum::{
29///     routing::post,
30///     Router,
31/// };
32/// use axum_msgpack::MsgPack;
33/// use serde::Deserialize;
34///
35/// #[derive(Deserialize)]
36/// struct CreateUser {
37///     email: String,
38///     password: String,
39/// }
40///
41/// async fn create_user(MsgPack(payload): MsgPack<CreateUser>) {
42///     // payload is a `CreateUser`
43/// }
44///
45/// let app = Router::new().route("/users", post(create_user));
46/// # async {
47/// #   axum::serve(tokio::net::TcpListener::bind(&"").await.unwrap(), app.into_make_service()).await.unwrap();
48/// # };
49/// ```
50///
51/// When used as a response, it can serialize any type that implements [`serde::Serialize`] to
52/// `MsgPack`, and will automatically set `Content-Type: application/msgpack` header.
53///
54/// # Response example
55///
56/// ```no_run
57/// use axum::{
58///     extract::Path,
59///     routing::get,
60///     Router,
61/// };
62/// use axum_msgpack::MsgPack;
63/// use serde::Serialize;
64/// use uuid::Uuid;
65///
66/// #[derive(Serialize)]
67/// struct User {
68///     id: Uuid,
69///     username: String,
70/// }
71///
72/// async fn get_user(Path(user_id) : Path<Uuid>) -> MsgPack<User> {
73///     let user = find_user(user_id).await;
74///     MsgPack(user)
75/// }
76///
77/// async fn find_user(user_id: Uuid) -> User {
78///     // ...
79///     # unimplemented!()
80/// }
81///
82/// let app = Router::new().route("/users/:id", get(get_user));
83/// # async {
84/// #   axum::serve(tokio::net::TcpListener::bind(&"").await.unwrap(), app.into_make_service()).await.unwrap();
85/// # };
86/// # mod uuid {
87/// #   use serde::{Serialize, Deserialize};
88/// #   #[derive(Serialize, Deserialize)]
89/// #   pub struct Uuid;
90/// # }
91/// ```
92#[derive(Debug, Clone, Copy, Default)]
93pub struct MsgPack<T>(pub T);
94
95impl<T, S> FromRequest<S> for MsgPack<T>
96where
97    T: DeserializeOwned,
98    S: Send + Sync,
99{
100    type Rejection = MsgPackRejection;
101
102    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
103        if !message_pack_content_type(&req) {
104            return Err(MissingMsgPackContentType.into())
105        }
106        let bytes = Bytes::from_request(req, state).await?;
107        let value = rmp_serde::from_slice(&bytes).map_err(InvalidMsgPackBody::from_err)?;
108        Ok(MsgPack(value))
109    }
110}
111
112impl<T> Deref for MsgPack<T> {
113    type Target = T;
114
115    fn deref(&self) -> &Self::Target {
116        &self.0
117    }
118}
119
120impl<T> DerefMut for MsgPack<T> {
121    fn deref_mut(&mut self) -> &mut Self::Target {
122        &mut self.0
123    }
124}
125
126impl<T> From<T> for MsgPack<T> {
127    fn from(inner: T) -> Self {
128        Self(inner)
129    }
130}
131
132impl<T> IntoResponse for MsgPack<T>
133where
134    T: Serialize,
135{
136    fn into_response(self) -> Response {
137        let bytes = match rmp_serde::encode::to_vec_named(&self.0) {
138            Ok(res) => res,
139            Err(err) => {
140                return Response::builder()
141                    .status(StatusCode::INTERNAL_SERVER_ERROR)
142                    .header(header::CONTENT_TYPE, "text/plain")
143                    .body(Body::new(err.to_string()))
144                    .unwrap();
145            }
146        };
147
148        let mut res = bytes.into_response();
149
150        res.headers_mut().insert(
151            header::CONTENT_TYPE,
152            HeaderValue::from_static("application/msgpack"),
153        );
154        res
155    }
156}
157
158/// MessagePack Extractor / Response.
159///
160/// When used as an extractor, it can deserialize request bodies into some type that
161/// implements [`serde::Deserialize`]. If the request body cannot be parsed, or value of the
162/// `Content-Type` header does not match any of the `application/msgpack`, `application/x-msgpack`
163/// or `application/*+msgpack` it will reject the request and return a `400 Bad Request` response.
164///
165/// # Extractor example
166///
167/// ```no_run
168/// use axum::{
169///     routing::post,
170///     Router,
171/// };
172/// use axum_msgpack::MsgPackRaw;
173/// use serde::Deserialize;
174///
175/// #[derive(Deserialize)]
176/// struct CreateUser {
177///     email: String,
178///     password: String,
179/// }
180///
181/// async fn create_user(MsgPackRaw(payload): MsgPackRaw<CreateUser>) {
182///     // payload is a `CreateUser`
183/// }
184///
185/// let app = Router::new().route("/users", post(create_user));
186/// # async {
187/// #   axum::serve(tokio::net::TcpListener::bind(&"").await.unwrap(), app.into_make_service()).await.unwrap();
188/// # };
189/// ```
190///
191/// When used as a response, it can serialize any type that implements [`serde::Serialize`] to
192/// `MsgPackRaw`, and will automatically set `Content-Type: application/msgpack` header.
193///
194/// # Response example
195///
196/// ```no_run
197/// use axum::{
198///     extract::Path,
199///     routing::get,
200///     Router,
201/// };
202/// use axum_msgpack::MsgPackRaw;
203/// use serde::Serialize;
204/// use uuid::Uuid;
205///
206/// #[derive(Serialize)]
207/// struct User {
208///     id: Uuid,
209///     username: String,
210/// }
211///
212/// async fn get_user(Path(user_id) : Path<Uuid>) -> MsgPackRaw<User> {
213///     let user = find_user(user_id).await;
214///     MsgPackRaw(user)
215/// }
216///
217/// async fn find_user(user_id: Uuid) -> User {
218///     // ...
219///     # unimplemented!()
220/// }
221///
222/// let app = Router::new().route("/users/:id", get(get_user));
223/// # async {
224/// #   axum::serve(tokio::net::TcpListener::bind(&"").await.unwrap(), app.into_make_service()).await.unwrap();
225/// # };
226/// # mod uuid {
227/// #   use serde::{Serialize, Deserialize};
228/// #   #[derive(Serialize, Deserialize)]
229/// #   pub struct Uuid;
230/// # }
231/// ```
232#[derive(Debug, Clone, Copy, Default)]
233pub struct MsgPackRaw<T>(pub T);
234
235impl<T, S> FromRequest<S> for MsgPackRaw<T>
236where
237    T: DeserializeOwned,
238    S: Send + Sync,
239{
240    type Rejection = MsgPackRejection;
241
242    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
243        if !message_pack_content_type(&req) {
244            return Err(MissingMsgPackContentType.into())
245        }
246        let bytes = Bytes::from_request(req, state).await?;
247        let value = rmp_serde::from_slice(&bytes).map_err(InvalidMsgPackBody::from_err)?;
248        Ok(MsgPackRaw(value))
249    }
250}
251
252impl<T> Deref for MsgPackRaw<T> {
253    type Target = T;
254
255    fn deref(&self) -> &Self::Target {
256        &self.0
257    }
258}
259
260impl<T> DerefMut for MsgPackRaw<T> {
261    fn deref_mut(&mut self) -> &mut Self::Target {
262        &mut self.0
263    }
264}
265
266impl<T> From<T> for MsgPackRaw<T> {
267    fn from(inner: T) -> Self {
268        Self(inner)
269    }
270}
271
272impl<T> IntoResponse for MsgPackRaw<T>
273where
274    T: Serialize,
275{
276    fn into_response(self) -> Response {
277        let bytes = match rmp_serde::encode::to_vec(&self.0) {
278            Ok(res) => res,
279            Err(err) => {
280                return Response::builder()
281                    .status(StatusCode::INTERNAL_SERVER_ERROR)
282                    .header(header::CONTENT_TYPE, "text/plain")
283                    .body(Body::new(err.to_string()))
284                    .unwrap();
285            }
286        };
287
288        let mut res = bytes.into_response();
289
290        res.headers_mut().insert(
291            header::CONTENT_TYPE,
292            HeaderValue::from_static("application/msgpack"),
293        );
294        res
295    }
296}
297
298fn message_pack_content_type<B>(req: &Request<B>) -> bool {
299    let Some(content_type) = req.headers().get(header::CONTENT_TYPE) else {
300        return false;
301    };
302    let  Ok(content_type) = content_type.to_str() else {
303        return false;
304    };
305    let Ok(mime) = content_type.parse::<mime::Mime>() else {
306        return false;
307    };
308
309    let is_message_pack = mime.type_() == "application"
310        && (["msgpack", "x-msgpack"]
311            .iter()
312            .any(|subtype| *subtype == mime.subtype())
313            || mime.suffix().map_or(false, |suffix| suffix == "msgpack"));
314
315    is_message_pack
316}
317
318#[cfg(test)]
319mod tests {
320    use axum::{
321        body::Body,
322        extract::FromRequest,
323        http::HeaderValue,
324        response::IntoResponse,
325    };
326    use futures_util::StreamExt;
327
328    use crate::{MsgPack, MsgPackRaw, MsgPackRejection};
329    use hyper::{header, Request};
330    use serde::{Deserialize, Serialize};
331
332    #[derive(Debug, Serialize, Deserialize, PartialEq)]
333    struct Input {
334        foo: String,
335    }
336
337    fn into_request<T: Serialize>(value: &T) -> Request<Body> {
338        let serialized =
339            rmp_serde::encode::to_vec_named(&value).expect("Failed to serialize test struct");
340
341        let body = Body::from(serialized);
342        Request::new(body)
343    }
344
345    fn into_request_raw<T: Serialize>(value: &T) -> Request<Body> {
346        let serialized =
347            rmp_serde::encode::to_vec(&value).expect("Failed to serialize test struct");
348
349        let body = Body::from(serialized);
350        Request::new(body)
351    }
352
353    #[tokio::test]
354    async fn serializes_named() {
355        let input = Input { foo: "bar".into() };
356        let serialized = rmp_serde::encode::to_vec_named(&input);
357        assert!(serialized.is_ok());
358        let serialized = serialized.unwrap();
359
360        let body = MsgPack(input).into_response().into_body();
361        let bytes = to_bytes(body).await;
362
363        assert_eq!(serialized, bytes);
364    }
365
366    #[tokio::test]
367    async fn deserializes_named() {
368        let input = Input { foo: "bar".into() };
369        let mut request = into_request(&input);
370
371        request.headers_mut().insert(
372            header::CONTENT_TYPE,
373            HeaderValue::from_static("application/msgpack"),
374        );
375
376        let outcome =
377            <MsgPack<Input> as FromRequest<_, _>>::from_request(request, &||{}).await;
378
379        let outcome = outcome.unwrap();
380        assert_eq!(input, outcome.0);
381    }
382
383    #[tokio::test]
384    async fn serializes_raw() {
385        let input = Input { foo: "bar".into() };
386        let serialized = rmp_serde::encode::to_vec(&input);
387        assert!(serialized.is_ok());
388        let serialized = serialized.unwrap();
389
390        let body = MsgPackRaw(input).into_response().into_body();
391        let bytes = to_bytes(body).await;
392
393        assert_eq!(serialized, bytes);
394    }
395
396    #[tokio::test]
397    async fn deserializes_raw() {
398        let input = Input { foo: "bar".into() };
399        let mut request = into_request_raw(&input);
400
401        request.headers_mut().insert(
402            header::CONTENT_TYPE,
403            HeaderValue::from_static("application/msgpack"),
404        );
405
406        let outcome =
407            <MsgPackRaw<Input> as FromRequest<_, _>>::from_request(request, &||{})
408                .await;
409
410        let outcome = outcome.unwrap();
411        assert_eq!(input, outcome.0);
412    }
413
414    #[tokio::test]
415    async fn supported_content_type() {
416        let input = Input { foo: "bar".into() };
417        let mut request = into_request(&input);
418        request.headers_mut().insert(
419            header::CONTENT_TYPE,
420            HeaderValue::from_static("application/msgpack"),
421        );
422
423        let outcome =
424            <MsgPack<Input> as FromRequest<_, _>>::from_request(request, &||{}).await;
425        assert!(outcome.is_ok());
426
427        let mut request = into_request(&input);
428        request.headers_mut().insert(
429            header::CONTENT_TYPE,
430            HeaderValue::from_static("application/cloudevents+msgpack"),
431        );
432
433        let outcome =
434            <MsgPack<Input> as FromRequest<_, _>>::from_request(request, &||{}).await;
435        assert!(outcome.is_ok());
436
437        let mut request = into_request(&input);
438        request.headers_mut().insert(
439            header::CONTENT_TYPE,
440            HeaderValue::from_static("application/x-msgpack"),
441        );
442
443        let outcome =
444            <MsgPack<Input> as FromRequest<_, _>>::from_request(request, &||{}).await;
445        assert!(outcome.is_ok());
446
447        let request = into_request(&input);
448        let outcome =
449            <MsgPack<Input> as FromRequest<_, _>>::from_request(request, &||{}).await;
450
451        match outcome {
452            Err(MsgPackRejection::MissingMsgPackContentType(_)) => {}
453            other => unreachable!(
454                "Expected missing MsgPack content type rejection, got: {:?}",
455                other
456            ),
457        }
458    }
459
460    async fn to_bytes(body: Body) -> Vec<u8> {
461        let mut buffer = Vec::new();
462        let mut stream = body.into_data_stream();
463
464        while let Some(bytes) = stream.next().await {
465            buffer.extend(bytes.unwrap().into_iter());
466        }
467
468        buffer
469    }
470}