use crate::distance::DistanceMetric;
use crate::error::{Result, SynaError};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswConfig {
pub m: usize,
pub m_max: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub ml: f64,
}
impl Default for HnswConfig {
fn default() -> Self {
Self {
m: 16,
m_max: 32,
ef_construction: 200,
ef_search: 100,
ml: 1.0 / (16.0_f64).ln(),
}
}
}
impl HnswConfig {
pub fn with_m(m: usize) -> Self {
Self {
m,
m_max: 2 * m,
ml: 1.0 / (m as f64).ln(),
..Default::default()
}
}
pub fn ef_construction(mut self, ef: usize) -> Self {
self.ef_construction = ef;
self
}
pub fn ef_search(mut self, ef: usize) -> Self {
self.ef_search = ef;
self
}
}
#[derive(Debug, Clone)]
pub struct Candidate {
pub node_id: usize,
pub distance: f32,
}
impl PartialEq for Candidate {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance && self.node_id == other.node_id
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> Ordering {
match self.distance.partial_cmp(&other.distance) {
Some(ord) => ord,
None => Ordering::Equal, }
}
}
#[derive(Debug, Clone)]
pub struct MinCandidate {
pub node_id: usize,
pub distance: f32,
}
impl PartialEq for MinCandidate {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance && self.node_id == other.node_id
}
}
impl Eq for MinCandidate {}
impl PartialOrd for MinCandidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MinCandidate {
fn cmp(&self, other: &Self) -> Ordering {
match other.distance.partial_cmp(&self.distance) {
Some(ord) => ord,
None => Ordering::Equal,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswNode {
pub key: String,
pub vector: Vec<f32>,
pub neighbors: Vec<Vec<(usize, f32)>>,
}
impl HnswNode {
pub fn new(key: String, vector: Vec<f32>, level: usize) -> Self {
Self {
key,
vector,
neighbors: vec![Vec::new(); level + 1],
}
}
pub fn level(&self) -> usize {
self.neighbors.len().saturating_sub(1)
}
}
pub struct HnswIndex {
config: HnswConfig,
metric: DistanceMetric,
dimensions: u16,
pub nodes: Vec<HnswNode>,
pub entry_point: Option<usize>,
max_level: usize,
pub key_to_id: HashMap<String, usize>,
rng_state: u64,
}
impl HnswIndex {
pub fn new(dimensions: u16, metric: DistanceMetric, config: HnswConfig) -> Self {
Self {
config,
metric,
dimensions,
nodes: Vec::new(),
entry_point: None,
max_level: 0,
key_to_id: HashMap::new(),
rng_state: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(42),
}
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn dimensions(&self) -> u16 {
self.dimensions
}
pub fn metric(&self) -> DistanceMetric {
self.metric
}
pub fn config(&self) -> &HnswConfig {
&self.config
}
pub fn entry_point(&self) -> Option<usize> {
self.entry_point
}
pub fn max_level(&self) -> usize {
self.max_level
}
pub fn set_max_level(&mut self, level: usize) {
if level > self.max_level {
self.max_level = level;
}
}
pub fn contains_key(&self, key: &str) -> bool {
self.key_to_id.contains_key(key)
}
pub fn get_node_id(&self, key: &str) -> Option<usize> {
self.key_to_id.get(key).copied()
}
pub fn get_node(&self, node_id: usize) -> Option<&HnswNode> {
self.nodes.get(node_id)
}
pub fn keys(&self) -> impl Iterator<Item = &str> {
self.key_to_id.keys().map(|s| s.as_str())
}
pub fn random_level(&mut self) -> usize {
self.rng_state ^= self.rng_state << 13;
self.rng_state ^= self.rng_state >> 7;
self.rng_state ^= self.rng_state << 17;
let uniform = (self.rng_state as f64) / (u64::MAX as f64);
let level = (-uniform.ln() * self.config.ml).floor() as usize;
level.min(32)
}
pub fn distance_to_node(&self, query: &[f32], node_id: usize) -> f32 {
self.metric.distance(query, &self.nodes[node_id].vector)
}
pub fn nodes_at_level(&self, level: usize) -> Vec<usize> {
self.nodes
.iter()
.enumerate()
.filter(|(_, node)| node.level() >= level)
.map(|(id, _)| id)
.collect()
}
pub fn stats(&self) -> HnswStats {
let mut level_counts = vec![0usize; self.max_level + 1];
let mut total_edges = 0usize;
for node in &self.nodes {
for (level, neighbors) in node.neighbors.iter().enumerate() {
if level < level_counts.len() {
level_counts[level] += 1;
}
total_edges += neighbors.len();
}
}
HnswStats {
num_nodes: self.nodes.len(),
max_level: self.max_level,
level_counts,
total_edges,
avg_edges_per_node: if self.nodes.is_empty() {
0.0
} else {
total_edges as f64 / self.nodes.len() as f64
},
}
}
}
#[derive(Debug, Clone)]
pub struct HnswStats {
pub num_nodes: usize,
pub max_level: usize,
pub level_counts: Vec<usize>,
pub total_edges: usize,
pub avg_edges_per_node: f64,
}
struct MinHeapEntry(f32, usize);
impl PartialEq for MinHeapEntry {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0 && self.1 == other.1
}
}
impl Eq for MinHeapEntry {}
impl PartialOrd for MinHeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MinHeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
other.0.total_cmp(&self.0)
}
}
struct MaxHeapEntry(f32, usize);
impl PartialEq for MaxHeapEntry {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0 && self.1 == other.1
}
}
impl Eq for MaxHeapEntry {}
impl PartialOrd for MaxHeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MaxHeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
self.0.total_cmp(&other.0)
}
}
impl HnswIndex {
pub fn search(&self, query: &[f32], k: usize) -> Vec<(String, f32)> {
let mut ep = match self.entry_point {
Some(ep) => ep,
None => return Vec::new(),
};
for lc in (1..=self.max_level).rev() {
let results = self.search_layer(query, ep, 1, lc);
if !results.is_empty() {
ep = results[0].0;
}
}
let candidates = self.search_layer(query, ep, self.config.ef_search, 0);
candidates
.into_iter()
.take(k)
.map(|(id, dist)| (self.nodes[id].key.clone(), dist))
.collect()
}
pub fn search_layer(
&self,
query: &[f32],
entry_point: usize,
ef: usize,
level: usize,
) -> Vec<(usize, f32)> {
use std::collections::{BinaryHeap, HashSet};
let mut visited = HashSet::new();
let mut candidates = BinaryHeap::new(); let mut results = BinaryHeap::new();
let ep_dist = self.metric.distance(query, &self.nodes[entry_point].vector);
visited.insert(entry_point);
candidates.push(MinHeapEntry(ep_dist, entry_point));
results.push(MaxHeapEntry(ep_dist, entry_point));
while let Some(MinHeapEntry(c_dist, c_id)) = candidates.pop() {
let worst_dist = results.peek().map(|e| e.0).unwrap_or(f32::MAX);
if c_dist > worst_dist {
break;
}
if level < self.nodes[c_id].neighbors.len() {
for (neighbor_id, _) in &self.nodes[c_id].neighbors[level] {
if visited.insert(*neighbor_id) {
let dist = self
.metric
.distance(query, &self.nodes[*neighbor_id].vector);
let worst_dist = results.peek().map(|e| e.0).unwrap_or(f32::MAX);
if results.len() < ef || dist < worst_dist {
candidates.push(MinHeapEntry(dist, *neighbor_id));
results.push(MaxHeapEntry(dist, *neighbor_id));
if results.len() > ef {
results.pop();
}
}
}
}
}
}
let mut result_vec: Vec<_> = results
.into_iter()
.map(|MaxHeapEntry(d, id)| (id, d))
.collect();
result_vec.sort_by(|a, b| a.1.total_cmp(&b.1));
result_vec
}
pub fn search_with_ef(&self, query: &[f32], k: usize, ef_search: usize) -> Vec<(String, f32)> {
let mut ep = match self.entry_point {
Some(ep) => ep,
None => return Vec::new(),
};
for lc in (1..=self.max_level).rev() {
let results = self.search_layer(query, ep, 1, lc);
if !results.is_empty() {
ep = results[0].0;
}
}
let candidates = self.search_layer(query, ep, ef_search, 0);
candidates
.into_iter()
.take(k)
.map(|(id, dist)| (self.nodes[id].key.clone(), dist))
.collect()
}
}
const HNSW_MAGIC: &[u8; 4] = b"HNSW";
const HNSW_VERSION: u16 = 1;
impl HnswIndex {
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let file = File::create(path)?;
let mut writer = BufWriter::new(file);
writer.write_all(HNSW_MAGIC)?;
writer.write_all(&HNSW_VERSION.to_le_bytes())?;
writer.write_all(&self.dimensions.to_le_bytes())?;
let metric_byte: u8 = match self.metric {
DistanceMetric::Cosine => 0,
DistanceMetric::Euclidean => 1,
DistanceMetric::DotProduct => 2,
};
writer.write_all(&[metric_byte])?;
bincode::serialize_into(&mut writer, &self.config)?;
bincode::serialize_into(&mut writer, &self.nodes)?;
bincode::serialize_into(&mut writer, &self.entry_point)?;
bincode::serialize_into(&mut writer, &self.max_level)?;
bincode::serialize_into(&mut writer, &self.rng_state)?;
writer.flush()?;
Ok(())
}
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let mut magic = [0u8; 4];
reader.read_exact(&mut magic)?;
if &magic != HNSW_MAGIC {
return Err(SynaError::CorruptedIndex(
"Invalid magic bytes - not an HNSW index file".to_string(),
));
}
let mut version_bytes = [0u8; 2];
reader.read_exact(&mut version_bytes)?;
let version = u16::from_le_bytes(version_bytes);
if version != HNSW_VERSION {
return Err(SynaError::CorruptedIndex(format!(
"Unsupported version: {} (expected {})",
version, HNSW_VERSION
)));
}
let mut dims_bytes = [0u8; 2];
reader.read_exact(&mut dims_bytes)?;
let dimensions = u16::from_le_bytes(dims_bytes);
let mut metric_byte = [0u8; 1];
reader.read_exact(&mut metric_byte)?;
let metric = match metric_byte[0] {
0 => DistanceMetric::Cosine,
1 => DistanceMetric::Euclidean,
2 => DistanceMetric::DotProduct,
_ => {
return Err(SynaError::CorruptedIndex(format!(
"Invalid metric byte: {}",
metric_byte[0]
)))
}
};
let config: HnswConfig = bincode::deserialize_from(&mut reader)?;
let nodes: Vec<HnswNode> = bincode::deserialize_from(&mut reader)?;
let entry_point: Option<usize> = bincode::deserialize_from(&mut reader)?;
let max_level: usize = bincode::deserialize_from(&mut reader)?;
let rng_state: u64 = bincode::deserialize_from(&mut reader)?;
let mut key_to_id = HashMap::new();
for (id, node) in nodes.iter().enumerate() {
key_to_id.insert(node.key.clone(), id);
}
Ok(Self {
config,
metric,
dimensions,
nodes,
entry_point,
max_level,
key_to_id,
rng_state,
})
}
pub fn load_validated<P: AsRef<Path>>(
path: P,
expected_dims: u16,
expected_metric: DistanceMetric,
) -> Result<Self> {
let index = Self::load(path)?;
if index.dimensions != expected_dims {
return Err(SynaError::DimensionMismatch {
expected: expected_dims,
got: index.dimensions,
});
}
if index.metric != expected_metric {
return Err(SynaError::CorruptedIndex(format!(
"Metric mismatch: expected {:?}, got {:?}",
expected_metric, index.metric
)));
}
Ok(index)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::BinaryHeap;
#[test]
fn test_hnsw_config_default() {
let config = HnswConfig::default();
assert_eq!(config.m, 16);
assert_eq!(config.m_max, 32);
assert_eq!(config.ef_construction, 200);
assert_eq!(config.ef_search, 100);
assert!((config.ml - 1.0 / 16.0_f64.ln()).abs() < 1e-10);
}
#[test]
fn test_hnsw_config_with_m() {
let config = HnswConfig::with_m(32);
assert_eq!(config.m, 32);
assert_eq!(config.m_max, 64);
assert!((config.ml - 1.0 / 32.0_f64.ln()).abs() < 1e-10);
}
#[test]
fn test_hnsw_config_builder() {
let config = HnswConfig::with_m(8).ef_construction(100).ef_search(50);
assert_eq!(config.m, 8);
assert_eq!(config.ef_construction, 100);
assert_eq!(config.ef_search, 50);
}
#[test]
fn test_candidate_ordering() {
let mut heap = BinaryHeap::new();
heap.push(Candidate {
node_id: 0,
distance: 1.0,
});
heap.push(Candidate {
node_id: 1,
distance: 3.0,
});
heap.push(Candidate {
node_id: 2,
distance: 2.0,
});
assert_eq!(heap.pop().unwrap().distance, 3.0);
assert_eq!(heap.pop().unwrap().distance, 2.0);
assert_eq!(heap.pop().unwrap().distance, 1.0);
}
#[test]
fn test_min_candidate_ordering() {
let mut heap = BinaryHeap::new();
heap.push(MinCandidate {
node_id: 0,
distance: 1.0,
});
heap.push(MinCandidate {
node_id: 1,
distance: 3.0,
});
heap.push(MinCandidate {
node_id: 2,
distance: 2.0,
});
assert_eq!(heap.pop().unwrap().distance, 1.0);
assert_eq!(heap.pop().unwrap().distance, 2.0);
assert_eq!(heap.pop().unwrap().distance, 3.0);
}
#[test]
fn test_hnsw_node_creation() {
let node = HnswNode::new("test_key".to_string(), vec![1.0, 2.0, 3.0], 2);
assert_eq!(node.key, "test_key");
assert_eq!(node.vector, vec![1.0, 2.0, 3.0]);
assert_eq!(node.level(), 2);
assert_eq!(node.neighbors.len(), 3); }
#[test]
fn test_hnsw_index_creation() {
let index = HnswIndex::new(128, DistanceMetric::Cosine, HnswConfig::default());
assert_eq!(index.dimensions(), 128);
assert_eq!(index.metric(), DistanceMetric::Cosine);
assert!(index.is_empty());
assert_eq!(index.len(), 0);
assert!(index.entry_point().is_none());
}
#[test]
fn test_random_level_distribution() {
let mut index = HnswIndex::new(128, DistanceMetric::Cosine, HnswConfig::default());
let mut level_counts = [0usize; 10];
for _ in 0..10000 {
let level = index.random_level();
if level < 10 {
level_counts[level] += 1;
}
}
assert!(level_counts[0] > level_counts[1]);
assert!(level_counts[1] > level_counts[2]);
assert!(level_counts[0] > 5000);
}
#[test]
fn test_hnsw_stats_empty() {
let index = HnswIndex::new(128, DistanceMetric::Cosine, HnswConfig::default());
let stats = index.stats();
assert_eq!(stats.num_nodes, 0);
assert_eq!(stats.total_edges, 0);
assert_eq!(stats.avg_edges_per_node, 0.0);
}
#[test]
fn test_search_empty_index() {
let index = HnswIndex::new(3, DistanceMetric::Euclidean, HnswConfig::default());
let query = vec![1.0, 2.0, 3.0];
let results = index.search(&query, 5);
assert!(results.is_empty());
}
#[test]
fn test_search_single_node() {
let mut index = HnswIndex::new(3, DistanceMetric::Euclidean, HnswConfig::default());
let node = HnswNode::new("node1".to_string(), vec![1.0, 0.0, 0.0], 0);
index.nodes.push(node);
index.key_to_id.insert("node1".to_string(), 0);
index.entry_point = Some(0);
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 5);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, "node1");
assert!(results[0].1 < 0.001); }
#[test]
fn test_search_multiple_nodes_sorted() {
let mut index = HnswIndex::new(3, DistanceMetric::Euclidean, HnswConfig::default());
let node1 = HnswNode::new("close".to_string(), vec![1.0, 0.0, 0.0], 0);
let node2 = HnswNode::new("medium".to_string(), vec![2.0, 0.0, 0.0], 0);
let node3 = HnswNode::new("far".to_string(), vec![3.0, 0.0, 0.0], 0);
index.nodes.push(node1);
index.nodes.push(node2);
index.nodes.push(node3);
index.key_to_id.insert("close".to_string(), 0);
index.key_to_id.insert("medium".to_string(), 1);
index.key_to_id.insert("far".to_string(), 2);
index.nodes[0].neighbors[0] = vec![(1, 1.0), (2, 2.0)];
index.nodes[1].neighbors[0] = vec![(0, 1.0), (2, 1.0)];
index.nodes[2].neighbors[0] = vec![(0, 2.0), (1, 1.0)];
index.entry_point = Some(0);
let query = vec![0.0, 0.0, 0.0];
let results = index.search(&query, 3);
assert_eq!(results.len(), 3);
assert_eq!(results[0].0, "close"); assert_eq!(results[1].0, "medium"); assert_eq!(results[2].0, "far");
assert!(results[0].1 < results[1].1);
assert!(results[1].1 < results[2].1);
}
#[test]
fn test_search_k_limit() {
let mut index = HnswIndex::new(3, DistanceMetric::Euclidean, HnswConfig::default());
for i in 0..5 {
let node = HnswNode::new(format!("node{}", i), vec![i as f32, 0.0, 0.0], 0);
index.nodes.push(node);
index.key_to_id.insert(format!("node{}", i), i);
}
for i in 0..5 {
let mut neighbors = Vec::new();
for j in 0..5 {
if i != j {
neighbors.push((j, (i as f32 - j as f32).abs()));
}
}
index.nodes[i].neighbors[0] = neighbors;
}
index.entry_point = Some(0);
let query = vec![0.0, 0.0, 0.0];
let results = index.search(&query, 2);
assert_eq!(results.len(), 2);
}
#[test]
fn test_search_with_ef() {
let mut index = HnswIndex::new(3, DistanceMetric::Euclidean, HnswConfig::default());
let node = HnswNode::new("node1".to_string(), vec![1.0, 0.0, 0.0], 0);
index.nodes.push(node);
index.key_to_id.insert("node1".to_string(), 0);
index.entry_point = Some(0);
let query = vec![1.0, 0.0, 0.0];
let results = index.search_with_ef(&query, 5, 50);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, "node1");
}
#[test]
fn test_min_heap_entry_ordering() {
let mut heap = BinaryHeap::new();
heap.push(MinHeapEntry(3.0, 0));
heap.push(MinHeapEntry(1.0, 1));
heap.push(MinHeapEntry(2.0, 2));
assert_eq!(heap.pop().unwrap().0, 1.0);
assert_eq!(heap.pop().unwrap().0, 2.0);
assert_eq!(heap.pop().unwrap().0, 3.0);
}
#[test]
fn test_max_heap_entry_ordering() {
let mut heap = BinaryHeap::new();
heap.push(MaxHeapEntry(1.0, 0));
heap.push(MaxHeapEntry(3.0, 1));
heap.push(MaxHeapEntry(2.0, 2));
assert_eq!(heap.pop().unwrap().0, 3.0);
assert_eq!(heap.pop().unwrap().0, 2.0);
assert_eq!(heap.pop().unwrap().0, 1.0);
}
#[test]
fn test_save_load_empty_index() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("empty.hnsw");
let index = HnswIndex::new(128, DistanceMetric::Cosine, HnswConfig::default());
index.save(&path).unwrap();
let loaded = HnswIndex::load(&path).unwrap();
assert_eq!(loaded.dimensions(), 128);
assert_eq!(loaded.metric(), DistanceMetric::Cosine);
assert!(loaded.is_empty());
assert!(loaded.entry_point().is_none());
}
#[test]
fn test_save_load_with_nodes() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("nodes.hnsw");
let mut index = HnswIndex::new(3, DistanceMetric::Euclidean, HnswConfig::default());
let node1 = HnswNode::new("node1".to_string(), vec![1.0, 0.0, 0.0], 0);
let node2 = HnswNode::new("node2".to_string(), vec![0.0, 1.0, 0.0], 0);
let node3 = HnswNode::new("node3".to_string(), vec![0.0, 0.0, 1.0], 0);
index.nodes.push(node1);
index.nodes.push(node2);
index.nodes.push(node3);
index.key_to_id.insert("node1".to_string(), 0);
index.key_to_id.insert("node2".to_string(), 1);
index.key_to_id.insert("node3".to_string(), 2);
index.nodes[0].neighbors[0] = vec![(1, 1.414), (2, 1.414)];
index.nodes[1].neighbors[0] = vec![(0, 1.414), (2, 1.414)];
index.nodes[2].neighbors[0] = vec![(0, 1.414), (1, 1.414)];
index.entry_point = Some(0);
index.save(&path).unwrap();
let loaded = HnswIndex::load(&path).unwrap();
assert_eq!(loaded.dimensions(), 3);
assert_eq!(loaded.metric(), DistanceMetric::Euclidean);
assert_eq!(loaded.len(), 3);
assert_eq!(loaded.entry_point(), Some(0));
assert!(loaded.contains_key("node1"));
assert!(loaded.contains_key("node2"));
assert!(loaded.contains_key("node3"));
let node1 = loaded.get_node(0).unwrap();
assert_eq!(node1.vector, vec![1.0, 0.0, 0.0]);
assert_eq!(node1.neighbors[0].len(), 2);
}
#[test]
fn test_save_load_search_consistency() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("search.hnsw");
let mut index = HnswIndex::new(3, DistanceMetric::Euclidean, HnswConfig::default());
let node1 = HnswNode::new("close".to_string(), vec![1.0, 0.0, 0.0], 0);
let node2 = HnswNode::new("medium".to_string(), vec![2.0, 0.0, 0.0], 0);
let node3 = HnswNode::new("far".to_string(), vec![3.0, 0.0, 0.0], 0);
index.nodes.push(node1);
index.nodes.push(node2);
index.nodes.push(node3);
index.key_to_id.insert("close".to_string(), 0);
index.key_to_id.insert("medium".to_string(), 1);
index.key_to_id.insert("far".to_string(), 2);
index.nodes[0].neighbors[0] = vec![(1, 1.0), (2, 2.0)];
index.nodes[1].neighbors[0] = vec![(0, 1.0), (2, 1.0)];
index.nodes[2].neighbors[0] = vec![(0, 2.0), (1, 1.0)];
index.entry_point = Some(0);
let query = vec![0.0, 0.0, 0.0];
let results_before = index.search(&query, 3);
index.save(&path).unwrap();
let loaded = HnswIndex::load(&path).unwrap();
let results_after = loaded.search(&query, 3);
assert_eq!(results_before.len(), results_after.len());
for (before, after) in results_before.iter().zip(results_after.iter()) {
assert_eq!(before.0, after.0); assert!((before.1 - after.1).abs() < 1e-6); }
}
#[test]
fn test_load_invalid_magic() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("invalid.hnsw");
std::fs::write(&path, b"XXXX").unwrap();
let result = HnswIndex::load(&path);
assert!(result.is_err());
match result {
Err(crate::error::SynaError::CorruptedIndex(msg)) => {
assert!(msg.contains("Invalid magic bytes"));
}
_ => panic!("Expected CorruptedIndex error"),
}
}
#[test]
fn test_load_invalid_version() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("invalid_version.hnsw");
let mut data = Vec::new();
data.extend_from_slice(b"HNSW");
data.extend_from_slice(&99u16.to_le_bytes());
std::fs::write(&path, &data).unwrap();
let result = HnswIndex::load(&path);
assert!(result.is_err());
match result {
Err(crate::error::SynaError::CorruptedIndex(msg)) => {
assert!(msg.contains("Unsupported version"));
}
_ => panic!("Expected CorruptedIndex error"),
}
}
#[test]
fn test_load_validated_success() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("validated.hnsw");
let index = HnswIndex::new(128, DistanceMetric::Cosine, HnswConfig::default());
index.save(&path).unwrap();
let loaded = HnswIndex::load_validated(&path, 128, DistanceMetric::Cosine).unwrap();
assert_eq!(loaded.dimensions(), 128);
assert_eq!(loaded.metric(), DistanceMetric::Cosine);
}
#[test]
fn test_load_validated_dimension_mismatch() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("dim_mismatch.hnsw");
let index = HnswIndex::new(128, DistanceMetric::Cosine, HnswConfig::default());
index.save(&path).unwrap();
let result = HnswIndex::load_validated(&path, 256, DistanceMetric::Cosine);
assert!(result.is_err());
match result {
Err(crate::error::SynaError::DimensionMismatch { expected, got }) => {
assert_eq!(expected, 256);
assert_eq!(got, 128);
}
_ => panic!("Expected DimensionMismatch error"),
}
}
#[test]
fn test_load_validated_metric_mismatch() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("metric_mismatch.hnsw");
let index = HnswIndex::new(128, DistanceMetric::Cosine, HnswConfig::default());
index.save(&path).unwrap();
let result = HnswIndex::load_validated(&path, 128, DistanceMetric::Euclidean);
assert!(result.is_err());
match result {
Err(crate::error::SynaError::CorruptedIndex(msg)) => {
assert!(msg.contains("Metric mismatch"));
}
_ => panic!("Expected CorruptedIndex error"),
}
}
#[test]
fn test_save_load_preserves_config() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("config.hnsw");
let config = HnswConfig::with_m(32).ef_construction(300).ef_search(150);
let index = HnswIndex::new(64, DistanceMetric::DotProduct, config);
index.save(&path).unwrap();
let loaded = HnswIndex::load(&path).unwrap();
assert_eq!(loaded.config().m, 32);
assert_eq!(loaded.config().m_max, 64);
assert_eq!(loaded.config().ef_construction, 300);
assert_eq!(loaded.config().ef_search, 150);
assert_eq!(loaded.metric(), DistanceMetric::DotProduct);
}
#[test]
fn test_save_load_all_metrics() {
let dir = tempfile::tempdir().unwrap();
for metric in [
DistanceMetric::Cosine,
DistanceMetric::Euclidean,
DistanceMetric::DotProduct,
] {
let path = dir.path().join(format!("{:?}.hnsw", metric));
let index = HnswIndex::new(64, metric, HnswConfig::default());
index.save(&path).unwrap();
let loaded = HnswIndex::load(&path).unwrap();
assert_eq!(loaded.metric(), metric);
}
}
}