1use std::{
6 pin::Pin,
7 task::{Context, Poll, ready},
8};
9
10use actix_web::{FromRequest, HttpMessage, HttpRequest, ResponseError, dev, http::StatusCode, web};
11use derive_more::{Display, Error};
12use futures_core::Stream as _;
13use tracing::debug;
14
15pub const DEFAULT_BYTES_LIMIT: usize = 4_194_304;
17
18#[derive(Debug)]
51pub struct Bytes<const LIMIT: usize = DEFAULT_BYTES_LIMIT>(pub web::Bytes);
53
54mod waiting_on_derive_more_to_start_using_syn_2_due_to_proc_macro_panic {
55 use super::*;
56
57 impl<const LIMIT: usize> std::ops::Deref for Bytes<LIMIT> {
58 type Target = web::Bytes;
59
60 fn deref(&self) -> &Self::Target {
61 &self.0
62 }
63 }
64
65 impl<const LIMIT: usize> std::ops::DerefMut for Bytes<LIMIT> {
66 fn deref_mut(&mut self) -> &mut Self::Target {
67 &mut self.0
68 }
69 }
70
71 impl<const LIMIT: usize> AsRef<web::Bytes> for Bytes<LIMIT> {
72 fn as_ref(&self) -> &web::Bytes {
73 &self.0
74 }
75 }
76
77 impl<const LIMIT: usize> AsMut<web::Bytes> for Bytes<LIMIT> {
78 fn as_mut(&mut self) -> &mut web::Bytes {
79 &mut self.0
80 }
81 }
82}
83
84impl<const LIMIT: usize> Bytes<LIMIT> {
85 pub fn into_inner(self) -> web::Bytes {
87 self.0
88 }
89}
90
91impl<const LIMIT: usize> FromRequest for Bytes<LIMIT> {
93 type Error = actix_web::Error;
94 type Future = BytesExtractFut<LIMIT>;
95
96 #[inline]
97 fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
98 BytesExtractFut {
99 req: Some(req.clone()),
100 fut: BytesBody::new(req, payload),
101 }
102 }
103}
104
105#[allow(missing_debug_implementations)]
106pub struct BytesExtractFut<const LIMIT: usize> {
107 req: Option<HttpRequest>,
108 fut: BytesBody<LIMIT>,
109}
110
111impl<const LIMIT: usize> Future for BytesExtractFut<LIMIT> {
112 type Output = actix_web::Result<Bytes<LIMIT>>;
113
114 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
115 let this = self.get_mut();
116
117 let res = ready!(Pin::new(&mut this.fut).poll(cx));
118
119 let res = match res {
120 Err(err) => {
121 let req = this.req.take().unwrap();
122
123 debug!(
124 "Failed to extract Bytes from payload in handler: {}",
125 req.match_name().unwrap_or_else(|| req.path())
126 );
127
128 Err(err.into())
129 }
130 Ok(data) => Ok(Bytes(data)),
131 };
132
133 Poll::Ready(res)
134 }
135}
136
137pub enum BytesBody<const LIMIT: usize> {
142 Error(Option<BytesPayloadError>),
143 Body {
144 #[allow(dead_code)]
146 length: Option<usize>,
147 payload: dev::Payload,
148 buf: web::BytesMut,
149 },
150}
151
152impl<const LIMIT: usize> Unpin for BytesBody<LIMIT> {}
153
154impl<const LIMIT: usize> BytesBody<LIMIT> {
155 pub fn new(req: &HttpRequest, payload: &mut dev::Payload) -> Self {
157 let payload = payload.take();
158
159 let length = req
160 .get_header::<crate::header::ContentLength>()
161 .map(|cl| cl.into_inner());
162
163 if let Some(len) = length {
168 if len > LIMIT {
169 return BytesBody::Error(Some(BytesPayloadError::OverflowKnownLength {
170 length: len,
171 limit: LIMIT,
172 }));
173 }
174 }
175
176 BytesBody::Body {
177 length,
178 payload,
179 buf: web::BytesMut::with_capacity(8192),
180 }
181 }
182}
183
184impl<const LIMIT: usize> Future for BytesBody<LIMIT> {
185 type Output = Result<web::Bytes, BytesPayloadError>;
186
187 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
188 let this = self.get_mut();
189
190 match this {
191 BytesBody::Body { buf, payload, .. } => loop {
192 let res = ready!(Pin::new(&mut *payload).poll_next(cx));
193
194 match res {
195 Some(chunk) => {
196 let chunk = chunk?;
197 let buf_len = buf.len() + chunk.len();
198 if buf_len > LIMIT {
199 return Poll::Ready(Err(BytesPayloadError::Overflow { limit: LIMIT }));
200 } else {
201 buf.extend_from_slice(&chunk);
202 }
203 }
204
205 None => return Poll::Ready(Ok(buf.split().freeze())),
206 }
207 },
208
209 BytesBody::Error(err) => Poll::Ready(Err(err.take().unwrap())),
210 }
211 }
212}
213
214#[derive(Debug, Display, Error)]
216#[non_exhaustive]
217pub enum BytesPayloadError {
218 #[display("Payload ({length} bytes) is larger than allowed (limit: {limit} bytes).")]
220 OverflowKnownLength { length: usize, limit: usize },
221
222 #[display("Payload has exceeded limit ({limit} bytes).")]
224 Overflow { limit: usize },
225
226 #[display("Error that occur during reading payload: {_0}")]
228 Payload(actix_web::error::PayloadError),
229}
230
231impl From<actix_web::error::PayloadError> for BytesPayloadError {
232 fn from(err: actix_web::error::PayloadError) -> Self {
233 Self::Payload(err)
234 }
235}
236
237impl ResponseError for BytesPayloadError {
238 fn status_code(&self) -> StatusCode {
239 match self {
240 Self::OverflowKnownLength { .. } => StatusCode::PAYLOAD_TOO_LARGE,
241 Self::Overflow { .. } => StatusCode::PAYLOAD_TOO_LARGE,
242 Self::Payload(err) => err.status_code(),
243 }
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use actix_web::{http::header, test::TestRequest, web};
250
251 use super::*;
252
253 #[cfg(test)]
254 impl PartialEq for BytesPayloadError {
255 fn eq(&self, other: &Self) -> bool {
256 match (self, other) {
257 (
258 Self::OverflowKnownLength {
259 length: l_length,
260 limit: l_limit,
261 },
262 Self::OverflowKnownLength {
263 length: r_length,
264 limit: r_limit,
265 },
266 ) => l_length == r_length && l_limit == r_limit,
267
268 (Self::Overflow { limit: l_limit }, Self::Overflow { limit: r_limit }) => {
269 l_limit == r_limit
270 }
271
272 _ => false,
273 }
274 }
275 }
276
277 #[actix_web::test]
278 async fn extract() {
279 let (req, mut pl) = TestRequest::default()
280 .insert_header(header::ContentType::json())
281 .insert_header(crate::header::ContentLength::from(3))
282 .set_payload(web::Bytes::from_static(b"foo"))
283 .to_http_parts();
284
285 let s = Bytes::<DEFAULT_BYTES_LIMIT>::from_request(&req, &mut pl)
286 .await
287 .unwrap();
288 assert_eq!(s.as_ref(), "foo");
289
290 let (req, mut pl) = TestRequest::default()
291 .insert_header(header::ContentType::json())
292 .insert_header(crate::header::ContentLength::from(16))
293 .set_payload(web::Bytes::from_static(b"foo foo foo foo"))
294 .to_http_parts();
295
296 let s = Bytes::<10>::from_request(&req, &mut pl).await;
297 let err_str = s.unwrap_err().to_string();
298 assert_eq!(
299 err_str,
300 "Payload (16 bytes) is larger than allowed (limit: 10 bytes).",
301 );
302
303 let (req, mut pl) = TestRequest::default()
304 .insert_header(header::ContentType::json())
305 .insert_header(crate::header::ContentLength::from(16))
306 .set_payload(web::Bytes::from_static(b"foo foo foo foo"))
307 .to_http_parts();
308 let s = Bytes::<10>::from_request(&req, &mut pl).await;
309 let err = format!("{}", s.unwrap_err());
310 assert!(
311 err.contains("larger than allowed"),
312 "unexpected error string: {err:?}",
313 );
314 }
315
316 #[actix_web::test]
317 async fn body() {
318 let (req, mut pl) = TestRequest::default().to_http_parts();
319 let _bytes = BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl)
320 .await
321 .unwrap();
322
323 let (req, mut pl) = TestRequest::default()
324 .insert_header(header::ContentType("application/text".parse().unwrap()))
325 .to_http_parts();
326 BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl)
328 .await
329 .unwrap();
330
331 let (req, mut pl) = TestRequest::default()
332 .insert_header(header::ContentType::json())
333 .insert_header(crate::header::ContentLength::from(10000))
334 .to_http_parts();
335
336 let bytes = BytesBody::<100>::new(&req, &mut pl).await;
337 assert_eq!(
338 bytes.unwrap_err(),
339 BytesPayloadError::OverflowKnownLength {
340 length: 10000,
341 limit: 100
342 }
343 );
344
345 let (req, mut pl) = TestRequest::default()
346 .insert_header(header::ContentType::json())
347 .set_payload(web::Bytes::from_static(&[0u8; 1000]))
348 .to_http_parts();
349
350 let bytes = BytesBody::<100>::new(&req, &mut pl).await;
351
352 assert_eq!(
353 bytes.unwrap_err(),
354 BytesPayloadError::Overflow { limit: 100 }
355 );
356
357 let (req, mut pl) = TestRequest::default()
358 .insert_header(header::ContentType::json())
359 .insert_header(crate::header::ContentLength::from(16))
360 .set_payload(web::Bytes::from_static(b"foo foo foo foo"))
361 .to_http_parts();
362
363 let bytes = BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl).await;
364 assert_eq!(bytes.ok().unwrap(), "foo foo foo foo");
365 }
366
367 #[actix_web::test]
368 async fn test_with_config_in_data_wrapper() {
369 let (req, mut pl) = TestRequest::default()
370 .app_data(web::Data::new(web::PayloadConfig::default().limit(8)))
371 .insert_header(header::ContentType::json())
372 .insert_header((header::CONTENT_LENGTH, 16))
373 .set_payload(web::Bytes::from_static(b"{\"name\": \"test\"}"))
374 .to_http_parts();
375
376 let s = Bytes::<10>::from_request(&req, &mut pl).await;
377 assert!(s.is_err());
378
379 let err_str = s.unwrap_err().to_string();
380 assert_eq!(
381 err_str,
382 "Payload (16 bytes) is larger than allowed (limit: 10 bytes).",
383 );
384 }
385}