Skip to main content

grafeo_core/execution/spill/
async_file.rs

1//! Async spill file read/write abstraction using tokio.
2
3use std::path::{Path, PathBuf};
4use tokio::fs::File;
5use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt, BufReader, BufWriter, SeekFrom};
6
7/// Buffer size for async spill file I/O (64 KB).
8const BUFFER_SIZE: usize = 64 * 1024;
9
10/// Async handle for a single spill file.
11///
12/// AsyncSpillFile manages a temporary file used for spilling operator state to disk
13/// using tokio's async I/O primitives for non-blocking operations.
14pub struct AsyncSpillFile {
15    /// Path to the spill file.
16    path: PathBuf,
17    /// Buffered writer (Some during write phase, None after finish).
18    writer: Option<BufWriter<File>>,
19    /// Total bytes written to this file.
20    bytes_written: u64,
21}
22
23impl AsyncSpillFile {
24    /// Creates a new async spill file at the given path.
25    ///
26    /// # Errors
27    ///
28    /// Returns an error if the file cannot be created.
29    pub async fn new(path: PathBuf) -> std::io::Result<Self> {
30        let file = File::create(&path).await?;
31        let writer = BufWriter::with_capacity(BUFFER_SIZE, file);
32
33        Ok(Self {
34            path,
35            writer: Some(writer),
36            bytes_written: 0,
37        })
38    }
39
40    /// Returns the path to this spill file.
41    #[must_use]
42    pub fn path(&self) -> &Path {
43        &self.path
44    }
45
46    /// Returns the number of bytes written to this file.
47    #[must_use]
48    pub fn bytes_written(&self) -> u64 {
49        self.bytes_written
50    }
51
52    /// Writes raw bytes to the file asynchronously.
53    ///
54    /// # Errors
55    ///
56    /// Returns an error if the write fails.
57    pub async fn write_all(&mut self, data: &[u8]) -> std::io::Result<()> {
58        let writer = self
59            .writer
60            .as_mut()
61            .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "Write phase ended"))?;
62
63        writer.write_all(data).await?;
64        self.bytes_written += data.len() as u64;
65        Ok(())
66    }
67
68    /// Writes a u64 in little-endian format.
69    ///
70    /// # Errors
71    ///
72    /// Returns an error if the write fails.
73    pub async fn write_u64_le(&mut self, value: u64) -> std::io::Result<()> {
74        self.write_all(&value.to_le_bytes()).await
75    }
76
77    /// Writes an i64 in little-endian format.
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if the write fails.
82    pub async fn write_i64_le(&mut self, value: i64) -> std::io::Result<()> {
83        self.write_all(&value.to_le_bytes()).await
84    }
85
86    /// Writes a length-prefixed byte slice.
87    ///
88    /// Format: [length: u64][data: bytes]
89    ///
90    /// # Errors
91    ///
92    /// Returns an error if the write fails.
93    pub async fn write_bytes(&mut self, data: &[u8]) -> std::io::Result<()> {
94        self.write_u64_le(data.len() as u64).await?;
95        self.write_all(data).await
96    }
97
98    /// Finishes writing and flushes buffers.
99    ///
100    /// After this call, the file is ready for reading.
101    ///
102    /// # Errors
103    ///
104    /// Returns an error if the flush fails.
105    pub async fn finish_write(&mut self) -> std::io::Result<()> {
106        if let Some(mut writer) = self.writer.take() {
107            writer.flush().await?;
108        }
109        Ok(())
110    }
111
112    /// Returns whether this file is still in write mode.
113    #[must_use]
114    pub fn is_writable(&self) -> bool {
115        self.writer.is_some()
116    }
117
118    /// Creates an async reader for this file.
119    ///
120    /// Can be called multiple times to create multiple readers.
121    ///
122    /// # Errors
123    ///
124    /// Returns an error if the file cannot be opened for reading.
125    pub async fn reader(&self) -> std::io::Result<AsyncSpillFileReader> {
126        let file = File::open(&self.path).await?;
127        let reader = BufReader::with_capacity(BUFFER_SIZE, file);
128        Ok(AsyncSpillFileReader { reader })
129    }
130
131    /// Deletes this spill file.
132    ///
133    /// Consumes the AsyncSpillFile handle.
134    ///
135    /// # Errors
136    ///
137    /// Returns an error if the file cannot be deleted.
138    pub async fn delete(mut self) -> std::io::Result<()> {
139        // Close the writer first
140        self.writer = None;
141        tokio::fs::remove_file(&self.path).await
142    }
143}
144
145impl std::fmt::Debug for AsyncSpillFile {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        f.debug_struct("AsyncSpillFile")
148            .field("path", &self.path)
149            .field("bytes_written", &self.bytes_written)
150            .field("is_writable", &self.is_writable())
151            .finish()
152    }
153}
154
155/// Async reader for a spill file.
156///
157/// Provides buffered async reading of spill file contents.
158pub struct AsyncSpillFileReader {
159    /// Buffered reader.
160    reader: BufReader<File>,
161}
162
163impl AsyncSpillFileReader {
164    /// Reads exactly `buf.len()` bytes from the file.
165    ///
166    /// # Errors
167    ///
168    /// Returns an error if not enough bytes are available.
169    pub async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
170        self.reader.read_exact(buf).await?;
171        Ok(())
172    }
173
174    /// Reads a u64 in little-endian format.
175    ///
176    /// # Errors
177    ///
178    /// Returns an error if the read fails.
179    pub async fn read_u64_le(&mut self) -> std::io::Result<u64> {
180        let mut buf = [0u8; 8];
181        self.read_exact(&mut buf).await?;
182        Ok(u64::from_le_bytes(buf))
183    }
184
185    /// Reads an i64 in little-endian format.
186    ///
187    /// # Errors
188    ///
189    /// Returns an error if the read fails.
190    pub async fn read_i64_le(&mut self) -> std::io::Result<i64> {
191        let mut buf = [0u8; 8];
192        self.read_exact(&mut buf).await?;
193        Ok(i64::from_le_bytes(buf))
194    }
195
196    /// Reads a f64 in little-endian format.
197    ///
198    /// # Errors
199    ///
200    /// Returns an error if the read fails.
201    pub async fn read_f64_le(&mut self) -> std::io::Result<f64> {
202        let mut buf = [0u8; 8];
203        self.read_exact(&mut buf).await?;
204        Ok(f64::from_le_bytes(buf))
205    }
206
207    /// Reads a u8 byte.
208    ///
209    /// # Errors
210    ///
211    /// Returns an error if the read fails.
212    pub async fn read_u8(&mut self) -> std::io::Result<u8> {
213        let mut buf = [0u8; 1];
214        self.read_exact(&mut buf).await?;
215        Ok(buf[0])
216    }
217
218    /// Reads a length-prefixed byte slice.
219    ///
220    /// Format: [length: u64][data: bytes]
221    ///
222    /// # Errors
223    ///
224    /// Returns an error if the read fails.
225    pub async fn read_bytes(&mut self) -> std::io::Result<Vec<u8>> {
226        let len = self.read_u64_le().await? as usize;
227        let mut buf = vec![0u8; len];
228        self.read_exact(&mut buf).await?;
229        Ok(buf)
230    }
231
232    /// Seeks to a position in the file.
233    ///
234    /// # Errors
235    ///
236    /// Returns an error if the seek fails.
237    pub async fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
238        self.reader.seek(pos).await
239    }
240
241    /// Seeks to the beginning of the file.
242    ///
243    /// # Errors
244    ///
245    /// Returns an error if the seek fails.
246    pub async fn rewind(&mut self) -> std::io::Result<()> {
247        self.reader.seek(SeekFrom::Start(0)).await?;
248        Ok(())
249    }
250
251    /// Returns the current position in the file.
252    ///
253    /// # Errors
254    ///
255    /// Returns an error if the operation fails.
256    pub async fn position(&mut self) -> std::io::Result<u64> {
257        self.reader.stream_position().await
258    }
259}
260
261impl std::fmt::Debug for AsyncSpillFileReader {
262    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263        f.debug_struct("AsyncSpillFileReader").finish()
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use tempfile::TempDir;
271
272    #[tokio::test]
273    async fn test_async_spill_file_write_read() {
274        let temp_dir = TempDir::new().unwrap();
275        let file_path = temp_dir.path().join("test.spill");
276
277        // Write phase
278        let mut file = AsyncSpillFile::new(file_path).await.unwrap();
279        file.write_all(b"hello ").await.unwrap();
280        file.write_all(b"world").await.unwrap();
281        assert_eq!(file.bytes_written(), 11);
282        file.finish_write().await.unwrap();
283
284        // Read phase
285        let mut reader = file.reader().await.unwrap();
286        let mut buf = [0u8; 11];
287        reader.read_exact(&mut buf).await.unwrap();
288        assert_eq!(&buf, b"hello world");
289    }
290
291    #[tokio::test]
292    async fn test_async_spill_file_integers() {
293        let temp_dir = TempDir::new().unwrap();
294        let file_path = temp_dir.path().join("test.spill");
295
296        let mut file = AsyncSpillFile::new(file_path).await.unwrap();
297        file.write_u64_le(u64::MAX).await.unwrap();
298        file.write_i64_le(i64::MIN).await.unwrap();
299        file.finish_write().await.unwrap();
300
301        let mut reader = file.reader().await.unwrap();
302        assert_eq!(reader.read_u64_le().await.unwrap(), u64::MAX);
303        assert_eq!(reader.read_i64_le().await.unwrap(), i64::MIN);
304    }
305
306    #[tokio::test]
307    async fn test_async_spill_file_bytes_prefixed() {
308        let temp_dir = TempDir::new().unwrap();
309        let file_path = temp_dir.path().join("test.spill");
310
311        let mut file = AsyncSpillFile::new(file_path).await.unwrap();
312        file.write_bytes(b"short").await.unwrap();
313        file.write_bytes(b"longer string here").await.unwrap();
314        file.finish_write().await.unwrap();
315
316        let mut reader = file.reader().await.unwrap();
317        assert_eq!(reader.read_bytes().await.unwrap(), b"short");
318        assert_eq!(reader.read_bytes().await.unwrap(), b"longer string here");
319    }
320
321    #[tokio::test]
322    async fn test_async_spill_file_multiple_readers() {
323        let temp_dir = TempDir::new().unwrap();
324        let file_path = temp_dir.path().join("test.spill");
325
326        let mut file = AsyncSpillFile::new(file_path).await.unwrap();
327        file.write_u64_le(42).await.unwrap();
328        file.write_u64_le(100).await.unwrap();
329        file.finish_write().await.unwrap();
330
331        // Create multiple readers
332        let mut reader1 = file.reader().await.unwrap();
333        let mut reader2 = file.reader().await.unwrap();
334
335        // Read from reader1
336        assert_eq!(reader1.read_u64_le().await.unwrap(), 42);
337
338        // reader2 still at beginning
339        assert_eq!(reader2.read_u64_le().await.unwrap(), 42);
340        assert_eq!(reader2.read_u64_le().await.unwrap(), 100);
341
342        // reader1 continues
343        assert_eq!(reader1.read_u64_le().await.unwrap(), 100);
344    }
345
346    #[tokio::test]
347    async fn test_async_spill_file_delete() {
348        let temp_dir = TempDir::new().unwrap();
349        let file_path = temp_dir.path().join("test.spill");
350        let file_path_clone = file_path.clone();
351
352        let mut file = AsyncSpillFile::new(file_path).await.unwrap();
353        file.write_all(b"data").await.unwrap();
354        file.finish_write().await.unwrap();
355
356        assert!(file_path_clone.exists());
357        file.delete().await.unwrap();
358        assert!(!file_path_clone.exists());
359    }
360
361    #[tokio::test]
362    async fn test_async_reader_seek() {
363        let temp_dir = TempDir::new().unwrap();
364        let file_path = temp_dir.path().join("test.spill");
365
366        let mut file = AsyncSpillFile::new(file_path).await.unwrap();
367        file.write_u64_le(1).await.unwrap();
368        file.write_u64_le(2).await.unwrap();
369        file.write_u64_le(3).await.unwrap();
370        file.finish_write().await.unwrap();
371
372        let mut reader = file.reader().await.unwrap();
373
374        // Read second value directly
375        reader.seek(SeekFrom::Start(8)).await.unwrap();
376        assert_eq!(reader.read_u64_le().await.unwrap(), 2);
377
378        // Rewind and read from beginning
379        reader.rewind().await.unwrap();
380        assert_eq!(reader.read_u64_le().await.unwrap(), 1);
381    }
382}