actix_msgpack/
msgpack_message.rs

1use crate::{ContentTypeHandler, MsgPackError, DEFAULT_PAYLOAD_LIMIT};
2use actix_web::{
3	dev::Payload, error::PayloadError, http::header::CONTENT_LENGTH, web::BytesMut, HttpMessage,
4	HttpRequest,
5};
6use futures_util::{future::LocalBoxFuture, stream::StreamExt, FutureExt};
7use mime::APPLICATION_MSGPACK;
8use serde::de::DeserializeOwned;
9use std::{
10	future::Future,
11	io,
12	pin::Pin,
13	task::{self, Poll},
14};
15
16pub struct MsgPackMessage<T> {
17	limit: usize,
18	length: Option<usize>,
19	stream: Option<Payload>,
20	err: Option<MsgPackError>,
21	fut: Option<LocalBoxFuture<'static, Result<T, MsgPackError>>>,
22}
23
24impl<T> MsgPackMessage<T> {
25	pub fn new(
26		req: &HttpRequest,
27		payload: &mut Payload,
28		content_type_fn: Option<ContentTypeHandler>,
29	) -> Self {
30		// Check content-type header
31		let can_parse = if let Ok(Some(mime_type)) = req.mime_type() {
32			if let Some(predicate) = content_type_fn {
33				predicate(mime_type)
34			} else {
35				mime_type == APPLICATION_MSGPACK
36			}
37		} else {
38			false
39		};
40
41		if !can_parse {
42			return MsgPackMessage {
43				limit: DEFAULT_PAYLOAD_LIMIT,
44				length: None,
45				stream: None,
46				fut: None,
47				err: Some(MsgPackError::ContentType),
48			};
49		}
50
51		let mut length = None;
52
53		if let Some(content_length) = req.headers().get(CONTENT_LENGTH) {
54			if let Ok(string) = content_length.to_str() {
55				if let Ok(l) = string.parse::<usize>() {
56					length = Some(l)
57				}
58			}
59		}
60
61		MsgPackMessage {
62			limit: DEFAULT_PAYLOAD_LIMIT,
63			length,
64			stream: Some(payload.take()),
65			fut: None,
66			err: None,
67		}
68	}
69
70	/// Set maximum accepted payload size in bytes
71	pub fn limit(mut self, limit: usize) -> Self {
72		self.limit = limit;
73		self
74	}
75}
76
77impl<T: DeserializeOwned + 'static> Future for MsgPackMessage<T> {
78	type Output = Result<T, MsgPackError>;
79
80	fn poll(mut self: Pin<&mut Self>, task: &mut task::Context<'_>) -> Poll<Self::Output> {
81		if let Some(ref mut fut) = self.fut {
82			return Pin::new(fut).poll(task);
83		}
84
85		if let Some(err) = self.err.take() {
86			return Poll::Ready(Err(err));
87		}
88
89		let limit = self.limit;
90
91		if let Some(len) = self.length.take() {
92			if len > limit {
93				return Poll::Ready(Err(MsgPackError::Overflow));
94			}
95		}
96
97		let mut stream = self.stream.take().expect("MsgPackMessage could not be used second time");
98
99		self.fut = Some(
100			async move {
101				let mut body = BytesMut::with_capacity(8192);
102
103				while let Some(item) = stream.next().await {
104					let chunk = item?;
105
106					if body.len() + chunk.len() > limit {
107						return Err(MsgPackError::Overflow);
108					} else {
109						body.extend_from_slice(&chunk);
110					}
111				}
112
113				if body.is_empty() {
114					return Err(MsgPackError::Payload(PayloadError::Incomplete(Some(
115						io::Error::new(io::ErrorKind::InvalidData, "payload is empty"),
116					))));
117				}
118
119				Ok(rmp_serde::from_slice::<T>(&body)?)
120			}
121			.boxed_local(),
122		);
123		self.poll(task)
124	}
125}