use std::path::{Path, PathBuf};
use tokio::fs::File;
use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt, BufReader, BufWriter, SeekFrom};
const BUFFER_SIZE: usize = 64 * 1024;
pub struct AsyncSpillFile {
path: PathBuf,
writer: Option<BufWriter<File>>,
bytes_written: u64,
}
impl AsyncSpillFile {
pub async fn new(path: PathBuf) -> std::io::Result<Self> {
let file = File::create(&path).await?;
let writer = BufWriter::with_capacity(BUFFER_SIZE, file);
Ok(Self {
path,
writer: Some(writer),
bytes_written: 0,
})
}
#[must_use]
pub fn path(&self) -> &Path {
&self.path
}
#[must_use]
pub fn bytes_written(&self) -> u64 {
self.bytes_written
}
pub async fn write_all(&mut self, data: &[u8]) -> std::io::Result<()> {
let writer = self
.writer
.as_mut()
.ok_or_else(|| std::io::Error::other("Write phase ended"))?;
writer.write_all(data).await?;
self.bytes_written += data.len() as u64;
Ok(())
}
pub async fn write_u64_le(&mut self, value: u64) -> std::io::Result<()> {
self.write_all(&value.to_le_bytes()).await
}
pub async fn write_i64_le(&mut self, value: i64) -> std::io::Result<()> {
self.write_all(&value.to_le_bytes()).await
}
pub async fn write_bytes(&mut self, data: &[u8]) -> std::io::Result<()> {
self.write_u64_le(data.len() as u64).await?;
self.write_all(data).await
}
pub async fn finish_write(&mut self) -> std::io::Result<()> {
if let Some(mut writer) = self.writer.take() {
writer.flush().await?;
}
Ok(())
}
#[must_use]
pub fn is_writable(&self) -> bool {
self.writer.is_some()
}
pub async fn reader(&self) -> std::io::Result<AsyncSpillFileReader> {
let file = File::open(&self.path).await?;
let reader = BufReader::with_capacity(BUFFER_SIZE, file);
Ok(AsyncSpillFileReader { reader })
}
pub async fn delete(mut self) -> std::io::Result<()> {
self.writer = None;
tokio::fs::remove_file(&self.path).await
}
}
impl std::fmt::Debug for AsyncSpillFile {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncSpillFile")
.field("path", &self.path)
.field("bytes_written", &self.bytes_written)
.field("is_writable", &self.is_writable())
.finish()
}
}
pub struct AsyncSpillFileReader {
reader: BufReader<File>,
}
impl AsyncSpillFileReader {
pub async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
self.reader.read_exact(buf).await?;
Ok(())
}
pub async fn read_u64_le(&mut self) -> std::io::Result<u64> {
let mut buf = [0u8; 8];
self.read_exact(&mut buf).await?;
Ok(u64::from_le_bytes(buf))
}
pub async fn read_i64_le(&mut self) -> std::io::Result<i64> {
let mut buf = [0u8; 8];
self.read_exact(&mut buf).await?;
Ok(i64::from_le_bytes(buf))
}
pub async fn read_f64_le(&mut self) -> std::io::Result<f64> {
let mut buf = [0u8; 8];
self.read_exact(&mut buf).await?;
Ok(f64::from_le_bytes(buf))
}
pub async fn read_u8(&mut self) -> std::io::Result<u8> {
let mut buf = [0u8; 1];
self.read_exact(&mut buf).await?;
Ok(buf[0])
}
pub async fn read_bytes(&mut self) -> std::io::Result<Vec<u8>> {
#[allow(clippy::cast_possible_truncation)]
let len = self.read_u64_le().await? as usize;
let mut buf = vec![0u8; len];
self.read_exact(&mut buf).await?;
Ok(buf)
}
pub async fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
self.reader.seek(pos).await
}
pub async fn rewind(&mut self) -> std::io::Result<()> {
self.reader.seek(SeekFrom::Start(0)).await?;
Ok(())
}
pub async fn position(&mut self) -> std::io::Result<u64> {
self.reader.stream_position().await
}
}
impl std::fmt::Debug for AsyncSpillFileReader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncSpillFileReader").finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_async_spill_file_write_read() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("test.spill");
let mut file = AsyncSpillFile::new(file_path).await.unwrap();
file.write_all(b"hello ").await.unwrap();
file.write_all(b"world").await.unwrap();
assert_eq!(file.bytes_written(), 11);
file.finish_write().await.unwrap();
let mut reader = file.reader().await.unwrap();
let mut buf = [0u8; 11];
reader.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello world");
}
#[tokio::test]
async fn test_async_spill_file_integers() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("test.spill");
let mut file = AsyncSpillFile::new(file_path).await.unwrap();
file.write_u64_le(u64::MAX).await.unwrap();
file.write_i64_le(i64::MIN).await.unwrap();
file.finish_write().await.unwrap();
let mut reader = file.reader().await.unwrap();
assert_eq!(reader.read_u64_le().await.unwrap(), u64::MAX);
assert_eq!(reader.read_i64_le().await.unwrap(), i64::MIN);
}
#[tokio::test]
async fn test_async_spill_file_bytes_prefixed() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("test.spill");
let mut file = AsyncSpillFile::new(file_path).await.unwrap();
file.write_bytes(b"short").await.unwrap();
file.write_bytes(b"longer string here").await.unwrap();
file.finish_write().await.unwrap();
let mut reader = file.reader().await.unwrap();
assert_eq!(reader.read_bytes().await.unwrap(), b"short");
assert_eq!(reader.read_bytes().await.unwrap(), b"longer string here");
}
#[tokio::test]
async fn test_async_spill_file_multiple_readers() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("test.spill");
let mut file = AsyncSpillFile::new(file_path).await.unwrap();
file.write_u64_le(42).await.unwrap();
file.write_u64_le(100).await.unwrap();
file.finish_write().await.unwrap();
let mut reader1 = file.reader().await.unwrap();
let mut reader2 = file.reader().await.unwrap();
assert_eq!(reader1.read_u64_le().await.unwrap(), 42);
assert_eq!(reader2.read_u64_le().await.unwrap(), 42);
assert_eq!(reader2.read_u64_le().await.unwrap(), 100);
assert_eq!(reader1.read_u64_le().await.unwrap(), 100);
}
#[tokio::test]
async fn test_async_spill_file_delete() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("test.spill");
let file_path_clone = file_path.clone();
let mut file = AsyncSpillFile::new(file_path).await.unwrap();
file.write_all(b"data").await.unwrap();
file.finish_write().await.unwrap();
assert!(file_path_clone.exists());
file.delete().await.unwrap();
assert!(!file_path_clone.exists());
}
#[tokio::test]
async fn test_async_reader_seek() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("test.spill");
let mut file = AsyncSpillFile::new(file_path).await.unwrap();
file.write_u64_le(1).await.unwrap();
file.write_u64_le(2).await.unwrap();
file.write_u64_le(3).await.unwrap();
file.finish_write().await.unwrap();
let mut reader = file.reader().await.unwrap();
reader.seek(SeekFrom::Start(8)).await.unwrap();
assert_eq!(reader.read_u64_le().await.unwrap(), 2);
reader.rewind().await.unwrap();
assert_eq!(reader.read_u64_le().await.unwrap(), 1);
}
}