use std::{
fs::{File, OpenOptions},
io::{BufWriter, Error, ErrorKind, Write},
path::{Path, PathBuf},
sync::Arc,
};
use bytemuck::{cast_slice, try_cast_slice};
use memmap2::Mmap;
use crate::superfile::BuildError;
pub(crate) struct SpillWriter {
path: PathBuf,
writer: BufWriter<File>,
bytes_written: u64,
}
impl SpillWriter {
const BUF_CAPACITY: usize = 1 << 20;
pub fn create(path: PathBuf) -> Result<Self, BuildError> {
let file = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&path)?;
let writer = BufWriter::with_capacity(Self::BUF_CAPACITY, file);
Ok(Self {
path,
writer,
bytes_written: 0,
})
}
pub fn write_all(&mut self, bytes: &[u8]) -> Result<(), BuildError> {
debug_assert!(
bytes.len().is_multiple_of(4),
"spill write_all: byte length {} not a multiple of 4",
bytes.len()
);
self.writer.write_all(bytes)?;
self.bytes_written += bytes.len() as u64;
Ok(())
}
pub fn write_vec(&mut self, vec: &[f32]) -> Result<(), BuildError> {
let bytes: &[u8] = cast_slice(vec);
self.write_all(bytes)
}
#[cfg(test)]
pub(crate) fn bytes_written(&self) -> u64 {
self.bytes_written
}
pub fn finish(mut self) -> Result<PathBuf, BuildError> {
self.writer.flush()?;
let file = self
.writer
.into_inner()
.map_err(|e| BuildError::Io(e.into_error()))?;
file.sync_all()?;
Ok(self.path)
}
}
pub trait ChunkedVectorSource {
fn n_rows(&self) -> usize;
#[cfg(test)]
fn dim(&self) -> usize;
fn chunk_rows(&self) -> usize;
fn next_chunk(&mut self) -> Option<&[f32]>;
#[cfg(test)]
fn reset(&mut self);
}
pub struct InMemoryVectorSource {
buf: Arc<Vec<f32>>,
dim: usize,
chunk_rows: usize,
cursor: usize, }
impl InMemoryVectorSource {
pub fn new(buf: Arc<Vec<f32>>, dim: usize, chunk_rows: usize) -> Self {
debug_assert!(dim > 0, "InMemoryVectorSource: dim must be > 0");
debug_assert!(
chunk_rows > 0,
"InMemoryVectorSource: chunk_rows must be > 0"
);
debug_assert!(
buf.len().is_multiple_of(dim),
"InMemoryVectorSource: buf.len() {} not a multiple of dim {}",
buf.len(),
dim
);
Self {
buf,
dim,
chunk_rows,
cursor: 0,
}
}
}
impl ChunkedVectorSource for InMemoryVectorSource {
fn n_rows(&self) -> usize {
self.buf.len() / self.dim
}
#[cfg(test)]
fn dim(&self) -> usize {
self.dim
}
fn chunk_rows(&self) -> usize {
self.chunk_rows
}
fn next_chunk(&mut self) -> Option<&[f32]> {
let n_rows = self.n_rows();
if self.cursor >= n_rows {
return None;
}
let take = (n_rows - self.cursor).min(self.chunk_rows);
let start = self.cursor * self.dim;
let end = start + take * self.dim;
self.cursor += take;
Some(&self.buf[start..end])
}
#[cfg(test)]
fn reset(&mut self) {
self.cursor = 0;
}
}
pub struct MmapVectorSource {
map: Mmap,
dim: usize,
chunk_rows: usize,
cursor: usize, }
impl MmapVectorSource {
pub fn open(path: &Path, dim: usize, chunk_rows: usize) -> Result<Self, BuildError> {
debug_assert!(dim > 0, "MmapVectorSource: dim must be > 0");
debug_assert!(chunk_rows > 0, "MmapVectorSource: chunk_rows must be > 0");
let file = File::open(path)?;
let file_len = file.metadata()?.len() as usize;
let row_bytes = dim
.checked_mul(4)
.expect("dim * 4 overflows usize — dim > 2^29 is nonsense");
if !file_len.is_multiple_of(row_bytes) {
return Err(BuildError::Io(Error::new(
ErrorKind::InvalidData,
format!(
"spill file length {file_len} is not a multiple of \
row size {row_bytes} (dim={dim})"
),
)));
}
let map = unsafe { Mmap::map(&file)? };
Ok(Self {
map,
dim,
chunk_rows,
cursor: 0,
})
}
}
impl ChunkedVectorSource for MmapVectorSource {
fn n_rows(&self) -> usize {
self.map.len() / (self.dim * 4)
}
#[cfg(test)]
fn dim(&self) -> usize {
self.dim
}
fn chunk_rows(&self) -> usize {
self.chunk_rows
}
fn next_chunk(&mut self) -> Option<&[f32]> {
let n_rows = self.n_rows();
if self.cursor >= n_rows {
return None;
}
let take = (n_rows - self.cursor).min(self.chunk_rows);
let row_bytes = self.dim * 4;
let start_b = self.cursor * row_bytes;
let end_b = start_b + take * row_bytes;
self.cursor += take;
let bytes: &[u8] = &self.map[start_b..end_b];
let floats: &[f32] =
try_cast_slice(bytes).expect("mmap slice is page-aligned and length is row-aligned");
Some(floats)
}
#[cfg(test)]
fn reset(&mut self) {
self.cursor = 0;
}
}
#[cfg(test)]
mod tests {
use std::{
fs::write,
io::{ErrorKind, Read},
iter::from_fn,
};
use tempfile::tempdir;
use super::*;
fn synth(n_rows: usize, dim: usize) -> Vec<f32> {
let mut v = Vec::with_capacity(n_rows * dim);
for r in 0..n_rows {
for c in 0..dim {
v.push(r as f32 * 1000.0 + c as f32);
}
}
v
}
#[test]
fn spill_write_then_mmap_read_round_trip() {
let tmp = tempdir().expect("tempdir");
let path = tmp.path().join("spill.bin");
let n_rows = 17;
let dim = 8;
let corpus = synth(n_rows, dim);
{
let mut w = SpillWriter::create(path.clone()).expect("create");
let bytes: &[u8] = cast_slice(&corpus);
w.write_all(bytes).expect("write_all");
assert_eq!(w.bytes_written(), bytes.len() as u64);
let finished_path = w.finish().expect("finish");
assert_eq!(finished_path, path);
}
{
let mut f = File::open(&path).expect("open spill");
let mut buf = Vec::new();
f.read_to_end(&mut buf).expect("read");
let expected: &[u8] = cast_slice(&corpus);
assert_eq!(buf, expected, "raw byte round-trip mismatch");
}
let mut src = MmapVectorSource::open(&path, dim, 5).expect("mmap open");
assert_eq!(src.n_rows(), n_rows);
assert_eq!(src.dim(), dim);
assert_eq!(src.chunk_rows(), 5);
let mut emitted = Vec::with_capacity(n_rows * dim);
while let Some(chunk) = src.next_chunk() {
emitted.extend_from_slice(chunk);
}
assert_eq!(emitted, corpus, "f32 round-trip via mmap mismatch");
}
#[test]
fn spill_write_vec_per_row_matches_write_all() {
let tmp = tempdir().expect("tempdir");
let path = tmp.path().join("spill_per_row.bin");
let n_rows = 13;
let dim = 4;
let corpus = synth(n_rows, dim);
let mut w = SpillWriter::create(path.clone()).expect("create");
for r in 0..n_rows {
let row = &corpus[r * dim..(r + 1) * dim];
w.write_vec(row).expect("write_vec");
}
w.finish().expect("finish");
let mut src = MmapVectorSource::open(&path, dim, dim).expect("mmap open");
let mut emitted = Vec::with_capacity(n_rows * dim);
while let Some(chunk) = src.next_chunk() {
emitted.extend_from_slice(chunk);
}
assert_eq!(emitted, corpus, "per-row write_vec round-trip mismatch");
}
#[test]
fn in_memory_source_yields_full_corpus_in_chunk_size_steps() {
let n_rows = 25;
let dim = 3;
let corpus = synth(n_rows, dim);
let mut src =
InMemoryVectorSource::new(Arc::new(corpus.clone()), dim, 7);
assert_eq!(src.n_rows(), n_rows);
assert_eq!(src.dim(), dim);
assert_eq!(src.chunk_rows(), 7);
let chunk = src.next_chunk().expect("chunk 0");
assert_eq!(chunk.len(), 7 * dim);
assert_eq!(chunk, &corpus[0..7 * dim]);
let chunk = src.next_chunk().expect("chunk 1");
assert_eq!(chunk.len(), 7 * dim);
assert_eq!(chunk, &corpus[7 * dim..14 * dim]);
let chunk = src.next_chunk().expect("chunk 2");
assert_eq!(chunk.len(), 7 * dim);
assert_eq!(chunk, &corpus[14 * dim..21 * dim]);
let chunk = src.next_chunk().expect("chunk 3 (partial)");
assert_eq!(chunk.len(), 4 * dim);
assert_eq!(chunk, &corpus[21 * dim..25 * dim]);
assert!(src.next_chunk().is_none(), "expected exhausted");
assert!(src.next_chunk().is_none(), "still exhausted on re-poll");
}
#[test]
fn in_memory_source_reset_replays_from_zero() {
let n_rows = 10;
let dim = 4;
let corpus = synth(n_rows, dim);
let mut src = InMemoryVectorSource::new(Arc::new(corpus.clone()), dim, 3);
let first_pass: Vec<f32> = from_fn(|| src.next_chunk().map(|c| c.to_vec()))
.flatten()
.collect();
assert_eq!(first_pass, corpus);
src.reset();
let second_pass: Vec<f32> = from_fn(|| src.next_chunk().map(|c| c.to_vec()))
.flatten()
.collect();
assert_eq!(second_pass, corpus, "reset didn't replay full corpus");
}
#[test]
fn mmap_source_chunk_boundary_matches_in_memory() {
let tmp = tempdir().expect("tempdir");
let path = tmp.path().join("xcheck.bin");
let n_rows = 50;
let dim = 5;
let corpus = synth(n_rows, dim);
let mut w = SpillWriter::create(path.clone()).expect("create");
w.write_all(cast_slice(&corpus)).expect("write");
w.finish().expect("finish");
let chunk_rows = 11;
let mut mem = InMemoryVectorSource::new(Arc::new(corpus.clone()), dim, chunk_rows);
let mut mm = MmapVectorSource::open(&path, dim, chunk_rows).expect("mmap");
loop {
let a = mem.next_chunk();
let b = mm.next_chunk();
match (a, b) {
(Some(x), Some(y)) => assert_eq!(x, y, "chunk-boundary divergence"),
(None, None) => break,
_ => panic!("source exhaustion disagreement"),
}
}
}
#[test]
fn mmap_source_rejects_misaligned_file_length() {
let tmp = tempdir().expect("tempdir");
let path = tmp.path().join("bad.bin");
write(&path, [0u8; 17]).expect("write 17 bytes");
match MmapVectorSource::open(&path, 4, 1) {
Ok(_) => panic!("expected length-mismatch error, got Ok"),
Err(BuildError::Io(e)) => {
assert_eq!(e.kind(), ErrorKind::InvalidData)
}
Err(other) => panic!("expected Io InvalidData, got {other:?}"),
}
}
#[test]
fn empty_corpus_yields_no_chunks() {
let mem_src = InMemoryVectorSource::new(Arc::new(Vec::<f32>::new()), 4, 8);
let mut s = mem_src;
assert_eq!(s.n_rows(), 0);
assert!(s.next_chunk().is_none());
let tmp = tempdir().expect("tempdir");
let path = tmp.path().join("empty.bin");
let w = SpillWriter::create(path.clone()).expect("create");
w.finish().expect("finish empty");
let mut s = MmapVectorSource::open(&path, 4, 8).expect("open empty");
assert_eq!(s.n_rows(), 0);
assert!(s.next_chunk().is_none());
}
}