#[cfg(test)]
mod tests {
use std::mem::size_of;
use std::sync::Arc;
use std::thread;
use tempfile::TempDir;
use crate::embedder::Embedding;
use crate::hnsw::{HnswIndex, HnswInner, LoadedHnsw};
use crate::EMBEDDING_DIM;
fn make_embedding(seed: u32) -> Embedding {
let mut v = vec![0.01f32; EMBEDDING_DIM];
let idx = (seed as usize) % EMBEDDING_DIM;
v[idx] = 1.0;
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
for val in &mut v {
*val /= norm;
}
Embedding::new(v)
}
#[test]
fn test_loaded_index_multiple_searches() {
let tmp = TempDir::new().unwrap();
let embeddings: Vec<_> = (1..=20)
.map(|i| (format!("chunk{}", i), make_embedding(i)))
.collect();
let index = HnswIndex::build_with_dim(embeddings, crate::EMBEDDING_DIM).unwrap();
index.save(tmp.path(), "safety_test").unwrap();
let loaded =
HnswIndex::load_with_dim(tmp.path(), "safety_test", crate::EMBEDDING_DIM).unwrap();
assert_eq!(loaded.len(), 20);
for i in 1..=20 {
let query = make_embedding(i);
let results = loaded.search(&query, 5);
assert!(
!results.is_empty(),
"Search {} should return results (memory corruption check)",
i
);
for r in &results {
assert!(
r.id.starts_with("chunk"),
"Result ID '{}' looks corrupted",
r.id
);
}
}
}
#[test]
fn test_loaded_index_lifecycle() {
let tmp = TempDir::new().unwrap();
let embeddings = vec![
("a".to_string(), make_embedding(100)),
("b".to_string(), make_embedding(200)),
("c".to_string(), make_embedding(300)),
];
HnswIndex::build_with_dim(embeddings, crate::EMBEDDING_DIM)
.unwrap()
.save(tmp.path(), "lifecycle")
.unwrap();
for cycle in 0..5 {
let loaded =
HnswIndex::load_with_dim(tmp.path(), "lifecycle", crate::EMBEDDING_DIM).unwrap();
let results = loaded.search(&make_embedding(100), 3);
assert_eq!(results[0].id, "a", "Cycle {} failed", cycle);
}
}
#[test]
fn test_loaded_index_threaded_access() {
let tmp = TempDir::new().unwrap();
let embeddings: Vec<_> = (1..=20)
.map(|i| (format!("item{}", i), make_embedding(i)))
.collect();
HnswIndex::build_with_dim(embeddings, crate::EMBEDDING_DIM)
.unwrap()
.save(tmp.path(), "threaded")
.unwrap();
let loaded = Arc::new(
HnswIndex::load_with_dim(tmp.path(), "threaded", crate::EMBEDDING_DIM).unwrap(),
);
let handles: Vec<_> = (0..4)
.map(|t| {
let index = Arc::clone(&loaded);
thread::spawn(move || {
for i in 1..=20 {
let query = make_embedding(i);
let results = index.search(&query, 3);
assert!(!results.is_empty(), "Thread {} search {} failed", t, i);
}
})
})
.collect();
for h in handles {
h.join().expect("Thread panicked");
}
}
#[test]
fn test_layout_invariants() {
let loaded_size = size_of::<LoadedHnsw>();
assert!(
loaded_size < 1024,
"LoadedHnsw unexpectedly large: {} bytes",
loaded_size
);
let inner_size = size_of::<HnswInner>();
assert!(
inner_size < 2048,
"HnswInner unexpectedly large: {} bytes",
inner_size
);
}
#[test]
fn test_loaded_minimal_index() {
let tmp = TempDir::new().unwrap();
let index = HnswIndex::build_with_dim(
vec![("only".to_string(), make_embedding(42))],
crate::EMBEDDING_DIM,
)
.unwrap();
index.save(tmp.path(), "minimal").unwrap();
let loaded = HnswIndex::load_with_dim(tmp.path(), "minimal", crate::EMBEDDING_DIM).unwrap();
assert_eq!(loaded.len(), 1);
let results = loaded.search(&make_embedding(42), 5);
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "only");
}
}