1use std::{
6 pin::Pin,
7 task::{Context, Poll, ready},
8};
9
10use actix_web::{FromRequest, HttpMessage, HttpRequest, ResponseError, dev, http::StatusCode, web};
11use derive_more::{Display, Error};
12use futures_core::Stream as _;
13use tracing::debug;
14
15pub const DEFAULT_BYTES_LIMIT: usize = 4_194_304;
17
18#[derive(Debug)]
51pub struct Bytes<const LIMIT: usize = DEFAULT_BYTES_LIMIT>(pub web::Bytes);
53
54mod waiting_on_derive_more_to_start_using_syn_2_due_to_proc_macro_panic {
55 use super::*;
56
57 impl<const LIMIT: usize> std::ops::Deref for Bytes<LIMIT> {
58 type Target = web::Bytes;
59
60 fn deref(&self) -> &Self::Target {
61 &self.0
62 }
63 }
64
65 impl<const LIMIT: usize> std::ops::DerefMut for Bytes<LIMIT> {
66 fn deref_mut(&mut self) -> &mut Self::Target {
67 &mut self.0
68 }
69 }
70
71 impl<const LIMIT: usize> AsRef<web::Bytes> for Bytes<LIMIT> {
72 fn as_ref(&self) -> &web::Bytes {
73 &self.0
74 }
75 }
76
77 impl<const LIMIT: usize> AsMut<web::Bytes> for Bytes<LIMIT> {
78 fn as_mut(&mut self) -> &mut web::Bytes {
79 &mut self.0
80 }
81 }
82}
83
84impl<const LIMIT: usize> Bytes<LIMIT> {
85 pub fn into_inner(self) -> web::Bytes {
87 self.0
88 }
89}
90
91impl<const LIMIT: usize> FromRequest for Bytes<LIMIT> {
93 type Error = actix_web::Error;
94 type Future = BytesExtractFut<LIMIT>;
95
96 #[inline]
97 fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
98 BytesExtractFut {
99 req: Some(req.clone()),
100 fut: BytesBody::new(req, payload),
101 }
102 }
103}
104
105#[allow(missing_debug_implementations)]
106pub struct BytesExtractFut<const LIMIT: usize> {
107 req: Option<HttpRequest>,
108 fut: BytesBody<LIMIT>,
109}
110
111impl<const LIMIT: usize> Future for BytesExtractFut<LIMIT> {
112 type Output = actix_web::Result<Bytes<LIMIT>>;
113
114 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
115 let this = self.get_mut();
116
117 let res = ready!(Pin::new(&mut this.fut).poll(cx));
118
119 let res = match res {
120 Err(err) => {
121 let req = this.req.take().unwrap();
122
123 debug!(
124 "Failed to extract Bytes from payload in handler: {}",
125 req.match_name().unwrap_or_else(|| req.path())
126 );
127
128 Err(err.into())
129 }
130 Ok(data) => Ok(Bytes(data)),
131 };
132
133 Poll::Ready(res)
134 }
135}
136
137pub enum BytesBody<const LIMIT: usize> {
142 Error(Option<BytesPayloadError>),
143 Body {
144 #[allow(dead_code)]
146 length: Option<usize>,
147 payload: dev::Payload,
148 buf: web::BytesMut,
149 },
150}
151
152impl<const LIMIT: usize> Unpin for BytesBody<LIMIT> {}
153
154impl<const LIMIT: usize> BytesBody<LIMIT> {
155 pub fn new(req: &HttpRequest, payload: &mut dev::Payload) -> Self {
157 let payload = payload.take();
158
159 let length = req
160 .get_header::<crate::header::ContentLength>()
161 .map(|cl| cl.into_inner());
162
163 if let Some(len) = length
168 && len > LIMIT
169 {
170 return BytesBody::Error(Some(BytesPayloadError::OverflowKnownLength {
171 length: len,
172 limit: LIMIT,
173 }));
174 }
175
176 BytesBody::Body {
177 length,
178 payload,
179 buf: web::BytesMut::with_capacity(8192),
180 }
181 }
182}
183
184impl<const LIMIT: usize> Future for BytesBody<LIMIT> {
185 type Output = Result<web::Bytes, BytesPayloadError>;
186
187 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
188 let this = self.get_mut();
189
190 match this {
191 BytesBody::Body { buf, payload, .. } => loop {
192 let res = ready!(Pin::new(&mut *payload).poll_next(cx));
193
194 match res {
195 Some(chunk) => {
196 let chunk = chunk?;
197 let buf_len = buf.len() + chunk.len();
198 if buf_len > LIMIT {
199 return Poll::Ready(Err(BytesPayloadError::Overflow { limit: LIMIT }));
200 } else {
201 buf.extend_from_slice(&chunk);
202 }
203 }
204
205 None => return Poll::Ready(Ok(buf.split().freeze())),
206 }
207 },
208
209 BytesBody::Error(err) => Poll::Ready(Err(err.take().unwrap())),
210 }
211 }
212}
213
214#[derive(Debug, Display, Error)]
216#[non_exhaustive]
217pub enum BytesPayloadError {
218 #[display("Payload ({length} bytes) is larger than allowed (limit: {limit} bytes).")]
220 OverflowKnownLength { length: usize, limit: usize },
221
222 #[display("Payload has exceeded limit ({limit} bytes).")]
224 Overflow { limit: usize },
225
226 #[display("Error that occur during reading payload: {_0}")]
228 Payload(actix_web::error::PayloadError),
229}
230
231impl From<actix_web::error::PayloadError> for BytesPayloadError {
232 fn from(err: actix_web::error::PayloadError) -> Self {
233 Self::Payload(err)
234 }
235}
236
237impl ResponseError for BytesPayloadError {
238 fn status_code(&self) -> StatusCode {
239 match self {
240 Self::OverflowKnownLength { .. } => StatusCode::PAYLOAD_TOO_LARGE,
241 Self::Overflow { .. } => StatusCode::PAYLOAD_TOO_LARGE,
242 Self::Payload(err) => err.status_code(),
243 }
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use actix_web::{http::header, test::TestRequest, web};
250
251 use super::*;
252
253 #[cfg(test)]
254 impl PartialEq for BytesPayloadError {
255 fn eq(&self, other: &Self) -> bool {
256 match (self, other) {
257 (
258 Self::OverflowKnownLength {
259 length: l_length,
260 limit: l_limit,
261 },
262 Self::OverflowKnownLength {
263 length: r_length,
264 limit: r_limit,
265 },
266 ) => l_length == r_length && l_limit == r_limit,
267
268 (Self::Overflow { limit: l_limit }, Self::Overflow { limit: r_limit }) => {
269 l_limit == r_limit
270 }
271
272 _ => false,
273 }
274 }
275 }
276
277 #[actix_web::test]
278 async fn extract() {
279 let (req, mut pl) = TestRequest::default()
280 .set_payload(web::Bytes::from_static(b"foo"))
281 .to_http_parts();
282
283 let s = Bytes::<DEFAULT_BYTES_LIMIT>::from_request(&req, &mut pl)
284 .await
285 .unwrap();
286 assert_eq!(s.as_ref(), "foo");
287
288 let (req, mut pl) = TestRequest::default()
289 .set_payload(web::Bytes::from_static(b"foo foo foo foo"))
290 .to_http_parts();
291
292 let s = Bytes::<10>::from_request(&req, &mut pl).await;
293 let err_str = s.unwrap_err().to_string();
294 assert_eq!(
295 err_str,
296 "Payload (15 bytes) is larger than allowed (limit: 10 bytes).",
297 );
298
299 let (req, mut pl) = TestRequest::default()
300 .insert_header(header::ContentType::json())
301 .insert_header(crate::header::ContentLength::from(16))
302 .set_payload(web::Bytes::from_static(b"foo foo foo foo"))
303 .to_http_parts();
304 let s = Bytes::<10>::from_request(&req, &mut pl).await;
305 let err = format!("{}", s.unwrap_err());
306 assert!(
307 err.contains("larger than allowed"),
308 "unexpected error string: {err:?}",
309 );
310 }
311
312 #[actix_web::test]
313 async fn body() {
314 let (req, mut pl) = TestRequest::default().to_http_parts();
315 let _bytes = BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl)
316 .await
317 .unwrap();
318
319 let (req, mut pl) = TestRequest::default()
320 .insert_header(header::ContentType("application/text".parse().unwrap()))
321 .to_http_parts();
322 BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl)
324 .await
325 .unwrap();
326
327 let (req, mut pl) = TestRequest::default()
328 .insert_header(header::ContentType::json())
329 .insert_header(crate::header::ContentLength::from(10000))
330 .to_http_parts();
331
332 let bytes = BytesBody::<100>::new(&req, &mut pl).await;
333 assert_eq!(
334 bytes.unwrap_err(),
335 BytesPayloadError::OverflowKnownLength {
336 length: 10000,
337 limit: 100
338 }
339 );
340
341 let (req, mut pl) = TestRequest::default()
342 .insert_header(header::ContentType::json())
343 .set_payload(web::Bytes::from_static(&[0u8; 1000]))
344 .to_http_parts();
345
346 let bytes = BytesBody::<100>::new(&req, &mut pl).await;
347
348 assert_eq!(
349 bytes.unwrap_err(),
350 BytesPayloadError::OverflowKnownLength {
351 length: 1000,
352 limit: 100
353 },
354 );
355
356 let (req, mut pl) = TestRequest::default()
357 .set_payload(web::Bytes::from_static(b"foo foo foo foo"))
358 .to_http_parts();
359
360 let bytes = BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl).await;
361 assert_eq!(bytes.ok().unwrap(), "foo foo foo foo");
362 }
363
364 #[actix_web::test]
365 async fn test_with_config_in_data_wrapper() {
366 let (req, mut pl) = TestRequest::default()
367 .app_data(web::Data::new(web::PayloadConfig::default().limit(8)))
368 .insert_header(header::ContentType::json())
369 .insert_header((header::CONTENT_LENGTH, 16))
370 .set_payload(web::Bytes::from_static(b"{\"name\": \"test\"}"))
371 .to_http_parts();
372
373 let s = Bytes::<10>::from_request(&req, &mut pl).await;
374 assert!(s.is_err());
375
376 let err_str = s.unwrap_err().to_string();
377 assert_eq!(
378 err_str,
379 "Payload (16 bytes) is larger than allowed (limit: 10 bytes).",
380 );
381 }
382}