use rusqlite::OptionalExtension;
use crate::hnsw::{
config::HnswConfig,
distance_metric::DistanceMetric,
errors::HnswError,
layer::HnswLayer,
multilayer::{LevelDistributor, MultiLayerNodeManager},
neighborhood::NeighborhoodSearch,
storage::{VectorStorage, VectorStorageStats},
};
#[cfg(test)]
use crate::hnsw::{config::hnsw_config, errors::HnswIndexError};
pub struct HnswIndex {
pub(crate) name: String,
pub(crate) config: HnswConfig,
pub(crate) layers: Vec<HnswLayer>,
pub(crate) storage: Box<dyn VectorStorage>,
pub(crate) entry_points: Vec<u64>,
pub(crate) vector_count: usize,
pub(crate) search_engine: NeighborhoodSearch,
pub(crate) level_distributor: Option<LevelDistributor>,
pub(crate) multi_layer_manager: Option<MultiLayerNodeManager>,
}
#[derive(Debug, Clone)]
pub struct HnswIndexStats {
pub vector_count: usize,
pub layer_count: usize,
pub entry_point_count: usize,
pub dimension: usize,
pub distance_metric: DistanceMetric,
pub storage_stats: VectorStorageStats,
pub layer_stats: Vec<(usize, usize, f32)>,
}
include!("index_api.rs");
include!("index_internal.rs");
include!("index_persist.rs");
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::SqliteGraph;
use crate::hnsw::{DistanceMetric, HnswConfigBuilder};
#[test]
fn test_hnsw_index_creation() {
let config = HnswConfigBuilder::new()
.dimension(3)
.distance_metric(DistanceMetric::Euclidean)
.build()
.unwrap();
let hnsw = HnswIndex::new("test_index", config).unwrap();
let stats = hnsw.statistics().unwrap();
assert_eq!(stats.vector_count, 0);
assert_eq!(stats.dimension, 3);
assert_eq!(stats.distance_metric, DistanceMetric::Euclidean);
}
#[test]
fn test_vector_insertion() {
let config = hnsw_config().dimension(3).build().unwrap();
let mut hnsw = HnswIndex::new("test_insert", config).unwrap();
let vector = vec![1.0, 0.0, 0.0];
let metadata = serde_json::json!({"label": "test"});
let result = hnsw.insert_vector(&vector, Some(metadata));
println!("Insert result: {:?}", result);
let vector_id = result.unwrap();
assert!(vector_id > 0);
let stats = hnsw.statistics().unwrap();
assert_eq!(stats.vector_count, 1);
}
#[test]
fn test_dimension_mismatch_error() {
let mut hnsw = HnswIndex::new("test_dim_error", HnswConfig::default()).unwrap();
let wrong_vector = vec![1.0, 0.0];
let result = hnsw.insert_vector(&wrong_vector, None);
assert!(result.is_err());
let error = result.unwrap_err();
assert!(matches!(
error,
HnswError::Index(HnswIndexError::VectorDimensionMismatch { .. })
));
}
#[test]
fn test_empty_search() {
let hnsw = HnswIndex::new("test_empty_search", HnswConfig::default()).unwrap();
let query = vec![1.0; 768];
let results = hnsw.search(&query, 5).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_vector_retrieval() {
let config = hnsw_config().dimension(3).build().unwrap();
let mut hnsw = HnswIndex::new("test_retrieval", config).unwrap();
let vector = vec![1.0, 0.0, 0.0];
let metadata = serde_json::json!({"label": "test"});
let vector_id = hnsw.insert_vector(&vector, Some(metadata.clone())).unwrap();
let result = hnsw.get_vector(vector_id).unwrap();
assert!(result.is_some());
let (retrieved_vector, retrieved_metadata) = result.unwrap();
assert_eq!(retrieved_vector, vector);
assert_eq!(retrieved_metadata, metadata);
}
#[test]
fn test_sqlite_graph_integration() {
let graph = SqliteGraph::open_in_memory().unwrap();
let config = HnswConfigBuilder::new()
.dimension(4)
.distance_metric(DistanceMetric::Cosine)
.build()
.unwrap();
let hnsw_indexes = graph.hnsw_index("test_index", config).unwrap();
let hnsw = hnsw_indexes.get("test_index").unwrap();
let stats = hnsw.statistics().unwrap();
assert_eq!(stats.vector_count, 0);
assert_eq!(stats.dimension, 4);
assert_eq!(stats.distance_metric, DistanceMetric::Cosine);
}
#[test]
fn test_basic_search_functionality() {
let mut hnsw = HnswIndex::new(
"test_search",
HnswConfigBuilder::new()
.dimension(2)
.m_connections(4)
.distance_metric(DistanceMetric::Euclidean)
.build()
.unwrap(),
)
.unwrap();
let vectors = vec![
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![-1.0, 0.0],
vec![0.0, -1.0],
];
let mut vector_ids = Vec::new();
for vector in vectors {
let id = hnsw.insert_vector(&vector, None).unwrap();
vector_ids.push(id);
}
let query = vec![0.9, 0.1];
let results = hnsw.search(&query, 2).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 2);
for window in results.windows(2) {
assert!(window[0].1 <= window[1].1);
}
}
#[test]
fn test_index_statistics() {
let mut hnsw = HnswIndex::new(
"test_stats",
HnswConfigBuilder::new()
.dimension(3)
.max_layers(3)
.distance_metric(DistanceMetric::Euclidean) .build()
.unwrap(),
)
.unwrap();
for i in 1..=5 {
let vector = vec![i as f32, (i * 2) as f32, (i * 3) as f32];
hnsw.insert_vector(&vector, None).unwrap();
}
let stats = hnsw.statistics().unwrap();
assert_eq!(stats.vector_count, 5);
assert_eq!(stats.layer_count, 3);
assert_eq!(stats.dimension, 3);
assert!(!stats.layer_stats.is_empty());
}
#[test]
fn test_metadata_persistence() {
use std::fs;
let test_dir = "/tmp/test_hnsw_metadata_persistence";
let db_path = format!("{}/test.db", test_dir);
let _ = fs::remove_dir_all(test_dir);
fs::create_dir_all(test_dir).unwrap();
{
let graph = SqliteGraph::open(&db_path).unwrap();
let config = HnswConfigBuilder::new()
.dimension(128)
.distance_metric(DistanceMetric::Euclidean)
.build()
.unwrap();
let hnsw_indexes = graph.hnsw_index("persist_test", config).unwrap();
let hnsw = hnsw_indexes.get("persist_test").unwrap();
assert_eq!(hnsw.name(), "persist_test");
assert_eq!(hnsw.config().dimension, 128);
assert_eq!(hnsw.config().distance_metric, DistanceMetric::Euclidean);
let conn = graph.connection();
let conn_ref = conn.underlying();
hnsw.save_metadata(conn_ref).unwrap();
}
{
let graph2 = SqliteGraph::open(&db_path).unwrap();
let index_names = graph2.list_hnsw_indexes().unwrap();
assert_eq!(index_names, vec!["persist_test".to_string()]);
let loaded_hnsw = graph2
.get_hnsw_index_ref("persist_test", |hnsw| {
assert_eq!(hnsw.name(), "persist_test");
assert_eq!(hnsw.config().dimension, 128);
assert_eq!(hnsw.config().distance_metric, DistanceMetric::Euclidean);
hnsw.config().dimension
})
.unwrap();
assert_eq!(loaded_hnsw, 128);
}
let _ = fs::remove_dir_all(test_dir);
}
#[test]
fn test_vector_loading_and_rebuild() {
use rusqlite::Connection;
use std::fs;
let test_dir = "/tmp/test_hnsw_vector_loading";
let db_path = format!("{}/test.db", test_dir);
let _ = fs::remove_dir_all(test_dir);
fs::create_dir_all(test_dir).unwrap();
{
let conn = Connection::open(&db_path).unwrap();
crate::schema::ensure_schema(&conn).unwrap();
conn.execute(
"INSERT INTO hnsw_indexes (name, dimension, m, ef_construction, distance_metric, vector_count, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
rusqlite::params!["load_test", 3, 16, 200, "euclidean", 5, 1000, 1000],
).unwrap();
let index_id = conn.last_insert_rowid();
for i in 0..5 {
let vector = vec![i as f32, (i * 2) as f32, (i * 3) as f32];
let vector_bytes = bytemuck::cast_slice::<f32, u8>(&vector).to_vec();
conn.execute(
"INSERT INTO hnsw_vectors (index_id, vector_data, metadata, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5)",
rusqlite::params![index_id, vector_bytes, None::<String>, 1000, 1000],
).unwrap();
}
}
{
let conn2 = Connection::open(&db_path).unwrap();
crate::schema::ensure_schema(&conn2).unwrap();
let hnsw_metadata = HnswIndex::load_metadata(&conn2, "load_test").unwrap();
assert_eq!(hnsw_metadata.vector_count, 5);
assert_eq!(hnsw_metadata.storage.vector_count().unwrap(), 0);
let hnsw_loaded = HnswIndex::load_with_vectors(&conn2, "load_test").unwrap();
assert_eq!(hnsw_loaded.vector_count, 5);
assert_eq!(hnsw_loaded.storage.vector_count().unwrap(), 5);
let (vector, _) = hnsw_loaded.get_vector(1).unwrap().unwrap();
assert_eq!(vector, vec![0.0, 0.0, 0.0]);
let query = vec![2.0, 4.0, 6.0];
let results = hnsw_loaded.search(&query, 3).unwrap();
assert!(!results.is_empty());
}
let _ = fs::remove_dir_all(test_dir);
}
#[test]
fn test_e2e_hnsw_persistence() {
use rusqlite::Connection;
use std::fs;
let test_dir = "/tmp/test_hnsw_e2e_persistence";
let db_path = format!("{}/test.db", test_dir);
let _ = fs::remove_dir_all(test_dir);
fs::create_dir_all(test_dir).unwrap();
{
let conn = Connection::open(&db_path).unwrap();
crate::schema::ensure_schema(&conn).unwrap();
conn.execute(
"INSERT INTO hnsw_indexes (name, dimension, m, ef_construction, distance_metric, vector_count, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
rusqlite::params!["e2e_test", 3, 16, 200, "euclidean", 5, 1000, 1000],
).unwrap();
let index_id = conn.last_insert_rowid();
for i in 0..5 {
let vector = vec![i as f32, (i * 2) as f32, (i * 3) as f32];
let vector_bytes = bytemuck::cast_slice::<f32, u8>(&vector).to_vec();
let metadata = serde_json::json!({"label": format!("vector_{}", i)}).to_string();
conn.execute(
"INSERT INTO hnsw_vectors (index_id, vector_data, metadata, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5)",
rusqlite::params![index_id, vector_bytes, metadata, 1000, 1000],
).unwrap();
}
}
{
let graph = SqliteGraph::open(&db_path).unwrap();
let index_names = graph.list_hnsw_indexes().unwrap();
assert_eq!(index_names, vec!["e2e_test".to_string()]);
let loaded_count = graph
.get_hnsw_index_ref("e2e_test", |hnsw| {
assert_eq!(hnsw.vector_count(), 5);
let (vector, metadata) = hnsw.get_vector(1).unwrap().unwrap();
assert_eq!(vector, vec![0.0, 0.0, 0.0]);
assert_eq!(metadata, serde_json::json!({"label": "vector_0"}));
let query = vec![2.0, 4.0, 6.0];
let results = hnsw.search(&query, 3).unwrap();
assert!(!results.is_empty());
hnsw.vector_count()
})
.unwrap();
assert_eq!(loaded_count, 5);
}
let _ = fs::remove_dir_all(test_dir);
}
#[test]
fn test_multilayer_level_distribution() {
let config = HnswConfig {
dimension: 4,
m: 16,
ef_construction: 200,
ef_search: 50,
ml: 4,
distance_metric: DistanceMetric::Euclidean,
enable_multilayer: true,
multilayer_level_distribution_base: Some(16),
multilayer_deterministic_seed: Some(42),
};
let hnsw = HnswIndex::new("test_multilayer_dist", config).unwrap();
assert!(
hnsw.has_level_distributor(),
"LevelDistributor should be initialized in multi-layer mode"
);
use crate::hnsw::multilayer::LevelDistributor;
let mut distributor = LevelDistributor::new(16.0, 4).with_seed(42);
let mut level_counts = vec![0; 4];
for _ in 0..1000 {
let level = distributor.sample_level_internal();
level_counts[level] += 1;
}
assert!(
level_counts[0] >= 900 && level_counts[0] <= 950,
"Level 0 should have ~938 samples, got {}",
level_counts[0]
);
assert!(
level_counts[1] >= 40 && level_counts[1] <= 80,
"Level 1 should have ~62 samples, got {}",
level_counts[1]
);
assert!(
level_counts[2] >= 1 && level_counts[2] <= 10,
"Level 2 should have ~4 samples, got {}",
level_counts[2]
);
println!(
"Level distribution (direct sampling): L0={}, L1={}, L2={}, L3={}",
level_counts[0], level_counts[1], level_counts[2], level_counts[3]
);
}
#[test]
fn test_single_layer_mode() {
let config = HnswConfig {
dimension: 4,
m: 16,
ef_construction: 200,
ef_search: 50,
ml: 4,
distance_metric: DistanceMetric::Euclidean,
enable_multilayer: false, multilayer_level_distribution_base: None,
multilayer_deterministic_seed: None,
};
let hnsw = HnswIndex::new("test_single_layer", config.clone()).unwrap();
assert!(
!hnsw.has_level_distributor(),
"LevelDistributor should not be initialized in single-layer mode"
);
let test_vector = vec![1.0, 0.0, 0.0, 0.0];
let mut hnsw_mut = HnswIndex::new("test_single_layer_mut", config).unwrap();
for _ in 0..100 {
hnsw_mut.insert_vector(&test_vector, None).unwrap();
}
let stats = hnsw_mut.statistics().unwrap();
assert_eq!(
stats.layer_stats[0].0, 100,
"Layer 0 should have 100 vectors"
);
assert_eq!(
stats.layer_stats[1].0, 0,
"Layer 1 should be empty in single-layer mode"
);
assert_eq!(
stats.layer_stats[2].0, 0,
"Layer 2 should be empty in single-layer mode"
);
assert_eq!(
stats.layer_stats[3].0, 0,
"Layer 3 should be empty in single-layer mode"
);
}
#[test]
fn test_multilayer_recall() {
use std::collections::HashSet;
let config = HnswConfig {
dimension: 64,
m: 16,
ef_construction: 200,
ef_search: 50,
ml: 16,
distance_metric: DistanceMetric::Euclidean,
enable_multilayer: true, multilayer_level_distribution_base: Some(16),
multilayer_deterministic_seed: Some(42),
};
let mut hnsw = HnswIndex::new("recall_test_unique", config).unwrap();
let mut vectors = Vec::new();
for i in 0..1000 {
let vector: Vec<f32> = (0..64)
.map(|j| ((i * 64 + j) as f32 * 0.01).cos())
.collect();
vectors.push(vector.clone());
hnsw.insert_vector(&vector, None).unwrap();
}
let k = 10;
let query = &vectors[0];
let hnsw_results = hnsw.search(query, k).unwrap();
let hnsw_ids: HashSet<_> = hnsw_results.iter().map(|(id, _)| *id).collect();
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
let mut exact_results: Vec<_> = vectors
.iter()
.enumerate()
.map(|(i, v)| (i as u64 + 1, euclidean_distance(query, v)))
.collect();
for i in 0..exact_results.len() {
let mut min_idx = i;
for j in (i + 1)..exact_results.len() {
if exact_results[j].1 < exact_results[min_idx].1 {
min_idx = j;
}
}
if min_idx != i {
let temp = exact_results[i];
exact_results[i] = exact_results[min_idx];
exact_results[min_idx] = temp;
}
}
let exact_ids: HashSet<_> = exact_results.iter().take(k).map(|(id, _)| *id).collect();
let overlap = hnsw_ids.intersection(&exact_ids).count();
let recall = (overlap as f64 / k as f64) * 100.0;
println!("HNSW results: {:?}", hnsw_results);
println!("Exact top {}: {:?}", k, exact_ids);
println!("Recall: {:.1}% ({}/{})", recall, overlap, k);
assert!(
recall >= 90.0,
"Recall {:.1}% is below 90% threshold",
recall
);
}
#[test]
#[ignore = "flaky: fails non-deterministically when run with all lib tests due to HNSW test pollution / NodeNotFound bug"]
fn test_multilayer_search_complexity_ologn() {
use std::time::Instant;
let sizes = vec![100, 1000, 10000];
let mut search_times = Vec::new();
for size in sizes {
let config = HnswConfig {
dimension: 64,
m: 16,
ef_construction: 200,
ef_search: 50,
ml: 16,
distance_metric: DistanceMetric::Euclidean,
enable_multilayer: true,
multilayer_level_distribution_base: Some(16),
multilayer_deterministic_seed: Some(42),
};
let mut hnsw = HnswIndex::new(&format!("complexity_test_{}", size), config).unwrap();
for i in 0..size {
let vector: Vec<f32> = (0..64)
.map(|j| ((i * 64 + j) as f32 * 0.01).sin())
.collect();
hnsw.insert_vector(&vector, None).unwrap();
}
let query: Vec<f32> = (0..64).map(|j| (j as f32 * 0.01).sin()).collect();
let iterations = 10;
let start = Instant::now();
for _ in 0..iterations {
let _ = hnsw.search(&query, 10).unwrap();
}
let elapsed = start.elapsed();
let avg_time_ns = elapsed.as_nanos() / iterations as u128;
search_times.push((size, avg_time_ns));
println!("Size {}: avg search time = {} ns", size, avg_time_ns);
}
let ratio_100_to_1000 = search_times[1].1 as f64 / search_times[0].1 as f64;
println!("Time ratio (1000/100): {:.2}x", ratio_100_to_1000);
assert!(
ratio_100_to_1000 < 10.0,
"Search time ratio {:.2}x suggests worse than log scaling; expected < 10x for O(log N)",
ratio_100_to_1000
);
let ratio_1000_to_10000 = search_times[2].1 as f64 / search_times[1].1 as f64;
println!("Time ratio (10000/1000): {:.2}x", ratio_1000_to_10000);
assert!(
ratio_1000_to_10000 < 10.0,
"Search time ratio {:.2}x suggests worse than log scaling; expected < 10x for O(log N)",
ratio_1000_to_10000
);
let overall_ratio = search_times[2].1 as f64 / search_times[0].1 as f64;
println!("Overall time ratio (10000/100): {:.2}x", overall_ratio);
assert!(
overall_ratio < 50.0,
"Overall search time ratio {:.2}x suggests linear scaling; expected < 50x for O(log N) (linear would be 100x)",
overall_ratio
);
}
#[test]
fn test_multilayer_insert_layers_correct() {
let config = HnswConfig {
dimension: 64,
m: 16,
ef_construction: 200,
ef_search: 50,
ml: 16,
distance_metric: DistanceMetric::Euclidean,
enable_multilayer: true,
multilayer_level_distribution_base: Some(16),
multilayer_deterministic_seed: Some(42),
};
let mut hnsw = HnswIndex::new("test_layers", config).unwrap();
for i in 0..100 {
let vector: Vec<f32> = (0..64)
.map(|j| ((i * 64 + j) as f32 * 0.01).cos())
.collect();
hnsw.insert_vector(&vector, None).unwrap();
}
let stats = hnsw.statistics().unwrap();
println!("Layer stats: {:?}", stats.layer_stats);
assert_eq!(
stats.layer_stats[0].0, 100,
"Layer 0 should have all 100 vectors"
);
let layer1_count = stats.layer_stats[1].0;
assert!(
layer1_count > 0 && layer1_count < 20,
"Layer 1 should have some vectors (got {}), but not all",
layer1_count
);
assert!(
stats.layer_stats[0].0 >= stats.layer_stats[1].0,
"Layer 0 should have >= Layer 1"
);
assert!(
stats.layer_stats[1].0 >= stats.layer_stats[2].0,
"Layer 1 should have >= Layer 2"
);
assert!(
hnsw.has_level_distributor(),
"LevelDistributor should be initialized"
);
}
}