1#![forbid(unsafe_code)]
4#![doc(html_logo_url = "https://actix.rs/img/logo.png")]
5#![doc(html_favicon_url = "https://actix.rs/favicon.ico")]
6#![cfg_attr(docsrs, feature(doc_auto_cfg))]
7
8use std::{
9 fmt,
10 future::Future,
11 ops::{Deref, DerefMut},
12 pin::Pin,
13 task::{self, Poll},
14};
15
16use actix_web::{
17 body::BoxBody,
18 dev::Payload,
19 error::PayloadError,
20 http::header::{CONTENT_LENGTH, CONTENT_TYPE},
21 web::BytesMut,
22 Error, FromRequest, HttpMessage, HttpRequest, HttpResponse, HttpResponseBuilder, Responder,
23 ResponseError,
24};
25use derive_more::Display;
26use futures_util::{
27 future::{FutureExt as _, LocalBoxFuture},
28 stream::StreamExt as _,
29};
30use prost::{DecodeError as ProtoBufDecodeError, EncodeError as ProtoBufEncodeError, Message};
31
32#[derive(Debug, Display)]
33pub enum ProtoBufPayloadError {
34 #[display(fmt = "Payload size is bigger than 256k")]
36 Overflow,
37
38 #[display(fmt = "Content type error")]
40 ContentType,
41
42 #[display(fmt = "ProtoBuf serialize error: {_0}")]
44 Serialize(ProtoBufEncodeError),
45
46 #[display(fmt = "ProtoBuf deserialize error: {_0}")]
48 Deserialize(ProtoBufDecodeError),
49
50 #[display(fmt = "Error that occur during reading payload: {_0}")]
52 Payload(PayloadError),
53}
54
55impl ResponseError for ProtoBufPayloadError {
56 fn error_response(&self) -> HttpResponse {
57 match *self {
58 ProtoBufPayloadError::Overflow => HttpResponse::PayloadTooLarge().into(),
59 _ => HttpResponse::BadRequest().into(),
60 }
61 }
62}
63
64impl From<PayloadError> for ProtoBufPayloadError {
65 fn from(err: PayloadError) -> ProtoBufPayloadError {
66 ProtoBufPayloadError::Payload(err)
67 }
68}
69
70impl From<ProtoBufDecodeError> for ProtoBufPayloadError {
71 fn from(err: ProtoBufDecodeError) -> ProtoBufPayloadError {
72 ProtoBufPayloadError::Deserialize(err)
73 }
74}
75
76pub struct ProtoBuf<T: Message>(pub T);
77
78impl<T: Message> Deref for ProtoBuf<T> {
79 type Target = T;
80
81 fn deref(&self) -> &T {
82 &self.0
83 }
84}
85
86impl<T: Message> DerefMut for ProtoBuf<T> {
87 fn deref_mut(&mut self) -> &mut T {
88 &mut self.0
89 }
90}
91
92impl<T: Message> fmt::Debug for ProtoBuf<T>
93where
94 T: fmt::Debug,
95{
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 write!(f, "ProtoBuf: {:?}", self.0)
98 }
99}
100
101impl<T: Message> fmt::Display for ProtoBuf<T>
102where
103 T: fmt::Display,
104{
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 fmt::Display::fmt(&self.0, f)
107 }
108}
109
110pub struct ProtoBufConfig {
111 limit: usize,
112}
113
114impl ProtoBufConfig {
115 pub fn limit(&mut self, limit: usize) -> &mut Self {
117 self.limit = limit;
118 self
119 }
120}
121
122impl Default for ProtoBufConfig {
123 fn default() -> Self {
124 ProtoBufConfig { limit: 262_144 }
125 }
126}
127
128impl<T> FromRequest for ProtoBuf<T>
129where
130 T: Message + Default + 'static,
131{
132 type Error = Error;
133 type Future = LocalBoxFuture<'static, Result<Self, Error>>;
134
135 #[inline]
136 fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
137 let limit = req
138 .app_data::<ProtoBufConfig>()
139 .map(|c| c.limit)
140 .unwrap_or(262_144);
141 ProtoBufMessage::new(req, payload)
142 .limit(limit)
143 .map(move |res| match res {
144 Err(e) => Err(e.into()),
145 Ok(item) => Ok(ProtoBuf(item)),
146 })
147 .boxed_local()
148 }
149}
150
151impl<T: Message + Default> Responder for ProtoBuf<T> {
152 type Body = BoxBody;
153
154 fn respond_to(self, _: &HttpRequest) -> HttpResponse {
155 let mut buf = Vec::new();
156 match self.0.encode(&mut buf) {
157 Ok(()) => HttpResponse::Ok()
158 .content_type("application/protobuf")
159 .body(buf),
160 Err(err) => HttpResponse::from_error(Error::from(ProtoBufPayloadError::Serialize(err))),
161 }
162 }
163}
164
165pub struct ProtoBufMessage<T: Message + Default> {
166 limit: usize,
167 length: Option<usize>,
168 stream: Option<Payload>,
169 err: Option<ProtoBufPayloadError>,
170 fut: Option<LocalBoxFuture<'static, Result<T, ProtoBufPayloadError>>>,
171}
172
173impl<T: Message + Default> ProtoBufMessage<T> {
174 pub fn new(req: &HttpRequest, payload: &mut Payload) -> Self {
176 if req.content_type() != "application/protobuf"
177 && req.content_type() != "application/x-protobuf"
178 {
179 return ProtoBufMessage {
180 limit: 262_144,
181 length: None,
182 stream: None,
183 fut: None,
184 err: Some(ProtoBufPayloadError::ContentType),
185 };
186 }
187
188 let mut len = None;
189 if let Some(l) = req.headers().get(CONTENT_LENGTH) {
190 if let Ok(s) = l.to_str() {
191 if let Ok(l) = s.parse::<usize>() {
192 len = Some(l)
193 }
194 }
195 }
196
197 ProtoBufMessage {
198 limit: 262_144,
199 length: len,
200 stream: Some(payload.take()),
201 fut: None,
202 err: None,
203 }
204 }
205
206 pub fn limit(mut self, limit: usize) -> Self {
208 self.limit = limit;
209 self
210 }
211}
212
213impl<T: Message + Default + 'static> Future for ProtoBufMessage<T> {
214 type Output = Result<T, ProtoBufPayloadError>;
215
216 fn poll(mut self: Pin<&mut Self>, task: &mut task::Context<'_>) -> Poll<Self::Output> {
217 if let Some(ref mut fut) = self.fut {
218 return Pin::new(fut).poll(task);
219 }
220
221 if let Some(err) = self.err.take() {
222 return Poll::Ready(Err(err));
223 }
224
225 let limit = self.limit;
226 if let Some(len) = self.length.take() {
227 if len > limit {
228 return Poll::Ready(Err(ProtoBufPayloadError::Overflow));
229 }
230 }
231
232 let mut stream = self
233 .stream
234 .take()
235 .expect("ProtoBufMessage could not be used second time");
236
237 self.fut = Some(
238 async move {
239 let mut body = BytesMut::with_capacity(8192);
240
241 while let Some(item) = stream.next().await {
242 let chunk = item?;
243 if (body.len() + chunk.len()) > limit {
244 return Err(ProtoBufPayloadError::Overflow);
245 } else {
246 body.extend_from_slice(&chunk);
247 }
248 }
249
250 Ok(<T>::decode(&mut body)?)
251 }
252 .boxed_local(),
253 );
254 self.poll(task)
255 }
256}
257
258pub trait ProtoBufResponseBuilder {
259 fn protobuf<T: Message>(&mut self, value: T) -> Result<HttpResponse, Error>;
260}
261
262impl ProtoBufResponseBuilder for HttpResponseBuilder {
263 fn protobuf<T: Message>(&mut self, value: T) -> Result<HttpResponse, Error> {
264 self.insert_header((CONTENT_TYPE, "application/protobuf"));
265
266 let mut body = Vec::new();
267 value
268 .encode(&mut body)
269 .map_err(ProtoBufPayloadError::Serialize)?;
270
271 Ok(self.body(body))
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use actix_web::{http::header, test::TestRequest};
278
279 use super::*;
280
281 impl PartialEq for ProtoBufPayloadError {
282 fn eq(&self, other: &ProtoBufPayloadError) -> bool {
283 match *self {
284 ProtoBufPayloadError::Overflow => {
285 matches!(*other, ProtoBufPayloadError::Overflow)
286 }
287 ProtoBufPayloadError::ContentType => {
288 matches!(*other, ProtoBufPayloadError::ContentType)
289 }
290 _ => false,
291 }
292 }
293 }
294
295 #[derive(Clone, PartialEq, Eq, Message)]
296 pub struct MyObject {
297 #[prost(int32, tag = "1")]
298 pub number: i32,
299 #[prost(string, tag = "2")]
300 pub name: String,
301 }
302
303 #[actix_web::test]
304 async fn test_protobuf() {
305 let protobuf = ProtoBuf(MyObject {
306 number: 9,
307 name: "test".to_owned(),
308 });
309 let req = TestRequest::default().to_http_request();
310 let resp = protobuf.respond_to(&req);
311 let ct = resp.headers().get(header::CONTENT_TYPE).unwrap();
312 assert_eq!(ct, "application/protobuf");
313 }
314
315 #[actix_web::test]
316 async fn test_protobuf_message() {
317 let (req, mut pl) = TestRequest::default().to_http_parts();
318 let protobuf = ProtoBufMessage::<MyObject>::new(&req, &mut pl).await;
319 assert_eq!(protobuf.err().unwrap(), ProtoBufPayloadError::ContentType);
320
321 let (req, mut pl) = TestRequest::get()
322 .insert_header((header::CONTENT_TYPE, "application/text"))
323 .to_http_parts();
324 let protobuf = ProtoBufMessage::<MyObject>::new(&req, &mut pl).await;
325 assert_eq!(protobuf.err().unwrap(), ProtoBufPayloadError::ContentType);
326
327 let (req, mut pl) = TestRequest::get()
328 .insert_header((header::CONTENT_TYPE, "application/protobuf"))
329 .insert_header((header::CONTENT_LENGTH, "10000"))
330 .to_http_parts();
331 let protobuf = ProtoBufMessage::<MyObject>::new(&req, &mut pl)
332 .limit(100)
333 .await;
334 assert_eq!(protobuf.err().unwrap(), ProtoBufPayloadError::Overflow);
335 }
336}