actix_web_buffering/
buffering.rs

1use std::{
2    fs::{File, OpenOptions},
3    io::{Read, Seek, SeekFrom, Write},
4    path::{Path, PathBuf},
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use actix_web::{
10    dev::{Body, BodySize, MessageBody, Payload, ResponseBody, ServiceRequest, ServiceResponse},
11    web::{Bytes, BytesMut},
12    HttpMessage,
13};
14use futures::{ready, Stream, StreamExt};
15use uuid::Uuid;
16
17struct RequestBufferedMark;
18struct ResponseBufferedMark;
19
20/// See crate example
21pub fn enable_request_buffering<T>(wrapper: T, req: &mut ServiceRequest)
22where
23    T: AsRef<FileBufferingStreamWrapper>,
24{
25    if !req.extensions().contains::<RequestBufferedMark>() {
26        let inner = req.take_payload();
27        req.set_payload(Payload::Stream(wrapper.as_ref().wrap(inner).boxed_local()));
28
29        req.extensions_mut().insert(RequestBufferedMark)
30    }
31}
32
33/// See crate example
34pub fn enable_response_buffering<T>(
35    wrapper: T,
36    mut svc_res: ServiceResponse<Body>,
37) -> ServiceResponse<Body>
38where
39    T: AsRef<FileBufferingStreamWrapper>,
40{
41    if !svc_res
42        .response()
43        .extensions()
44        .contains::<ResponseBufferedMark>()
45    {
46        svc_res
47            .response_mut()
48            .extensions_mut()
49            .insert(ResponseBufferedMark);
50
51        svc_res.map_body(|_, rb| {
52            let wrapped = wrapper.as_ref().wrap(rb);
53            ResponseBody::Body(Body::Message(Box::new(wrapped)))
54        })
55    } else {
56        svc_res
57    }
58}
59
60/// File buffering stream wrapper. After wrap stream can be read multiple times
61pub struct FileBufferingStreamWrapper {
62    tmp_dir: PathBuf,
63    threshold: usize,
64    produce_chunk_size: usize,
65    buffer_limit: Option<usize>,
66}
67
68impl FileBufferingStreamWrapper {
69    pub fn new() -> Self {
70        Self {
71            tmp_dir: std::env::temp_dir(),
72            threshold: 1024 * 30,
73            produce_chunk_size: 1024 * 30,
74            buffer_limit: None,
75        }
76    }
77
78    /// The temporary dir for larger bodies
79    pub fn tmp_dir(mut self, v: impl AsRef<Path>) -> Self {
80        self.tmp_dir = v.as_ref().to_path_buf();
81        self
82    }
83
84    /// The maximum size in bytes of the in-memory used to buffer the stream. Larger bodies are written to disk
85    pub fn threshold(mut self, v: usize) -> Self {
86        self.threshold = v;
87        self
88    }
89
90    /// The chunk size for read buffered bodies
91    pub fn produce_chunk_size(mut self, v: usize) -> Self {
92        self.produce_chunk_size = v;
93        self
94    }
95
96    /// The maximum size in bytes of the body. An attempt to read beyond this limit will cause an error
97    pub fn buffer_limit(mut self, v: Option<usize>) -> Self {
98        self.buffer_limit = v;
99        self
100    }
101
102    pub fn wrap<S>(&self, inner: S) -> FileBufferingStream<S> {
103        FileBufferingStream::new(
104            inner,
105            self.tmp_dir.to_path_buf(),
106            self.threshold,
107            self.produce_chunk_size,
108            self.buffer_limit,
109        )
110    }
111}
112
113impl AsRef<FileBufferingStreamWrapper> for FileBufferingStreamWrapper {
114    fn as_ref(&self) -> &FileBufferingStreamWrapper {
115        self
116    }
117}
118
119enum Buffer {
120    Memory(BytesMut),
121    File(PathBuf, File),
122}
123
124pub struct FileBufferingStream<S> {
125    inner: S,
126    inner_eof: bool,
127
128    tmp_dir: PathBuf,
129    threshold: usize,
130    produce_chunk_size: usize,
131    buffer_limit: Option<usize>,
132
133    buffer: Buffer,
134    buffer_size: usize,
135    produce_index: usize,
136}
137
138impl<S> Drop for FileBufferingStream<S> {
139    fn drop(&mut self) {
140        match self.buffer {
141            Buffer::Memory(_) => {}
142            Buffer::File(ref path, _) => match std::fs::remove_file(path) {
143                Ok(_) => {}
144                Err(e) => println!("error at remove buffering file {:?}. {}", path, e),
145            },
146        };
147    }
148}
149
150impl<S> FileBufferingStream<S> {
151    fn new(
152        inner: S,
153        tmp_dir: PathBuf,
154        threshold: usize,
155        produce_chunk_size: usize,
156        buffer_limit: Option<usize>,
157    ) -> Self {
158        Self {
159            inner: inner,
160            inner_eof: false,
161
162            tmp_dir,
163            threshold,
164            produce_chunk_size,
165            buffer_limit: buffer_limit,
166
167            buffer: Buffer::Memory(BytesMut::new()),
168            buffer_size: 0,
169            produce_index: 0,
170        }
171    }
172
173    fn write_to_buffer(&mut self, bytes: &Bytes) -> Result<(), BufferingError> {
174        if let Some(limit) = self.buffer_limit {
175            if self.buffer_size + bytes.len() > limit {
176                return Err(BufferingError::Overflow);
177            }
178        }
179
180        match self.buffer {
181            Buffer::Memory(ref mut memory) => {
182                if self.threshold < memory.len() + bytes.len() {
183                    let mut path = self.tmp_dir.to_path_buf();
184                    path.push(Uuid::new_v4().to_simple().to_string());
185
186                    let mut file = OpenOptions::new()
187                        .write(true)
188                        .read(true)
189                        .create_new(true)
190                        .open(&path)?;
191
192                    file.write_all(&memory[..])?;
193                    file.write_all(bytes)?;
194
195                    self.buffer = Buffer::File(path, file);
196                } else {
197                    memory.extend_from_slice(bytes)
198                }
199            }
200            Buffer::File(_, ref mut file) => {
201                file.write_all(bytes)?;
202            }
203        }
204
205        self.buffer_size += bytes.len();
206
207        Ok(())
208    }
209
210    fn read_from_buffer(&mut self) -> Result<Bytes, BufferingError> {
211        let chunk_size = self.produce_chunk_size;
212        let buffer_size = self.buffer_size;
213        let current_index = self.produce_index;
214
215        if buffer_size <= current_index {
216            self.produce_index = 0;
217            return Ok(Bytes::new());
218        }
219
220        let bytes = match self.buffer {
221            Buffer::Memory(ref memory) => {
222                let bytes = {
223                    if buffer_size <= current_index + chunk_size {
224                        self.produce_index = buffer_size;
225                        let start = current_index as usize;
226                        Bytes::copy_from_slice(&memory[start..])
227                    } else {
228                        self.produce_index += chunk_size;
229                        let start = current_index as usize;
230                        let end = (current_index + chunk_size) as usize;
231                        Bytes::copy_from_slice(&memory[start..end])
232                    }
233                };
234
235                bytes
236            }
237            Buffer::File(_, ref mut file) => {
238                if current_index == 0 {
239                    file.seek(SeekFrom::Start(0))?;
240                    file.flush()?;
241                }
242
243                let mut bytes = {
244                    if buffer_size <= current_index + chunk_size {
245                        self.produce_index = buffer_size;
246                        vec![0u8; buffer_size - current_index]
247                    } else {
248                        self.produce_index += chunk_size;
249                        vec![0u8; chunk_size]
250                    }
251                };
252
253                file.read_exact(bytes.as_mut_slice())?;
254
255                bytes.into()
256            }
257        };
258
259        Ok(bytes)
260    }
261}
262
263impl<S, E> FileBufferingStream<S>
264where
265    S: Stream<Item = Result<Bytes, E>> + Unpin,
266{
267    fn generic_poll_next<I>(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, I>>>
268    where
269        E: Into<I>,
270        I: From<BufferingError>,
271    {
272        let this = self.get_mut();
273
274        match this.inner_eof {
275            false => {
276                let op = ready!(this.inner.poll_next_unpin(cx));
277                match op {
278                    Some(ref r) => {
279                        if let Ok(ref o) = r {
280                            this.write_to_buffer(o)?;
281                        }
282                    }
283                    None => {
284                        this.inner_eof = true;
285                    }
286                };
287
288                Poll::Ready(op.map(|res| res.map_err(Into::into)))
289            }
290            true => {
291                let bytes = this.read_from_buffer()?;
292                if bytes.len() == 0 {
293                    Poll::Ready(None)
294                } else {
295                    Poll::Ready(Some(Ok(bytes)))
296                }
297            }
298        }
299    }
300}
301
302#[derive(Debug)]
303enum BufferingError {
304    Overflow,
305    Io(std::io::Error),
306}
307
308impl From<std::io::Error> for BufferingError {
309    fn from(e: std::io::Error) -> Self {
310        Self::Io(e)
311    }
312}
313
314impl<S, E> MessageBody for FileBufferingStream<S>
315where
316    S: Stream<Item = Result<Bytes, E>> + Unpin,
317    E: Into<actix_web::Error>,
318{    
319    fn size(&self) -> BodySize {
320        match self.inner_eof {
321            false => BodySize::Stream,
322            true =>  BodySize::Sized(self.buffer_size as u64)
323        }
324    }
325
326    fn poll_next(
327        self: Pin<&mut Self>,
328        cx: &mut Context<'_>,
329    ) -> Poll<Option<Result<Bytes, actix_web::Error>>> {
330        self.generic_poll_next(cx)
331    }
332}
333
334impl<S> Stream for FileBufferingStream<S>
335where
336    S: Stream<Item = Result<Bytes, actix_web::error::PayloadError>> + Unpin,
337{
338    type Item = Result<Bytes, actix_web::error::PayloadError>;
339
340    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
341        self.generic_poll_next(cx)
342    }
343
344    fn size_hint(&self) -> (usize, Option<usize>) {
345        match self.inner_eof {
346            false => self.inner.size_hint(),
347            true => (self.produce_index, Some(self.buffer_size))
348        }
349    }
350}
351
352impl From<BufferingError> for actix_web::error::PayloadError {
353    fn from(e: BufferingError) -> Self {
354        match e {
355            BufferingError::Overflow => actix_web::error::PayloadError::Overflow,
356            BufferingError::Io(io) => io.into(),
357        }
358    }
359}
360
361impl From<BufferingError> for actix_web::Error {
362    fn from(e: BufferingError) -> Self {
363        match e {
364            BufferingError::Overflow => actix_web::error::PayloadError::Overflow.into(),
365            BufferingError::Io(io) => io.into(),
366        }
367    }
368}