async_spooled_tempfile/
lib.rs

1//! Crate providing an asynchronous version of the [`tempfile::SpooledTempFile`](https://docs.rs/tempfile/latest/tempfile/struct.SpooledTempFile.html)
2//! structure exposed by the [tempfile](https://docs.rs/tempfile/latest/tempfile/index.html) crate.
3use std::future::Future;
4use std::io::{self, Cursor, Seek, SeekFrom, Write};
5use std::pin::Pin;
6use std::task::{ready, Context, Poll};
7use tokio::fs::File;
8use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
9use tokio::task::JoinHandle;
10
11pub use tempfile;
12
13#[derive(Debug)]
14enum DataLocation {
15    InMemory(Option<Cursor<Vec<u8>>>),
16    WritingToDisk(JoinHandle<io::Result<File>>),
17    OnDisk(File),
18    Poisoned,
19}
20
21#[derive(Debug)]
22struct Inner {
23    data_location: DataLocation,
24    last_write_err: Option<io::Error>,
25}
26
27/// Data stored in a [`SpooledTempFile`] instance.
28#[derive(Debug)]
29pub enum SpooledData {
30    InMemory(Cursor<Vec<u8>>),
31    OnDisk(File),
32}
33
34/// Asynchronous version of [`tempfile::SpooledTempFile`](https://docs.rs/tempfile/latest/tempfile/struct.SpooledTempFile.html).
35#[derive(Debug)]
36pub struct SpooledTempFile {
37    max_size: usize,
38    inner: Inner,
39}
40
41impl SpooledTempFile {
42    /// Creates a new instance of [`SpooledTempFile`] that can hold up to `max_size` bytes in
43    /// memory.
44    pub fn new(max_size: usize) -> Self {
45        Self {
46            max_size,
47            inner: Inner {
48                data_location: DataLocation::InMemory(Some(Cursor::new(Vec::new()))),
49                last_write_err: None,
50            },
51        }
52    }
53
54    /// Creates a new instance of [`SpooledTempFile`] that can hold up to `max_size` bytes in
55    /// memory and allocates space for the in-memory buffer.
56    pub fn with_max_size_and_capacity(max_size: usize, capacity: usize) -> Self {
57        Self {
58            max_size,
59            inner: Inner {
60                data_location: DataLocation::InMemory(Some(Cursor::new(Vec::with_capacity(
61                    capacity,
62                )))),
63                last_write_err: None,
64            },
65        }
66    }
67
68    /// Returns `true` if the data have been written to a file.
69    pub fn is_rolled(&self) -> bool {
70        std::matches!(self.inner.data_location, DataLocation::OnDisk(..))
71    }
72
73    /// Determines whether the current instance is poisoned or not.
74    ///
75    /// An instance of [`SpooledTempFile`] is poisoned if it failed to move its data
76    /// from memory to disk.
77    ///
78    pub fn is_poisoned(&self) -> bool {
79        std::matches!(self.inner.data_location, DataLocation::Poisoned)
80    }
81
82    /// Consumes and returns the inner [`SpooledData`] type.
83    pub async fn into_inner(self) -> Result<SpooledData, io::Error> {
84        match self.inner.data_location {
85            DataLocation::InMemory(opt_mem_buffer) => {
86                Ok(SpooledData::InMemory(opt_mem_buffer.unwrap()))
87            }
88            DataLocation::WritingToDisk(handle) => match handle.await {
89                Ok(Ok(file)) => Ok(SpooledData::OnDisk(file)),
90                Ok(Err(err)) => Err(err),
91                Err(_) => Err(io::Error::new(
92                    io::ErrorKind::Other,
93                    "background task failed",
94                )),
95            },
96            DataLocation::OnDisk(file) => Ok(SpooledData::OnDisk(file)),
97            DataLocation::Poisoned => Err(io::Error::new(
98                io::ErrorKind::Other,
99                "failed to move data from memory to disk",
100            )),
101        }
102    }
103
104    fn poll_roll(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
105        loop {
106            match self.inner.data_location {
107                DataLocation::InMemory(ref mut opt_mem_buffer) => {
108                    let mut mem_buffer = opt_mem_buffer.take().unwrap();
109
110                    let handle = tokio::task::spawn_blocking(move || {
111                        let mut file = tempfile::tempfile()?;
112
113                        file.write_all(mem_buffer.get_mut())?;
114                        file.seek(SeekFrom::Start(mem_buffer.position()))?;
115
116                        Ok(File::from_std(file))
117                    });
118
119                    self.inner.data_location = DataLocation::WritingToDisk(handle);
120                }
121                DataLocation::WritingToDisk(ref mut handle) => {
122                    let res = ready!(Pin::new(handle).poll(cx));
123
124                    match res {
125                        Ok(Ok(file)) => {
126                            self.inner.data_location = DataLocation::OnDisk(file);
127                        }
128                        Ok(Err(err)) => {
129                            self.inner.data_location = DataLocation::Poisoned;
130                            return Poll::Ready(Err(err));
131                        }
132                        Err(_) => {
133                            self.inner.data_location = DataLocation::Poisoned;
134                            return Poll::Ready(Err(io::Error::new(
135                                io::ErrorKind::Other,
136                                "background task failed",
137                            )));
138                        }
139                    }
140                }
141                DataLocation::OnDisk(_) => {
142                    return Poll::Ready(Ok(()));
143                }
144                DataLocation::Poisoned => {
145                    return Poll::Ready(Err(io::Error::new(
146                        io::ErrorKind::Other,
147                        "failed to move data from memory to disk",
148                    )));
149                }
150            }
151        }
152    }
153
154    /// Moves the data from memory to disk.
155    /// Does nothing if the transition has already been made.
156    pub async fn roll(&mut self) -> io::Result<()> {
157        std::future::poll_fn(|cx| self.poll_roll(cx)).await
158    }
159
160    /// Truncates or extends the underlying buffer / file.
161    /// If the provided size is greater than `max_size`, data will be moved from
162    /// memory to disk regardless of the size of the data hold by the current instance.
163    pub async fn set_len(&mut self, size: u64) -> Result<(), io::Error> {
164        if size > self.max_size as u64 {
165            self.roll().await?;
166        }
167
168        loop {
169            match self.inner.data_location {
170                DataLocation::InMemory(ref mut opt_mem_buffer) => {
171                    opt_mem_buffer
172                        .as_mut()
173                        .unwrap()
174                        .get_mut()
175                        .resize(size as usize, 0);
176                    return Ok(());
177                }
178                DataLocation::WritingToDisk(_) => {
179                    self.roll().await?;
180                }
181                DataLocation::OnDisk(ref mut file) => {
182                    return file.set_len(size).await;
183                }
184                DataLocation::Poisoned => {
185                    return Err(io::Error::new(
186                        io::ErrorKind::Other,
187                        "failed to move data from memory to disk",
188                    ));
189                }
190            }
191        }
192    }
193}
194
195impl AsyncWrite for SpooledTempFile {
196    fn poll_write(
197        self: Pin<&mut Self>,
198        cx: &mut Context<'_>,
199        buf: &[u8],
200    ) -> Poll<Result<usize, io::Error>> {
201        let me = self.get_mut();
202
203        if let Some(err) = me.inner.last_write_err.take() {
204            return Poll::Ready(Err(err));
205        }
206
207        loop {
208            match me.inner.data_location {
209                DataLocation::InMemory(ref mut opt_mem_buffer) => {
210                    let mut mem_buffer = opt_mem_buffer.take().unwrap();
211
212                    if mem_buffer.position().saturating_add(buf.len() as u64) > me.max_size as u64 {
213                        *opt_mem_buffer = Some(mem_buffer);
214
215                        ready!(me.poll_roll(cx))?;
216
217                        continue;
218                    }
219
220                    let res = Pin::new(&mut mem_buffer).poll_write(cx, buf);
221
222                    *opt_mem_buffer = Some(mem_buffer);
223
224                    return res;
225                }
226                DataLocation::WritingToDisk(_) => {
227                    ready!(me.poll_roll(cx))?;
228                }
229                DataLocation::OnDisk(ref mut file) => {
230                    return Pin::new(file).poll_write(cx, buf);
231                }
232                DataLocation::Poisoned => {
233                    return Poll::Ready(Err(io::Error::new(
234                        io::ErrorKind::Other,
235                        "failed to move data from memory to disk",
236                    )));
237                }
238            }
239        }
240    }
241
242    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
243        let me = self.get_mut();
244
245        match me.inner.data_location {
246            DataLocation::InMemory(ref mut opt_mem_buffer) => {
247                Pin::new(opt_mem_buffer.as_mut().unwrap()).poll_flush(cx)
248            }
249            DataLocation::WritingToDisk(_) => me.poll_roll(cx),
250            DataLocation::OnDisk(ref mut file) => Pin::new(file).poll_flush(cx),
251            DataLocation::Poisoned => Poll::Ready(Err(io::Error::new(
252                io::ErrorKind::Other,
253                "failed to move data from memory to disk",
254            ))),
255        }
256    }
257
258    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
259        self.poll_flush(cx)
260    }
261}
262
263impl AsyncRead for SpooledTempFile {
264    fn poll_read(
265        self: Pin<&mut Self>,
266        cx: &mut Context<'_>,
267        buf: &mut ReadBuf<'_>,
268    ) -> Poll<io::Result<()>> {
269        let me = self.get_mut();
270
271        loop {
272            match me.inner.data_location {
273                DataLocation::InMemory(ref mut opt_mem_buffer) => {
274                    return Pin::new(opt_mem_buffer.as_mut().unwrap()).poll_read(cx, buf);
275                }
276                DataLocation::WritingToDisk(_) => {
277                    if let Err(write_err) = ready!(me.poll_roll(cx)) {
278                        me.inner.last_write_err = Some(write_err);
279                    }
280                }
281                DataLocation::OnDisk(ref mut file) => {
282                    return Pin::new(file).poll_read(cx, buf);
283                }
284                DataLocation::Poisoned => {
285                    return Poll::Ready(Err(io::Error::new(
286                        io::ErrorKind::Other,
287                        "failed to move data from memory to disk",
288                    )));
289                }
290            }
291        }
292    }
293}
294
295impl AsyncSeek for SpooledTempFile {
296    fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
297        let me = self.get_mut();
298
299        match me.inner.data_location {
300            DataLocation::InMemory(ref mut opt_mem_buffer) => {
301                Pin::new(opt_mem_buffer.as_mut().unwrap()).start_seek(position)
302            }
303            DataLocation::WritingToDisk(_) => Err(io::Error::new(
304                io::ErrorKind::Other,
305                "other operation is pending, call poll_complete before start_seek",
306            )),
307            DataLocation::OnDisk(ref mut file) => Pin::new(file).start_seek(position),
308            DataLocation::Poisoned => Err(io::Error::new(
309                io::ErrorKind::Other,
310                "failed to move data from memory to disk",
311            )),
312        }
313    }
314
315    fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
316        let me = self.get_mut();
317
318        loop {
319            match me.inner.data_location {
320                DataLocation::InMemory(ref mut opt_mem_buffer) => {
321                    return Pin::new(opt_mem_buffer.as_mut().unwrap()).poll_complete(cx);
322                }
323                DataLocation::WritingToDisk(_) => {
324                    if let Err(write_err) = ready!(me.poll_roll(cx)) {
325                        me.inner.last_write_err = Some(write_err);
326                    }
327                }
328                DataLocation::OnDisk(ref mut file) => {
329                    return Pin::new(file).poll_complete(cx);
330                }
331                DataLocation::Poisoned => {
332                    return Poll::Ready(Err(io::Error::new(
333                        io::ErrorKind::Other,
334                        "failed to move data from memory to disk",
335                    )));
336                }
337            }
338        }
339    }
340}