1use std::future::Future;
4use std::io::{self, Cursor, Seek, SeekFrom, Write};
5use std::pin::Pin;
6use std::task::{ready, Context, Poll};
7use tokio::fs::File;
8use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
9use tokio::task::JoinHandle;
10
11pub use tempfile;
12
13#[derive(Debug)]
14enum DataLocation {
15 InMemory(Option<Cursor<Vec<u8>>>),
16 WritingToDisk(JoinHandle<io::Result<File>>),
17 OnDisk(File),
18 Poisoned,
19}
20
21#[derive(Debug)]
22struct Inner {
23 data_location: DataLocation,
24 last_write_err: Option<io::Error>,
25}
26
27#[derive(Debug)]
29pub enum SpooledData {
30 InMemory(Cursor<Vec<u8>>),
31 OnDisk(File),
32}
33
34#[derive(Debug)]
36pub struct SpooledTempFile {
37 max_size: usize,
38 inner: Inner,
39}
40
41impl SpooledTempFile {
42 pub fn new(max_size: usize) -> Self {
45 Self {
46 max_size,
47 inner: Inner {
48 data_location: DataLocation::InMemory(Some(Cursor::new(Vec::new()))),
49 last_write_err: None,
50 },
51 }
52 }
53
54 pub fn with_max_size_and_capacity(max_size: usize, capacity: usize) -> Self {
57 Self {
58 max_size,
59 inner: Inner {
60 data_location: DataLocation::InMemory(Some(Cursor::new(Vec::with_capacity(
61 capacity,
62 )))),
63 last_write_err: None,
64 },
65 }
66 }
67
68 pub fn is_rolled(&self) -> bool {
70 std::matches!(self.inner.data_location, DataLocation::OnDisk(..))
71 }
72
73 pub fn is_poisoned(&self) -> bool {
79 std::matches!(self.inner.data_location, DataLocation::Poisoned)
80 }
81
82 pub async fn into_inner(self) -> Result<SpooledData, io::Error> {
84 match self.inner.data_location {
85 DataLocation::InMemory(opt_mem_buffer) => {
86 Ok(SpooledData::InMemory(opt_mem_buffer.unwrap()))
87 }
88 DataLocation::WritingToDisk(handle) => match handle.await {
89 Ok(Ok(file)) => Ok(SpooledData::OnDisk(file)),
90 Ok(Err(err)) => Err(err),
91 Err(_) => Err(io::Error::new(
92 io::ErrorKind::Other,
93 "background task failed",
94 )),
95 },
96 DataLocation::OnDisk(file) => Ok(SpooledData::OnDisk(file)),
97 DataLocation::Poisoned => Err(io::Error::new(
98 io::ErrorKind::Other,
99 "failed to move data from memory to disk",
100 )),
101 }
102 }
103
104 fn poll_roll(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
105 loop {
106 match self.inner.data_location {
107 DataLocation::InMemory(ref mut opt_mem_buffer) => {
108 let mut mem_buffer = opt_mem_buffer.take().unwrap();
109
110 let handle = tokio::task::spawn_blocking(move || {
111 let mut file = tempfile::tempfile()?;
112
113 file.write_all(mem_buffer.get_mut())?;
114 file.seek(SeekFrom::Start(mem_buffer.position()))?;
115
116 Ok(File::from_std(file))
117 });
118
119 self.inner.data_location = DataLocation::WritingToDisk(handle);
120 }
121 DataLocation::WritingToDisk(ref mut handle) => {
122 let res = ready!(Pin::new(handle).poll(cx));
123
124 match res {
125 Ok(Ok(file)) => {
126 self.inner.data_location = DataLocation::OnDisk(file);
127 }
128 Ok(Err(err)) => {
129 self.inner.data_location = DataLocation::Poisoned;
130 return Poll::Ready(Err(err));
131 }
132 Err(_) => {
133 self.inner.data_location = DataLocation::Poisoned;
134 return Poll::Ready(Err(io::Error::new(
135 io::ErrorKind::Other,
136 "background task failed",
137 )));
138 }
139 }
140 }
141 DataLocation::OnDisk(_) => {
142 return Poll::Ready(Ok(()));
143 }
144 DataLocation::Poisoned => {
145 return Poll::Ready(Err(io::Error::new(
146 io::ErrorKind::Other,
147 "failed to move data from memory to disk",
148 )));
149 }
150 }
151 }
152 }
153
154 pub async fn roll(&mut self) -> io::Result<()> {
157 std::future::poll_fn(|cx| self.poll_roll(cx)).await
158 }
159
160 pub async fn set_len(&mut self, size: u64) -> Result<(), io::Error> {
164 if size > self.max_size as u64 {
165 self.roll().await?;
166 }
167
168 loop {
169 match self.inner.data_location {
170 DataLocation::InMemory(ref mut opt_mem_buffer) => {
171 opt_mem_buffer
172 .as_mut()
173 .unwrap()
174 .get_mut()
175 .resize(size as usize, 0);
176 return Ok(());
177 }
178 DataLocation::WritingToDisk(_) => {
179 self.roll().await?;
180 }
181 DataLocation::OnDisk(ref mut file) => {
182 return file.set_len(size).await;
183 }
184 DataLocation::Poisoned => {
185 return Err(io::Error::new(
186 io::ErrorKind::Other,
187 "failed to move data from memory to disk",
188 ));
189 }
190 }
191 }
192 }
193}
194
195impl AsyncWrite for SpooledTempFile {
196 fn poll_write(
197 self: Pin<&mut Self>,
198 cx: &mut Context<'_>,
199 buf: &[u8],
200 ) -> Poll<Result<usize, io::Error>> {
201 let me = self.get_mut();
202
203 if let Some(err) = me.inner.last_write_err.take() {
204 return Poll::Ready(Err(err));
205 }
206
207 loop {
208 match me.inner.data_location {
209 DataLocation::InMemory(ref mut opt_mem_buffer) => {
210 let mut mem_buffer = opt_mem_buffer.take().unwrap();
211
212 if mem_buffer.position().saturating_add(buf.len() as u64) > me.max_size as u64 {
213 *opt_mem_buffer = Some(mem_buffer);
214
215 ready!(me.poll_roll(cx))?;
216
217 continue;
218 }
219
220 let res = Pin::new(&mut mem_buffer).poll_write(cx, buf);
221
222 *opt_mem_buffer = Some(mem_buffer);
223
224 return res;
225 }
226 DataLocation::WritingToDisk(_) => {
227 ready!(me.poll_roll(cx))?;
228 }
229 DataLocation::OnDisk(ref mut file) => {
230 return Pin::new(file).poll_write(cx, buf);
231 }
232 DataLocation::Poisoned => {
233 return Poll::Ready(Err(io::Error::new(
234 io::ErrorKind::Other,
235 "failed to move data from memory to disk",
236 )));
237 }
238 }
239 }
240 }
241
242 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
243 let me = self.get_mut();
244
245 match me.inner.data_location {
246 DataLocation::InMemory(ref mut opt_mem_buffer) => {
247 Pin::new(opt_mem_buffer.as_mut().unwrap()).poll_flush(cx)
248 }
249 DataLocation::WritingToDisk(_) => me.poll_roll(cx),
250 DataLocation::OnDisk(ref mut file) => Pin::new(file).poll_flush(cx),
251 DataLocation::Poisoned => Poll::Ready(Err(io::Error::new(
252 io::ErrorKind::Other,
253 "failed to move data from memory to disk",
254 ))),
255 }
256 }
257
258 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
259 self.poll_flush(cx)
260 }
261}
262
263impl AsyncRead for SpooledTempFile {
264 fn poll_read(
265 self: Pin<&mut Self>,
266 cx: &mut Context<'_>,
267 buf: &mut ReadBuf<'_>,
268 ) -> Poll<io::Result<()>> {
269 let me = self.get_mut();
270
271 loop {
272 match me.inner.data_location {
273 DataLocation::InMemory(ref mut opt_mem_buffer) => {
274 return Pin::new(opt_mem_buffer.as_mut().unwrap()).poll_read(cx, buf);
275 }
276 DataLocation::WritingToDisk(_) => {
277 if let Err(write_err) = ready!(me.poll_roll(cx)) {
278 me.inner.last_write_err = Some(write_err);
279 }
280 }
281 DataLocation::OnDisk(ref mut file) => {
282 return Pin::new(file).poll_read(cx, buf);
283 }
284 DataLocation::Poisoned => {
285 return Poll::Ready(Err(io::Error::new(
286 io::ErrorKind::Other,
287 "failed to move data from memory to disk",
288 )));
289 }
290 }
291 }
292 }
293}
294
295impl AsyncSeek for SpooledTempFile {
296 fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
297 let me = self.get_mut();
298
299 match me.inner.data_location {
300 DataLocation::InMemory(ref mut opt_mem_buffer) => {
301 Pin::new(opt_mem_buffer.as_mut().unwrap()).start_seek(position)
302 }
303 DataLocation::WritingToDisk(_) => Err(io::Error::new(
304 io::ErrorKind::Other,
305 "other operation is pending, call poll_complete before start_seek",
306 )),
307 DataLocation::OnDisk(ref mut file) => Pin::new(file).start_seek(position),
308 DataLocation::Poisoned => Err(io::Error::new(
309 io::ErrorKind::Other,
310 "failed to move data from memory to disk",
311 )),
312 }
313 }
314
315 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
316 let me = self.get_mut();
317
318 loop {
319 match me.inner.data_location {
320 DataLocation::InMemory(ref mut opt_mem_buffer) => {
321 return Pin::new(opt_mem_buffer.as_mut().unwrap()).poll_complete(cx);
322 }
323 DataLocation::WritingToDisk(_) => {
324 if let Err(write_err) = ready!(me.poll_roll(cx)) {
325 me.inner.last_write_err = Some(write_err);
326 }
327 }
328 DataLocation::OnDisk(ref mut file) => {
329 return Pin::new(file).poll_complete(cx);
330 }
331 DataLocation::Poisoned => {
332 return Poll::Ready(Err(io::Error::new(
333 io::ErrorKind::Other,
334 "failed to move data from memory to disk",
335 )));
336 }
337 }
338 }
339 }
340}