1use std::{
2 fs::{File, OpenOptions},
3 io::{Read, Seek, SeekFrom, Write},
4 path::{Path, PathBuf},
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use actix_web::{
10 dev::{Body, BodySize, MessageBody, Payload, ResponseBody, ServiceRequest, ServiceResponse},
11 web::{Bytes, BytesMut},
12 HttpMessage,
13};
14use futures::{ready, Stream, StreamExt};
15use uuid::Uuid;
16
17struct RequestBufferedMark;
18struct ResponseBufferedMark;
19
20pub fn enable_request_buffering<T>(wrapper: T, req: &mut ServiceRequest)
22where
23 T: AsRef<FileBufferingStreamWrapper>,
24{
25 if !req.extensions().contains::<RequestBufferedMark>() {
26 let inner = req.take_payload();
27 req.set_payload(Payload::Stream(wrapper.as_ref().wrap(inner).boxed_local()));
28
29 req.extensions_mut().insert(RequestBufferedMark)
30 }
31}
32
33pub fn enable_response_buffering<T>(
35 wrapper: T,
36 mut svc_res: ServiceResponse<Body>,
37) -> ServiceResponse<Body>
38where
39 T: AsRef<FileBufferingStreamWrapper>,
40{
41 if !svc_res
42 .response()
43 .extensions()
44 .contains::<ResponseBufferedMark>()
45 {
46 svc_res
47 .response_mut()
48 .extensions_mut()
49 .insert(ResponseBufferedMark);
50
51 svc_res.map_body(|_, rb| {
52 let wrapped = wrapper.as_ref().wrap(rb);
53 ResponseBody::Body(Body::Message(Box::new(wrapped)))
54 })
55 } else {
56 svc_res
57 }
58}
59
60pub struct FileBufferingStreamWrapper {
62 tmp_dir: PathBuf,
63 threshold: usize,
64 produce_chunk_size: usize,
65 buffer_limit: Option<usize>,
66}
67
68impl FileBufferingStreamWrapper {
69 pub fn new() -> Self {
70 Self {
71 tmp_dir: std::env::temp_dir(),
72 threshold: 1024 * 30,
73 produce_chunk_size: 1024 * 30,
74 buffer_limit: None,
75 }
76 }
77
78 pub fn tmp_dir(mut self, v: impl AsRef<Path>) -> Self {
80 self.tmp_dir = v.as_ref().to_path_buf();
81 self
82 }
83
84 pub fn threshold(mut self, v: usize) -> Self {
86 self.threshold = v;
87 self
88 }
89
90 pub fn produce_chunk_size(mut self, v: usize) -> Self {
92 self.produce_chunk_size = v;
93 self
94 }
95
96 pub fn buffer_limit(mut self, v: Option<usize>) -> Self {
98 self.buffer_limit = v;
99 self
100 }
101
102 pub fn wrap<S>(&self, inner: S) -> FileBufferingStream<S> {
103 FileBufferingStream::new(
104 inner,
105 self.tmp_dir.to_path_buf(),
106 self.threshold,
107 self.produce_chunk_size,
108 self.buffer_limit,
109 )
110 }
111}
112
113impl AsRef<FileBufferingStreamWrapper> for FileBufferingStreamWrapper {
114 fn as_ref(&self) -> &FileBufferingStreamWrapper {
115 self
116 }
117}
118
119enum Buffer {
120 Memory(BytesMut),
121 File(PathBuf, File),
122}
123
124pub struct FileBufferingStream<S> {
125 inner: S,
126 inner_eof: bool,
127
128 tmp_dir: PathBuf,
129 threshold: usize,
130 produce_chunk_size: usize,
131 buffer_limit: Option<usize>,
132
133 buffer: Buffer,
134 buffer_size: usize,
135 produce_index: usize,
136}
137
138impl<S> Drop for FileBufferingStream<S> {
139 fn drop(&mut self) {
140 match self.buffer {
141 Buffer::Memory(_) => {}
142 Buffer::File(ref path, _) => match std::fs::remove_file(path) {
143 Ok(_) => {}
144 Err(e) => println!("error at remove buffering file {:?}. {}", path, e),
145 },
146 };
147 }
148}
149
150impl<S> FileBufferingStream<S> {
151 fn new(
152 inner: S,
153 tmp_dir: PathBuf,
154 threshold: usize,
155 produce_chunk_size: usize,
156 buffer_limit: Option<usize>,
157 ) -> Self {
158 Self {
159 inner: inner,
160 inner_eof: false,
161
162 tmp_dir,
163 threshold,
164 produce_chunk_size,
165 buffer_limit: buffer_limit,
166
167 buffer: Buffer::Memory(BytesMut::new()),
168 buffer_size: 0,
169 produce_index: 0,
170 }
171 }
172
173 fn write_to_buffer(&mut self, bytes: &Bytes) -> Result<(), BufferingError> {
174 if let Some(limit) = self.buffer_limit {
175 if self.buffer_size + bytes.len() > limit {
176 return Err(BufferingError::Overflow);
177 }
178 }
179
180 match self.buffer {
181 Buffer::Memory(ref mut memory) => {
182 if self.threshold < memory.len() + bytes.len() {
183 let mut path = self.tmp_dir.to_path_buf();
184 path.push(Uuid::new_v4().to_simple().to_string());
185
186 let mut file = OpenOptions::new()
187 .write(true)
188 .read(true)
189 .create_new(true)
190 .open(&path)?;
191
192 file.write_all(&memory[..])?;
193 file.write_all(bytes)?;
194
195 self.buffer = Buffer::File(path, file);
196 } else {
197 memory.extend_from_slice(bytes)
198 }
199 }
200 Buffer::File(_, ref mut file) => {
201 file.write_all(bytes)?;
202 }
203 }
204
205 self.buffer_size += bytes.len();
206
207 Ok(())
208 }
209
210 fn read_from_buffer(&mut self) -> Result<Bytes, BufferingError> {
211 let chunk_size = self.produce_chunk_size;
212 let buffer_size = self.buffer_size;
213 let current_index = self.produce_index;
214
215 if buffer_size <= current_index {
216 self.produce_index = 0;
217 return Ok(Bytes::new());
218 }
219
220 let bytes = match self.buffer {
221 Buffer::Memory(ref memory) => {
222 let bytes = {
223 if buffer_size <= current_index + chunk_size {
224 self.produce_index = buffer_size;
225 let start = current_index as usize;
226 Bytes::copy_from_slice(&memory[start..])
227 } else {
228 self.produce_index += chunk_size;
229 let start = current_index as usize;
230 let end = (current_index + chunk_size) as usize;
231 Bytes::copy_from_slice(&memory[start..end])
232 }
233 };
234
235 bytes
236 }
237 Buffer::File(_, ref mut file) => {
238 if current_index == 0 {
239 file.seek(SeekFrom::Start(0))?;
240 file.flush()?;
241 }
242
243 let mut bytes = {
244 if buffer_size <= current_index + chunk_size {
245 self.produce_index = buffer_size;
246 vec![0u8; buffer_size - current_index]
247 } else {
248 self.produce_index += chunk_size;
249 vec![0u8; chunk_size]
250 }
251 };
252
253 file.read_exact(bytes.as_mut_slice())?;
254
255 bytes.into()
256 }
257 };
258
259 Ok(bytes)
260 }
261}
262
263impl<S, E> FileBufferingStream<S>
264where
265 S: Stream<Item = Result<Bytes, E>> + Unpin,
266{
267 fn generic_poll_next<I>(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, I>>>
268 where
269 E: Into<I>,
270 I: From<BufferingError>,
271 {
272 let this = self.get_mut();
273
274 match this.inner_eof {
275 false => {
276 let op = ready!(this.inner.poll_next_unpin(cx));
277 match op {
278 Some(ref r) => {
279 if let Ok(ref o) = r {
280 this.write_to_buffer(o)?;
281 }
282 }
283 None => {
284 this.inner_eof = true;
285 }
286 };
287
288 Poll::Ready(op.map(|res| res.map_err(Into::into)))
289 }
290 true => {
291 let bytes = this.read_from_buffer()?;
292 if bytes.len() == 0 {
293 Poll::Ready(None)
294 } else {
295 Poll::Ready(Some(Ok(bytes)))
296 }
297 }
298 }
299 }
300}
301
302#[derive(Debug)]
303enum BufferingError {
304 Overflow,
305 Io(std::io::Error),
306}
307
308impl From<std::io::Error> for BufferingError {
309 fn from(e: std::io::Error) -> Self {
310 Self::Io(e)
311 }
312}
313
314impl<S, E> MessageBody for FileBufferingStream<S>
315where
316 S: Stream<Item = Result<Bytes, E>> + Unpin,
317 E: Into<actix_web::Error>,
318{
319 fn size(&self) -> BodySize {
320 match self.inner_eof {
321 false => BodySize::Stream,
322 true => BodySize::Sized(self.buffer_size as u64)
323 }
324 }
325
326 fn poll_next(
327 self: Pin<&mut Self>,
328 cx: &mut Context<'_>,
329 ) -> Poll<Option<Result<Bytes, actix_web::Error>>> {
330 self.generic_poll_next(cx)
331 }
332}
333
334impl<S> Stream for FileBufferingStream<S>
335where
336 S: Stream<Item = Result<Bytes, actix_web::error::PayloadError>> + Unpin,
337{
338 type Item = Result<Bytes, actix_web::error::PayloadError>;
339
340 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
341 self.generic_poll_next(cx)
342 }
343
344 fn size_hint(&self) -> (usize, Option<usize>) {
345 match self.inner_eof {
346 false => self.inner.size_hint(),
347 true => (self.produce_index, Some(self.buffer_size))
348 }
349 }
350}
351
352impl From<BufferingError> for actix_web::error::PayloadError {
353 fn from(e: BufferingError) -> Self {
354 match e {
355 BufferingError::Overflow => actix_web::error::PayloadError::Overflow,
356 BufferingError::Io(io) => io.into(),
357 }
358 }
359}
360
361impl From<BufferingError> for actix_web::Error {
362 fn from(e: BufferingError) -> Self {
363 match e {
364 BufferingError::Overflow => actix_web::error::PayloadError::Overflow.into(),
365 BufferingError::Io(io) => io.into(),
366 }
367 }
368}