actix_protobuf/
lib.rs

1//! Protobuf payload extractor for Actix Web.
2
3#![forbid(unsafe_code)]
4#![doc(html_logo_url = "https://actix.rs/img/logo.png")]
5#![doc(html_favicon_url = "https://actix.rs/favicon.ico")]
6#![cfg_attr(docsrs, feature(doc_auto_cfg))]
7
8use std::{
9    fmt,
10    future::Future,
11    ops::{Deref, DerefMut},
12    pin::Pin,
13    task::{self, Poll},
14};
15
16use actix_web::{
17    body::BoxBody,
18    dev::Payload,
19    error::PayloadError,
20    http::header::{CONTENT_LENGTH, CONTENT_TYPE},
21    web::BytesMut,
22    Error, FromRequest, HttpMessage, HttpRequest, HttpResponse, HttpResponseBuilder, Responder,
23    ResponseError,
24};
25use derive_more::Display;
26use futures_util::{
27    future::{FutureExt as _, LocalBoxFuture},
28    stream::StreamExt as _,
29};
30use prost::{DecodeError as ProtoBufDecodeError, EncodeError as ProtoBufEncodeError, Message};
31
32#[derive(Debug, Display)]
33pub enum ProtoBufPayloadError {
34    /// Payload size is bigger than 256k
35    #[display(fmt = "Payload size is bigger than 256k")]
36    Overflow,
37
38    /// Content type error
39    #[display(fmt = "Content type error")]
40    ContentType,
41
42    /// Serialize error
43    #[display(fmt = "ProtoBuf serialize error: {_0}")]
44    Serialize(ProtoBufEncodeError),
45
46    /// Deserialize error
47    #[display(fmt = "ProtoBuf deserialize error: {_0}")]
48    Deserialize(ProtoBufDecodeError),
49
50    /// Payload error
51    #[display(fmt = "Error that occur during reading payload: {_0}")]
52    Payload(PayloadError),
53}
54
55impl ResponseError for ProtoBufPayloadError {
56    fn error_response(&self) -> HttpResponse {
57        match *self {
58            ProtoBufPayloadError::Overflow => HttpResponse::PayloadTooLarge().into(),
59            _ => HttpResponse::BadRequest().into(),
60        }
61    }
62}
63
64impl From<PayloadError> for ProtoBufPayloadError {
65    fn from(err: PayloadError) -> ProtoBufPayloadError {
66        ProtoBufPayloadError::Payload(err)
67    }
68}
69
70impl From<ProtoBufDecodeError> for ProtoBufPayloadError {
71    fn from(err: ProtoBufDecodeError) -> ProtoBufPayloadError {
72        ProtoBufPayloadError::Deserialize(err)
73    }
74}
75
76pub struct ProtoBuf<T: Message>(pub T);
77
78impl<T: Message> Deref for ProtoBuf<T> {
79    type Target = T;
80
81    fn deref(&self) -> &T {
82        &self.0
83    }
84}
85
86impl<T: Message> DerefMut for ProtoBuf<T> {
87    fn deref_mut(&mut self) -> &mut T {
88        &mut self.0
89    }
90}
91
92impl<T: Message> fmt::Debug for ProtoBuf<T>
93where
94    T: fmt::Debug,
95{
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        write!(f, "ProtoBuf: {:?}", self.0)
98    }
99}
100
101impl<T: Message> fmt::Display for ProtoBuf<T>
102where
103    T: fmt::Display,
104{
105    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106        fmt::Display::fmt(&self.0, f)
107    }
108}
109
110pub struct ProtoBufConfig {
111    limit: usize,
112}
113
114impl ProtoBufConfig {
115    /// Change max size of payload. By default max size is 256Kb
116    pub fn limit(&mut self, limit: usize) -> &mut Self {
117        self.limit = limit;
118        self
119    }
120}
121
122impl Default for ProtoBufConfig {
123    fn default() -> Self {
124        ProtoBufConfig { limit: 262_144 }
125    }
126}
127
128impl<T> FromRequest for ProtoBuf<T>
129where
130    T: Message + Default + 'static,
131{
132    type Error = Error;
133    type Future = LocalBoxFuture<'static, Result<Self, Error>>;
134
135    #[inline]
136    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
137        let limit = req
138            .app_data::<ProtoBufConfig>()
139            .map(|c| c.limit)
140            .unwrap_or(262_144);
141        ProtoBufMessage::new(req, payload)
142            .limit(limit)
143            .map(move |res| match res {
144                Err(e) => Err(e.into()),
145                Ok(item) => Ok(ProtoBuf(item)),
146            })
147            .boxed_local()
148    }
149}
150
151impl<T: Message + Default> Responder for ProtoBuf<T> {
152    type Body = BoxBody;
153
154    fn respond_to(self, _: &HttpRequest) -> HttpResponse {
155        let mut buf = Vec::new();
156        match self.0.encode(&mut buf) {
157            Ok(()) => HttpResponse::Ok()
158                .content_type("application/protobuf")
159                .body(buf),
160            Err(err) => HttpResponse::from_error(Error::from(ProtoBufPayloadError::Serialize(err))),
161        }
162    }
163}
164
165pub struct ProtoBufMessage<T: Message + Default> {
166    limit: usize,
167    length: Option<usize>,
168    stream: Option<Payload>,
169    err: Option<ProtoBufPayloadError>,
170    fut: Option<LocalBoxFuture<'static, Result<T, ProtoBufPayloadError>>>,
171}
172
173impl<T: Message + Default> ProtoBufMessage<T> {
174    /// Create `ProtoBufMessage` for request.
175    pub fn new(req: &HttpRequest, payload: &mut Payload) -> Self {
176        if req.content_type() != "application/protobuf"
177            && req.content_type() != "application/x-protobuf"
178        {
179            return ProtoBufMessage {
180                limit: 262_144,
181                length: None,
182                stream: None,
183                fut: None,
184                err: Some(ProtoBufPayloadError::ContentType),
185            };
186        }
187
188        let mut len = None;
189        if let Some(l) = req.headers().get(CONTENT_LENGTH) {
190            if let Ok(s) = l.to_str() {
191                if let Ok(l) = s.parse::<usize>() {
192                    len = Some(l)
193                }
194            }
195        }
196
197        ProtoBufMessage {
198            limit: 262_144,
199            length: len,
200            stream: Some(payload.take()),
201            fut: None,
202            err: None,
203        }
204    }
205
206    /// Change max size of payload. By default max size is 256Kb
207    pub fn limit(mut self, limit: usize) -> Self {
208        self.limit = limit;
209        self
210    }
211}
212
213impl<T: Message + Default + 'static> Future for ProtoBufMessage<T> {
214    type Output = Result<T, ProtoBufPayloadError>;
215
216    fn poll(mut self: Pin<&mut Self>, task: &mut task::Context<'_>) -> Poll<Self::Output> {
217        if let Some(ref mut fut) = self.fut {
218            return Pin::new(fut).poll(task);
219        }
220
221        if let Some(err) = self.err.take() {
222            return Poll::Ready(Err(err));
223        }
224
225        let limit = self.limit;
226        if let Some(len) = self.length.take() {
227            if len > limit {
228                return Poll::Ready(Err(ProtoBufPayloadError::Overflow));
229            }
230        }
231
232        let mut stream = self
233            .stream
234            .take()
235            .expect("ProtoBufMessage could not be used second time");
236
237        self.fut = Some(
238            async move {
239                let mut body = BytesMut::with_capacity(8192);
240
241                while let Some(item) = stream.next().await {
242                    let chunk = item?;
243                    if (body.len() + chunk.len()) > limit {
244                        return Err(ProtoBufPayloadError::Overflow);
245                    } else {
246                        body.extend_from_slice(&chunk);
247                    }
248                }
249
250                Ok(<T>::decode(&mut body)?)
251            }
252            .boxed_local(),
253        );
254        self.poll(task)
255    }
256}
257
258pub trait ProtoBufResponseBuilder {
259    fn protobuf<T: Message>(&mut self, value: T) -> Result<HttpResponse, Error>;
260}
261
262impl ProtoBufResponseBuilder for HttpResponseBuilder {
263    fn protobuf<T: Message>(&mut self, value: T) -> Result<HttpResponse, Error> {
264        self.insert_header((CONTENT_TYPE, "application/protobuf"));
265
266        let mut body = Vec::new();
267        value
268            .encode(&mut body)
269            .map_err(ProtoBufPayloadError::Serialize)?;
270
271        Ok(self.body(body))
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use actix_web::{http::header, test::TestRequest};
278
279    use super::*;
280
281    impl PartialEq for ProtoBufPayloadError {
282        fn eq(&self, other: &ProtoBufPayloadError) -> bool {
283            match *self {
284                ProtoBufPayloadError::Overflow => {
285                    matches!(*other, ProtoBufPayloadError::Overflow)
286                }
287                ProtoBufPayloadError::ContentType => {
288                    matches!(*other, ProtoBufPayloadError::ContentType)
289                }
290                _ => false,
291            }
292        }
293    }
294
295    #[derive(Clone, PartialEq, Eq, Message)]
296    pub struct MyObject {
297        #[prost(int32, tag = "1")]
298        pub number: i32,
299        #[prost(string, tag = "2")]
300        pub name: String,
301    }
302
303    #[actix_web::test]
304    async fn test_protobuf() {
305        let protobuf = ProtoBuf(MyObject {
306            number: 9,
307            name: "test".to_owned(),
308        });
309        let req = TestRequest::default().to_http_request();
310        let resp = protobuf.respond_to(&req);
311        let ct = resp.headers().get(header::CONTENT_TYPE).unwrap();
312        assert_eq!(ct, "application/protobuf");
313    }
314
315    #[actix_web::test]
316    async fn test_protobuf_message() {
317        let (req, mut pl) = TestRequest::default().to_http_parts();
318        let protobuf = ProtoBufMessage::<MyObject>::new(&req, &mut pl).await;
319        assert_eq!(protobuf.err().unwrap(), ProtoBufPayloadError::ContentType);
320
321        let (req, mut pl) = TestRequest::get()
322            .insert_header((header::CONTENT_TYPE, "application/text"))
323            .to_http_parts();
324        let protobuf = ProtoBufMessage::<MyObject>::new(&req, &mut pl).await;
325        assert_eq!(protobuf.err().unwrap(), ProtoBufPayloadError::ContentType);
326
327        let (req, mut pl) = TestRequest::get()
328            .insert_header((header::CONTENT_TYPE, "application/protobuf"))
329            .insert_header((header::CONTENT_LENGTH, "10000"))
330            .to_http_parts();
331        let protobuf = ProtoBufMessage::<MyObject>::new(&req, &mut pl)
332            .limit(100)
333            .await;
334        assert_eq!(protobuf.err().unwrap(), ProtoBufPayloadError::Overflow);
335    }
336}