use std::collections::HashSet;
use std::fs::File;
use std::io::{BufReader, Error, ErrorKind, Read, Result};
use std::path::Path;
pub fn load_fvecs(path: &Path) -> Result<Vec<Vec<f32>>> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let mut vectors = Vec::new();
let mut expected_dim: Option<usize> = None;
loop {
let mut dim_buf = [0u8; 4];
match reader.read_exact(&mut dim_buf) {
Ok(()) => {}
Err(e) if e.kind() == ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e),
}
let dim = u32::from_le_bytes(dim_buf) as usize;
match expected_dim {
None => expected_dim = Some(dim),
Some(expected) if dim != expected => {
return Err(Error::new(
ErrorKind::InvalidData,
format!(
"Dimension mismatch at vector {}: expected {}, got {}",
vectors.len(),
expected,
dim
),
));
}
_ => {}
}
if dim == 0 || dim > 10_000 {
return Err(Error::new(
ErrorKind::InvalidData,
format!("Invalid dimension {} at vector {}", dim, vectors.len()),
));
}
let mut vec_buf = vec![0u8; dim * 4];
reader.read_exact(&mut vec_buf)?;
let vec: Vec<f32> = vec_buf
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
vectors.push(vec);
}
Ok(vectors)
}
pub fn load_ivecs(path: &Path) -> Result<Vec<Vec<u32>>> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let mut results = Vec::new();
loop {
let mut k_buf = [0u8; 4];
match reader.read_exact(&mut k_buf) {
Ok(()) => {}
Err(e) if e.kind() == ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e),
}
let k = u32::from_le_bytes(k_buf) as usize;
if k == 0 || k > 1000 {
return Err(Error::new(
ErrorKind::InvalidData,
format!("Invalid k {} at result {}", k, results.len()),
));
}
let mut ids_buf = vec![0u8; k * 4];
reader.read_exact(&mut ids_buf)?;
let ids: Vec<u32> = ids_buf
.chunks_exact(4)
.map(|b| u32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
results.push(ids);
}
Ok(results)
}
#[must_use]
pub fn calculate_recall(results: &[u64], ground_truth: &[u32], k: usize) -> f64 {
let k = k.min(results.len()).min(ground_truth.len());
if k == 0 {
return 0.0;
}
let result_set: HashSet<u64> = results.iter().take(k).copied().collect();
let truth_set: HashSet<u64> = ground_truth.iter().take(k).map(|&x| u64::from(x)).collect();
let intersection = result_set.intersection(&truth_set).count();
intersection as f64 / k as f64
}
#[derive(Debug, Clone)]
pub struct RecallBenchResult {
pub dataset: String,
pub mode: String,
pub k: usize,
pub ef_search: usize,
pub recall: f64,
pub queries_per_second: f64,
pub latency_p50_us: f64,
pub latency_p99_us: f64,
}
impl RecallBenchResult {
#[must_use]
pub fn as_table_row(&self) -> String {
format!(
"| {} | {} | {} | {} | {:.4} | {:.0} | {:.0} | {:.0} |",
self.dataset,
self.mode,
self.ef_search,
self.k,
self.recall,
self.queries_per_second,
self.latency_p50_us,
self.latency_p99_us
)
}
}
#[must_use]
pub fn percentile(sorted_values: &[f64], p: f64) -> f64 {
assert!(
!sorted_values.is_empty(),
"Cannot calculate percentile of empty slice"
);
let idx = ((sorted_values.len() as f64 * p) as usize).min(sorted_values.len() - 1);
sorted_values[idx]
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_test_fvecs(vectors: &[Vec<f32>]) -> NamedTempFile {
let mut file = NamedTempFile::new().unwrap();
for vec in vectors {
let dim = vec.len() as u32;
file.write_all(&dim.to_le_bytes()).unwrap();
for &val in vec {
file.write_all(&val.to_le_bytes()).unwrap();
}
}
file.flush().unwrap();
file
}
fn create_test_ivecs(results: &[Vec<u32>]) -> NamedTempFile {
let mut file = NamedTempFile::new().unwrap();
for ids in results {
let k = ids.len() as u32;
file.write_all(&k.to_le_bytes()).unwrap();
for &id in ids {
file.write_all(&id.to_le_bytes()).unwrap();
}
}
file.flush().unwrap();
file
}
#[test]
fn test_load_fvecs_valid() {
let vectors = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
let file = create_test_fvecs(&vectors);
let loaded = load_fvecs(file.path()).unwrap();
assert_eq!(loaded.len(), 2);
assert_eq!(loaded[0], vec![1.0, 2.0, 3.0]);
assert_eq!(loaded[1], vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_load_fvecs_empty() {
let file = NamedTempFile::new().unwrap();
let loaded = load_fvecs(file.path()).unwrap();
assert!(loaded.is_empty());
}
#[test]
fn test_load_fvecs_inconsistent_dimensions() {
let mut file = NamedTempFile::new().unwrap();
file.write_all(&3u32.to_le_bytes()).unwrap();
file.write_all(&1.0f32.to_le_bytes()).unwrap();
file.write_all(&2.0f32.to_le_bytes()).unwrap();
file.write_all(&3.0f32.to_le_bytes()).unwrap();
file.write_all(&4u32.to_le_bytes()).unwrap();
file.write_all(&1.0f32.to_le_bytes()).unwrap();
file.write_all(&2.0f32.to_le_bytes()).unwrap();
file.write_all(&3.0f32.to_le_bytes()).unwrap();
file.write_all(&4.0f32.to_le_bytes()).unwrap();
file.flush().unwrap();
let result = load_fvecs(file.path());
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Dimension mismatch"));
}
#[test]
fn test_load_fvecs_invalid_dimension_zero() {
let mut file = NamedTempFile::new().unwrap();
file.write_all(&0u32.to_le_bytes()).unwrap();
file.flush().unwrap();
let result = load_fvecs(file.path());
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Invalid dimension"));
}
#[test]
fn test_load_fvecs_invalid_dimension_too_large() {
let mut file = NamedTempFile::new().unwrap();
file.write_all(&20000u32.to_le_bytes()).unwrap();
file.flush().unwrap();
let result = load_fvecs(file.path());
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Invalid dimension"));
}
#[test]
fn test_load_ivecs_valid() {
let results = vec![vec![0, 1, 2], vec![3, 4, 5]];
let file = create_test_ivecs(&results);
let loaded = load_ivecs(file.path()).unwrap();
assert_eq!(loaded.len(), 2);
assert_eq!(loaded[0], vec![0, 1, 2]);
assert_eq!(loaded[1], vec![3, 4, 5]);
}
#[test]
fn test_load_ivecs_empty() {
let file = NamedTempFile::new().unwrap();
let loaded = load_ivecs(file.path()).unwrap();
assert!(loaded.is_empty());
}
#[test]
fn test_load_ivecs_invalid_k_zero() {
let mut file = NamedTempFile::new().unwrap();
file.write_all(&0u32.to_le_bytes()).unwrap();
file.flush().unwrap();
let result = load_ivecs(file.path());
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Invalid k"));
}
#[test]
fn test_calculate_recall_perfect() {
let results = vec![0u64, 1, 2, 3, 4];
let ground_truth = vec![0u32, 1, 2, 3, 4];
let recall = calculate_recall(&results, &ground_truth, 5);
assert!((recall - 1.0).abs() < 0.001);
}
#[test]
fn test_calculate_recall_zero() {
let results = vec![10u64, 11, 12, 13, 14];
let ground_truth = vec![0u32, 1, 2, 3, 4];
let recall = calculate_recall(&results, &ground_truth, 5);
assert!((recall - 0.0).abs() < 0.001);
}
#[test]
fn test_calculate_recall_partial() {
let results = vec![0u64, 1, 10, 11, 12];
let ground_truth = vec![0u32, 1, 2, 3, 4];
let recall = calculate_recall(&results, &ground_truth, 5);
assert!((recall - 0.4).abs() < 0.001); }
#[test]
fn test_calculate_recall_k_larger_than_results() {
let results = vec![0u64, 1];
let ground_truth = vec![0u32, 1, 2, 3, 4];
let recall = calculate_recall(&results, &ground_truth, 5);
assert!((recall - 1.0).abs() < 0.001); }
#[test]
fn test_calculate_recall_empty() {
let results: Vec<u64> = vec![];
let ground_truth = vec![0u32, 1, 2];
let recall = calculate_recall(&results, &ground_truth, 5);
assert!((recall - 0.0).abs() < 0.001);
}
#[test]
fn test_percentile_median() {
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let p50 = percentile(&values, 0.5);
assert!((p50 - 3.0).abs() < 0.001);
}
#[test]
fn test_percentile_p99() {
let values: Vec<f64> = (0..100).map(|i| i as f64).collect();
let p99 = percentile(&values, 0.99);
assert!((p99 - 99.0).abs() < 0.001);
}
#[test]
fn test_recall_result_table_row() {
let result = RecallBenchResult {
dataset: "SIFT-1M".to_string(),
mode: "float32".to_string(),
k: 10,
ef_search: 50,
recall: 0.9512,
queries_per_second: 5000.0,
latency_p50_us: 150.0,
latency_p99_us: 500.0,
};
let row = result.as_table_row();
assert!(row.contains("SIFT-1M"));
assert!(row.contains("float32"));
assert!(row.contains("0.9512"));
}
}