use std::fs::File;
use std::io::{BufReader, Read};
use std::path::{Path, PathBuf};
use crate::error::{EvalError, Result};
#[derive(Debug, Clone)]
pub struct SiftDataset {
pub base: Vec<Vec<f32>>,
pub queries: Vec<Vec<f32>>,
pub ground_truth: Vec<Vec<u32>>,
pub dim: usize,
}
const MAX_RECORD_DIM: usize = 1 << 20;
fn read_vecs<T, F>(path: &Path, truncated_reason: &'static str, decode: F) -> Result<Vec<Vec<T>>>
where
F: Fn([u8; 4]) -> T,
{
let file = File::open(path).map_err(|source| EvalError::Io {
path: path.to_path_buf(),
source,
})?;
let mut r = BufReader::new(file);
let mut out: Vec<Vec<T>> = Vec::new();
let mut dim_buf = [0u8; 4];
loop {
match r.read_exact(&mut dim_buf) {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(source) => {
return Err(EvalError::Io {
path: path.to_path_buf(),
source,
});
}
}
let dim = u32::from_le_bytes(dim_buf) as usize;
if dim > MAX_RECORD_DIM {
return Err(EvalError::Parse {
path: path.to_path_buf(),
reason: "record dimension exceeds the maximum supported (file likely corrupt)",
});
}
let mut payload = vec![0u8; dim * 4];
r.read_exact(&mut payload).map_err(|source| {
if source.kind() == std::io::ErrorKind::UnexpectedEof {
EvalError::Parse {
path: path.to_path_buf(),
reason: truncated_reason,
}
} else {
EvalError::Io {
path: path.to_path_buf(),
source,
}
}
})?;
let row: Vec<T> = payload
.chunks_exact(4)
.map(|c| decode([c[0], c[1], c[2], c[3]]))
.collect();
out.push(row);
}
Ok(out)
}
pub fn read_fvecs(path: impl AsRef<Path>) -> Result<Vec<Vec<f32>>> {
read_vecs(
path.as_ref(),
"truncated fvecs record payload",
f32::from_le_bytes,
)
}
pub fn read_ivecs(path: impl AsRef<Path>) -> Result<Vec<Vec<u32>>> {
read_vecs(
path.as_ref(),
"truncated ivecs record payload",
u32::from_le_bytes,
)
}
pub fn load_sift_dataset(root: impl AsRef<Path>, prefix: &str) -> Result<SiftDataset> {
let root = root.as_ref();
let base_path: PathBuf = root.join(format!("{prefix}_base.fvecs"));
let query_path: PathBuf = root.join(format!("{prefix}_query.fvecs"));
let gt_path: PathBuf = root.join(format!("{prefix}_groundtruth.ivecs"));
let base = read_fvecs(&base_path)?;
let queries = read_fvecs(&query_path)?;
let ground_truth = read_ivecs(>_path)?;
if base.is_empty() {
return Err(EvalError::EmptyInput { kind: "base" });
}
if queries.is_empty() {
return Err(EvalError::EmptyInput { kind: "queries" });
}
if ground_truth.is_empty() {
return Err(EvalError::EmptyInput {
kind: "ground_truth",
});
}
let dim = base[0].len();
if let Some(row) = base.iter().find(|r| r.len() != dim) {
return Err(EvalError::DimensionMismatch {
expected: dim,
found: row.len(),
});
}
if let Some(row) = queries.iter().find(|r| r.len() != dim) {
return Err(EvalError::DimensionMismatch {
expected: dim,
found: row.len(),
});
}
if queries.len() != ground_truth.len() {
return Err(EvalError::LengthMismatch {
kind: "queries vs ground_truth",
expected: queries.len(),
found: ground_truth.len(),
});
}
Ok(SiftDataset {
base,
queries,
ground_truth,
dim,
})
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used, clippy::expect_used)]
use super::*;
use std::fs;
fn encode_fvecs(rows: &[&[f32]]) -> Vec<u8> {
let mut bytes = Vec::new();
for row in rows {
bytes.extend_from_slice(&(row.len() as u32).to_le_bytes());
for &x in *row {
bytes.extend_from_slice(&x.to_le_bytes());
}
}
bytes
}
fn encode_ivecs(rows: &[&[u32]]) -> Vec<u8> {
let mut bytes = Vec::new();
for row in rows {
bytes.extend_from_slice(&(row.len() as u32).to_le_bytes());
for &x in *row {
bytes.extend_from_slice(&x.to_le_bytes());
}
}
bytes
}
struct TempFile(PathBuf);
impl TempFile {
fn new(name: &str, bytes: &[u8]) -> Self {
let path = std::env::temp_dir().join(format!("iqdb_eval_{name}"));
fs::write(&path, bytes).unwrap();
Self(path)
}
fn path(&self) -> &Path {
&self.0
}
}
impl Drop for TempFile {
fn drop(&mut self) {
let _ = fs::remove_file(&self.0);
}
}
#[test]
fn fvecs_round_trips() {
let rows: &[&[f32]] = &[&[1.0, 2.0, 3.0], &[-4.5, 0.0, 9.25]];
let f = TempFile::new("rt.fvecs", &encode_fvecs(rows));
let got = read_fvecs(f.path()).unwrap();
assert_eq!(got, vec![vec![1.0, 2.0, 3.0], vec![-4.5, 0.0, 9.25]]);
}
#[test]
fn ivecs_round_trips() {
let rows: &[&[u32]] = &[&[0, 1, 2], &[7, 8, 9]];
let f = TempFile::new("rt.ivecs", &encode_ivecs(rows));
let got = read_ivecs(f.path()).unwrap();
assert_eq!(got, vec![vec![0u32, 1, 2], vec![7, 8, 9]]);
}
#[test]
fn empty_file_reads_empty() {
let f = TempFile::new("empty.fvecs", &[]);
assert!(read_fvecs(f.path()).unwrap().is_empty());
}
#[test]
fn truncated_payload_is_parse_error() {
let mut bytes = 3u32.to_le_bytes().to_vec();
bytes.extend_from_slice(&1.0f32.to_le_bytes());
bytes.extend_from_slice(&2.0f32.to_le_bytes());
let f = TempFile::new("trunc.fvecs", &bytes);
let err = read_fvecs(f.path()).unwrap_err();
assert!(matches!(err, EvalError::Parse { .. }), "got {err:?}");
}
#[test]
fn trailing_partial_header_stops_cleanly() {
let mut bytes = encode_fvecs(&[&[1.0, 2.0]]);
bytes.extend_from_slice(&[0xAB, 0xCD]);
let f = TempFile::new("partial.fvecs", &bytes);
let got = read_fvecs(f.path()).unwrap();
assert_eq!(got, vec![vec![1.0, 2.0]]);
}
#[test]
fn oversized_dim_is_rejected_without_allocating() {
let bytes = u32::MAX.to_le_bytes().to_vec();
let f = TempFile::new("huge.fvecs", &bytes);
let err = read_fvecs(f.path()).unwrap_err();
match err {
EvalError::Parse { reason, .. } => {
assert!(reason.contains("dimension"), "unexpected reason: {reason}");
}
other => panic!("expected Parse, got {other:?}"),
}
}
#[test]
fn dim_exactly_at_cap_is_accepted_in_header() {
let bytes = (MAX_RECORD_DIM as u32).to_le_bytes().to_vec();
let f = TempFile::new("atcap.fvecs", &bytes);
let err = read_fvecs(f.path()).unwrap_err();
assert!(
matches!(&err, EvalError::Parse { reason, .. } if reason.contains("truncated")),
"expected truncated-payload parse error, got {err:?}",
);
}
#[test]
fn missing_file_is_io_error() {
let path = std::env::temp_dir().join("iqdb_eval_does_not_exist_xyz.fvecs");
let err = read_fvecs(&path).unwrap_err();
assert!(matches!(err, EvalError::Io { .. }), "got {err:?}");
}
}