1#![forbid(unsafe_code)]
2
3use crate::rejection::{InvalidMsgPackBody, MissingMsgPackContentType};
4use axum::{
5 body::{Bytes, Body},
6 extract::{FromRequest, Request},
7 response::{IntoResponse, Response},
8 http::{header::HeaderValue, StatusCode},
9};
10use hyper::header;
11use rejection::MsgPackRejection;
12use serde::{de::DeserializeOwned, Serialize};
13use std::ops::{Deref, DerefMut};
14
15mod error;
16mod rejection;
17
18#[derive(Debug, Clone, Copy, Default)]
93pub struct MsgPack<T>(pub T);
94
95impl<T, S> FromRequest<S> for MsgPack<T>
96where
97 T: DeserializeOwned,
98 S: Send + Sync,
99{
100 type Rejection = MsgPackRejection;
101
102 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
103 if !message_pack_content_type(&req) {
104 return Err(MissingMsgPackContentType.into())
105 }
106 let bytes = Bytes::from_request(req, state).await?;
107 let value = rmp_serde::from_slice(&bytes).map_err(InvalidMsgPackBody::from_err)?;
108 Ok(MsgPack(value))
109 }
110}
111
112impl<T> Deref for MsgPack<T> {
113 type Target = T;
114
115 fn deref(&self) -> &Self::Target {
116 &self.0
117 }
118}
119
120impl<T> DerefMut for MsgPack<T> {
121 fn deref_mut(&mut self) -> &mut Self::Target {
122 &mut self.0
123 }
124}
125
126impl<T> From<T> for MsgPack<T> {
127 fn from(inner: T) -> Self {
128 Self(inner)
129 }
130}
131
132impl<T> IntoResponse for MsgPack<T>
133where
134 T: Serialize,
135{
136 fn into_response(self) -> Response {
137 let bytes = match rmp_serde::encode::to_vec_named(&self.0) {
138 Ok(res) => res,
139 Err(err) => {
140 return Response::builder()
141 .status(StatusCode::INTERNAL_SERVER_ERROR)
142 .header(header::CONTENT_TYPE, "text/plain")
143 .body(Body::new(err.to_string()))
144 .unwrap();
145 }
146 };
147
148 let mut res = bytes.into_response();
149
150 res.headers_mut().insert(
151 header::CONTENT_TYPE,
152 HeaderValue::from_static("application/msgpack"),
153 );
154 res
155 }
156}
157
158#[derive(Debug, Clone, Copy, Default)]
233pub struct MsgPackRaw<T>(pub T);
234
235impl<T, S> FromRequest<S> for MsgPackRaw<T>
236where
237 T: DeserializeOwned,
238 S: Send + Sync,
239{
240 type Rejection = MsgPackRejection;
241
242 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
243 if !message_pack_content_type(&req) {
244 return Err(MissingMsgPackContentType.into())
245 }
246 let bytes = Bytes::from_request(req, state).await?;
247 let value = rmp_serde::from_slice(&bytes).map_err(InvalidMsgPackBody::from_err)?;
248 Ok(MsgPackRaw(value))
249 }
250}
251
252impl<T> Deref for MsgPackRaw<T> {
253 type Target = T;
254
255 fn deref(&self) -> &Self::Target {
256 &self.0
257 }
258}
259
260impl<T> DerefMut for MsgPackRaw<T> {
261 fn deref_mut(&mut self) -> &mut Self::Target {
262 &mut self.0
263 }
264}
265
266impl<T> From<T> for MsgPackRaw<T> {
267 fn from(inner: T) -> Self {
268 Self(inner)
269 }
270}
271
272impl<T> IntoResponse for MsgPackRaw<T>
273where
274 T: Serialize,
275{
276 fn into_response(self) -> Response {
277 let bytes = match rmp_serde::encode::to_vec(&self.0) {
278 Ok(res) => res,
279 Err(err) => {
280 return Response::builder()
281 .status(StatusCode::INTERNAL_SERVER_ERROR)
282 .header(header::CONTENT_TYPE, "text/plain")
283 .body(Body::new(err.to_string()))
284 .unwrap();
285 }
286 };
287
288 let mut res = bytes.into_response();
289
290 res.headers_mut().insert(
291 header::CONTENT_TYPE,
292 HeaderValue::from_static("application/msgpack"),
293 );
294 res
295 }
296}
297
298fn message_pack_content_type<B>(req: &Request<B>) -> bool {
299 let Some(content_type) = req.headers().get(header::CONTENT_TYPE) else {
300 return false;
301 };
302 let Ok(content_type) = content_type.to_str() else {
303 return false;
304 };
305 let Ok(mime) = content_type.parse::<mime::Mime>() else {
306 return false;
307 };
308
309 let is_message_pack = mime.type_() == "application"
310 && (["msgpack", "x-msgpack"]
311 .iter()
312 .any(|subtype| *subtype == mime.subtype())
313 || mime.suffix().map_or(false, |suffix| suffix == "msgpack"));
314
315 is_message_pack
316}
317
318#[cfg(test)]
319mod tests {
320 use axum::{
321 body::Body,
322 extract::FromRequest,
323 http::HeaderValue,
324 response::IntoResponse,
325 };
326 use futures_util::StreamExt;
327
328 use crate::{MsgPack, MsgPackRaw, MsgPackRejection};
329 use hyper::{header, Request};
330 use serde::{Deserialize, Serialize};
331
332 #[derive(Debug, Serialize, Deserialize, PartialEq)]
333 struct Input {
334 foo: String,
335 }
336
337 fn into_request<T: Serialize>(value: &T) -> Request<Body> {
338 let serialized =
339 rmp_serde::encode::to_vec_named(&value).expect("Failed to serialize test struct");
340
341 let body = Body::from(serialized);
342 Request::new(body)
343 }
344
345 fn into_request_raw<T: Serialize>(value: &T) -> Request<Body> {
346 let serialized =
347 rmp_serde::encode::to_vec(&value).expect("Failed to serialize test struct");
348
349 let body = Body::from(serialized);
350 Request::new(body)
351 }
352
353 #[tokio::test]
354 async fn serializes_named() {
355 let input = Input { foo: "bar".into() };
356 let serialized = rmp_serde::encode::to_vec_named(&input);
357 assert!(serialized.is_ok());
358 let serialized = serialized.unwrap();
359
360 let body = MsgPack(input).into_response().into_body();
361 let bytes = to_bytes(body).await;
362
363 assert_eq!(serialized, bytes);
364 }
365
366 #[tokio::test]
367 async fn deserializes_named() {
368 let input = Input { foo: "bar".into() };
369 let mut request = into_request(&input);
370
371 request.headers_mut().insert(
372 header::CONTENT_TYPE,
373 HeaderValue::from_static("application/msgpack"),
374 );
375
376 let outcome =
377 <MsgPack<Input> as FromRequest<_, _>>::from_request(request, &||{}).await;
378
379 let outcome = outcome.unwrap();
380 assert_eq!(input, outcome.0);
381 }
382
383 #[tokio::test]
384 async fn serializes_raw() {
385 let input = Input { foo: "bar".into() };
386 let serialized = rmp_serde::encode::to_vec(&input);
387 assert!(serialized.is_ok());
388 let serialized = serialized.unwrap();
389
390 let body = MsgPackRaw(input).into_response().into_body();
391 let bytes = to_bytes(body).await;
392
393 assert_eq!(serialized, bytes);
394 }
395
396 #[tokio::test]
397 async fn deserializes_raw() {
398 let input = Input { foo: "bar".into() };
399 let mut request = into_request_raw(&input);
400
401 request.headers_mut().insert(
402 header::CONTENT_TYPE,
403 HeaderValue::from_static("application/msgpack"),
404 );
405
406 let outcome =
407 <MsgPackRaw<Input> as FromRequest<_, _>>::from_request(request, &||{})
408 .await;
409
410 let outcome = outcome.unwrap();
411 assert_eq!(input, outcome.0);
412 }
413
414 #[tokio::test]
415 async fn supported_content_type() {
416 let input = Input { foo: "bar".into() };
417 let mut request = into_request(&input);
418 request.headers_mut().insert(
419 header::CONTENT_TYPE,
420 HeaderValue::from_static("application/msgpack"),
421 );
422
423 let outcome =
424 <MsgPack<Input> as FromRequest<_, _>>::from_request(request, &||{}).await;
425 assert!(outcome.is_ok());
426
427 let mut request = into_request(&input);
428 request.headers_mut().insert(
429 header::CONTENT_TYPE,
430 HeaderValue::from_static("application/cloudevents+msgpack"),
431 );
432
433 let outcome =
434 <MsgPack<Input> as FromRequest<_, _>>::from_request(request, &||{}).await;
435 assert!(outcome.is_ok());
436
437 let mut request = into_request(&input);
438 request.headers_mut().insert(
439 header::CONTENT_TYPE,
440 HeaderValue::from_static("application/x-msgpack"),
441 );
442
443 let outcome =
444 <MsgPack<Input> as FromRequest<_, _>>::from_request(request, &||{}).await;
445 assert!(outcome.is_ok());
446
447 let request = into_request(&input);
448 let outcome =
449 <MsgPack<Input> as FromRequest<_, _>>::from_request(request, &||{}).await;
450
451 match outcome {
452 Err(MsgPackRejection::MissingMsgPackContentType(_)) => {}
453 other => unreachable!(
454 "Expected missing MsgPack content type rejection, got: {:?}",
455 other
456 ),
457 }
458 }
459
460 async fn to_bytes(body: Body) -> Vec<u8> {
461 let mut buffer = Vec::new();
462 let mut stream = body.into_data_stream();
463
464 while let Some(bytes) = stream.next().await {
465 buffer.extend(bytes.unwrap().into_iter());
466 }
467
468 buffer
469 }
470}