#![allow(clippy::unwrap_used, clippy::expect_used)]
#![cfg(all(feature = "hnsw", feature = "serde"))]
use vicinity::hnsw::HNSWIndex;
fn build_deterministic_index(n: usize, dim: usize) -> HNSWIndex {
let mut index = HNSWIndex::new(dim, 16, 32).expect("valid params");
let mut seed: u64 = 42;
let mut next = || -> f32 {
seed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
((seed >> 33) as f32) / (u32::MAX as f32) - 0.5
};
for i in 0..n {
let mut v: Vec<f32> = (0..dim).map(|_| next()).collect();
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
v.iter_mut().for_each(|x| *x /= norm);
}
index.add(i as u32, v).expect("add should succeed");
}
index.build().expect("build should succeed");
index
}
fn deterministic_queries(count: usize, dim: usize) -> Vec<Vec<f32>> {
let mut seed: u64 = 12345;
let mut next = || -> f32 {
seed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
((seed >> 33) as f32) / (u32::MAX as f32) - 0.5
};
(0..count)
.map(|_| {
let mut q: Vec<f32> = (0..dim).map(|_| next()).collect();
let norm = q.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
q.iter_mut().for_each(|x| *x /= norm);
}
q
})
.collect()
}
fn save_to_bytes(index: &HNSWIndex) -> Vec<u8> {
let mut buf = Vec::new();
index
.save_to_writer(&mut buf)
.expect("serialization should succeed");
buf
}
#[test]
fn hnsw_save_load_exact_roundtrip() {
let dim = 16;
let n = 50;
let k = 5;
let ef = 64;
let index = build_deterministic_index(n, dim);
let queries = deterministic_queries(5, dim);
let original_results: Vec<_> = queries
.iter()
.map(|q| index.search(q, k, ef).expect("search should succeed"))
.collect();
let tmp = tempfile::NamedTempFile::new().expect("tempfile creation");
index
.save_to_writer(std::io::BufWriter::new(tmp.as_file()))
.expect("save should succeed");
let loaded = HNSWIndex::load_from_reader(std::io::BufReader::new(
std::fs::File::open(tmp.path()).expect("open temp file"),
))
.expect("load should succeed");
assert!(loaded.is_built());
for (i, q) in queries.iter().enumerate() {
let loaded_results = loaded.search(q, k, ef).expect("search should succeed");
assert_eq!(
loaded_results.len(),
original_results[i].len(),
"query {i}: result count mismatch"
);
for (j, (lr, or)) in loaded_results
.iter()
.zip(original_results[i].iter())
.enumerate()
{
assert_eq!(
lr.0, or.0,
"query {i} result {j}: doc_id mismatch ({} vs {})",
lr.0, or.0
);
assert!(
(lr.1 - or.1).abs() < f32::EPSILON,
"query {i} result {j}: distance mismatch ({} vs {})",
lr.1,
or.1
);
}
}
}
#[test]
fn truncated_file_returns_err() {
let index = build_deterministic_index(50, 16);
let bytes = save_to_bytes(&index);
assert!(
bytes.len() > 2,
"sanity: serialized bytes should be non-trivial"
);
let truncation_points = [
0, 1, bytes.len() / 2, bytes.len() - 1, ];
for &len in &truncation_points {
let truncated = &bytes[..len];
let result = HNSWIndex::load_from_reader(truncated);
assert!(
result.is_err(),
"expected Err for truncated input ({len} of {} bytes), got Ok",
bytes.len()
);
}
}
#[test]
fn corrupted_bytes_do_not_panic() {
let index = build_deterministic_index(50, 16);
let bytes = save_to_bytes(&index);
let mut seed: u64 = 99;
let mut next_pos = || -> usize {
seed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
(seed >> 33) as usize % bytes.len()
};
for trial in 0..10 {
let mut corrupted = bytes.clone();
let flips = (trial % 3) + 1;
for _ in 0..flips {
let pos = next_pos();
corrupted[pos] ^= 0xFF;
}
let result = std::panic::catch_unwind(|| HNSWIndex::load_from_reader(corrupted.as_slice()));
match result {
Ok(Ok(_loaded)) => {
}
Ok(Err(_e)) => {
}
Err(panic_payload) => {
let msg: String = if let Some(s) = panic_payload.downcast_ref::<&str>() {
(*s).to_owned()
} else if let Some(s) = panic_payload.downcast_ref::<String>() {
s.clone()
} else {
"(non-string panic)".to_owned()
};
panic!("trial {trial}: load_from_reader panicked: {msg}");
}
}
}
}