use rusqlite::OptionalExtension;
use std::collections::HashMap;
use std::sync::atomic::AtomicU64;
use crate::hnsw::{
config::HnswConfig,
distance_metric::DistanceMetric,
errors::HnswError,
layer::HnswLayer,
multilayer::{LevelDistributor, MultiLayerNodeManager},
neighborhood::NeighborhoodSearch,
storage::{VectorStorage, VectorStorageStats},
};
#[cfg(test)]
use crate::hnsw::{config::hnsw_config, errors::HnswIndexError};
pub struct HnswIndex {
pub(crate) name: String,
pub(crate) config: HnswConfig,
pub(crate) layers: Vec<HnswLayer>,
pub(crate) storage: Box<dyn VectorStorage>,
pub(crate) entry_points: Vec<u64>,
pub(crate) vector_count: usize,
pub(crate) search_engine: NeighborhoodSearch,
pub(crate) level_distributor: Option<LevelDistributor>,
pub(crate) multi_layer_manager: Option<MultiLayerNodeManager>,
pub(crate) vector_cache: HashMap<u64, Vec<f32>>,
pub(crate) insert_count: AtomicU64,
pub(crate) search_count: AtomicU64,
pub(crate) vector_cache_hits: AtomicU64,
pub(crate) vector_cache_misses: AtomicU64,
}
#[derive(Debug, Clone)]
pub struct HnswIndexStats {
pub vector_count: usize,
pub layer_count: usize,
pub entry_point_count: usize,
pub dimension: usize,
pub distance_metric: DistanceMetric,
pub storage_stats: VectorStorageStats,
pub layer_stats: Vec<(usize, usize, f32)>,
pub insert_count: u64,
pub search_count: u64,
pub vector_cache_hits: u64,
pub vector_cache_misses: u64,
}
include!("index_api.rs");
include!("index_internal.rs");
include!("index_persist.rs");
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::SqliteGraph;
use crate::hnsw::{DistanceMetric, HnswConfigBuilder};
#[test]
fn test_hnsw_index_creation() {
let config = HnswConfigBuilder::new()
.dimension(3)
.distance_metric(DistanceMetric::Euclidean)
.build()
.unwrap();
let hnsw = HnswIndex::new("test_index", config).unwrap();
let stats = hnsw.statistics().unwrap();
assert_eq!(stats.vector_count, 0);
assert_eq!(stats.dimension, 3);
assert_eq!(stats.distance_metric, DistanceMetric::Euclidean);
assert_eq!(stats.insert_count, 0);
assert_eq!(stats.search_count, 0);
assert_eq!(stats.vector_cache_hits, 0);
assert_eq!(stats.vector_cache_misses, 0);
}
#[test]
fn test_vector_insertion() {
let config = hnsw_config().dimension(3).build().unwrap();
let mut hnsw = HnswIndex::new("test_insert", config).unwrap();
let vector = vec![1.0, 0.0, 0.0];
let metadata = serde_json::json!({"label": "test"});
let result = hnsw.insert_vector(&vector, Some(metadata));
println!("Insert result: {:?}", result);
let vector_id = result.unwrap();
assert!(vector_id > 0);
let stats = hnsw.statistics().unwrap();
assert_eq!(stats.vector_count, 1);
}
#[test]
fn test_dimension_mismatch_error() {
let mut hnsw = HnswIndex::new("test_dim_error", HnswConfig::default()).unwrap();
let wrong_vector = vec![1.0, 0.0];
let result = hnsw.insert_vector(&wrong_vector, None);
assert!(result.is_err());
let error = result.unwrap_err();
assert!(matches!(
error,
HnswError::Index(HnswIndexError::VectorDimensionMismatch { .. })
));
}
#[test]
fn test_empty_search() {
let hnsw = HnswIndex::new("test_empty_search", HnswConfig::default()).unwrap();
let query = vec![1.0; 768];
let results = hnsw.search(&query, 5).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_vector_retrieval() {
let config = hnsw_config().dimension(3).build().unwrap();
let mut hnsw = HnswIndex::new("test_retrieval", config).unwrap();
let vector = vec![1.0, 0.0, 0.0];
let metadata = serde_json::json!({"label": "test"});
let vector_id = hnsw.insert_vector(&vector, Some(metadata.clone())).unwrap();
let result = hnsw.get_vector(vector_id).unwrap();
assert!(result.is_some());
let (retrieved_vector, retrieved_metadata) = result.unwrap();
assert_eq!(retrieved_vector, vector);
assert_eq!(retrieved_metadata, metadata);
}
#[test]
fn test_sqlite_graph_integration() {
let graph = SqliteGraph::open_in_memory().unwrap();
let config = HnswConfigBuilder::new()
.dimension(4)
.distance_metric(DistanceMetric::Cosine)
.build()
.unwrap();
let hnsw_indexes = graph.hnsw_index("test_index", config).unwrap();
let hnsw = hnsw_indexes.get("test_index").unwrap();
let stats = hnsw.statistics().unwrap();
assert_eq!(stats.vector_count, 0);
assert_eq!(stats.dimension, 4);
assert_eq!(stats.distance_metric, DistanceMetric::Cosine);
}
#[test]
fn test_basic_search_functionality() {
let mut hnsw = HnswIndex::new(
"test_search",
HnswConfigBuilder::new()
.dimension(2)
.m_connections(4)
.distance_metric(DistanceMetric::Euclidean)
.build()
.unwrap(),
)
.unwrap();
let vectors = vec![
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![-1.0, 0.0],
vec![0.0, -1.0],
];
let mut vector_ids = Vec::new();
for vector in vectors {
let id = hnsw.insert_vector(&vector, None).unwrap();
vector_ids.push(id);
}
let query = vec![0.9, 0.1];
let results = hnsw.search(&query, 2).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 2);
for window in results.windows(2) {
assert!(window[0].1 <= window[1].1);
}
}
#[test]
fn test_index_statistics() {
let mut hnsw = HnswIndex::new(
"test_stats",
HnswConfigBuilder::new()
.dimension(3)
.max_layers(3)
.distance_metric(DistanceMetric::Euclidean) .build()
.unwrap(),
)
.unwrap();
for i in 1..=5 {
let vector = vec![i as f32, (i * 2) as f32, (i * 3) as f32];
hnsw.insert_vector(&vector, None).unwrap();
}
let stats = hnsw.statistics().unwrap();
assert_eq!(stats.vector_count, 5);
assert_eq!(stats.layer_count, 3);
assert_eq!(stats.dimension, 3);
assert_eq!(stats.insert_count, 5);
assert_eq!(stats.search_count, 0);
assert_eq!(stats.vector_cache_hits, 4);
assert_eq!(stats.vector_cache_misses, 0);
assert!(!stats.layer_stats.is_empty());
}
#[test]
fn test_metadata_persistence() {
use std::fs;
let test_dir = "/tmp/test_hnsw_metadata_persistence";
let db_path = format!("{}/test.db", test_dir);
let _ = fs::remove_dir_all(test_dir);
fs::create_dir_all(test_dir).unwrap();
{
let graph = SqliteGraph::open(&db_path).unwrap();
let config = HnswConfigBuilder::new()
.dimension(128)
.distance_metric(DistanceMetric::Euclidean)
.build()
.unwrap();
let hnsw_indexes = graph.hnsw_index("persist_test", config).unwrap();
let hnsw = hnsw_indexes.get("persist_test").unwrap();
assert_eq!(hnsw.name(), "persist_test");
assert_eq!(hnsw.config().dimension, 128);
assert_eq!(hnsw.config().distance_metric, DistanceMetric::Euclidean);
let conn = graph.connection();
let conn_ref = conn.underlying();
hnsw.save_metadata(conn_ref).unwrap();
}
{
let graph2 = SqliteGraph::open(&db_path).unwrap();
let index_names = graph2.list_hnsw_indexes().unwrap();
assert_eq!(index_names, vec!["persist_test".to_string()]);
let loaded_hnsw = graph2
.get_hnsw_index_ref("persist_test", |hnsw| {
assert_eq!(hnsw.name(), "persist_test");
assert_eq!(hnsw.config().dimension, 128);
assert_eq!(hnsw.config().distance_metric, DistanceMetric::Euclidean);
hnsw.config().dimension
})
.unwrap();
assert_eq!(loaded_hnsw, 128);
}
let _ = fs::remove_dir_all(test_dir);
}
#[test]
fn test_vector_loading_and_rebuild() {
use rusqlite::Connection;
use std::fs;
let test_dir = "/tmp/test_hnsw_vector_loading";
let db_path = format!("{}/test.db", test_dir);
let _ = fs::remove_dir_all(test_dir);
fs::create_dir_all(test_dir).unwrap();
{
let conn = Connection::open(&db_path).unwrap();
crate::schema::ensure_schema(&conn).unwrap();
conn.execute(
"INSERT INTO hnsw_indexes (name, dimension, m, ef_construction, distance_metric, vector_count, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
rusqlite::params!["load_test", 3, 16, 200, "euclidean", 5, 1000, 1000],
).unwrap();
let index_id = conn.last_insert_rowid();
for i in 0..5 {
let vector = vec![i as f32, (i * 2) as f32, (i * 3) as f32];
let vector_bytes = bytemuck::cast_slice::<f32, u8>(&vector).to_vec();
conn.execute(
"INSERT INTO hnsw_vectors (index_id, vector_data, metadata, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5)",
rusqlite::params![index_id, vector_bytes, None::<String>, 1000, 1000],
).unwrap();
}
}
{
let conn2 = Connection::open(&db_path).unwrap();
crate::schema::ensure_schema(&conn2).unwrap();
let hnsw_metadata = HnswIndex::load_metadata(&conn2, "load_test").unwrap();
assert_eq!(hnsw_metadata.vector_count, 5);
assert_eq!(hnsw_metadata.storage.vector_count().unwrap(), 0);
let hnsw_loaded = HnswIndex::load_with_vectors(&conn2, "load_test").unwrap();
assert_eq!(hnsw_loaded.vector_count, 5);
assert_eq!(hnsw_loaded.storage.vector_count().unwrap(), 5);
let (vector, _) = hnsw_loaded.get_vector(1).unwrap().unwrap();
assert_eq!(vector, vec![0.0, 0.0, 0.0]);
let query = vec![2.0, 4.0, 6.0];
let results = hnsw_loaded.search(&query, 3).unwrap();
assert!(!results.is_empty());
}
let _ = fs::remove_dir_all(test_dir);
}
#[test]
fn test_e2e_hnsw_persistence() {
use rusqlite::Connection;
use std::fs;
let test_dir = "/tmp/test_hnsw_e2e_persistence";
let db_path = format!("{}/test.db", test_dir);
let _ = fs::remove_dir_all(test_dir);
fs::create_dir_all(test_dir).unwrap();
{
let conn = Connection::open(&db_path).unwrap();
crate::schema::ensure_schema(&conn).unwrap();
conn.execute(
"INSERT INTO hnsw_indexes (name, dimension, m, ef_construction, distance_metric, vector_count, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
rusqlite::params!["e2e_test", 3, 16, 200, "euclidean", 5, 1000, 1000],
).unwrap();
let index_id = conn.last_insert_rowid();
for i in 0..5 {
let vector = vec![i as f32, (i * 2) as f32, (i * 3) as f32];
let vector_bytes = bytemuck::cast_slice::<f32, u8>(&vector).to_vec();
let metadata = serde_json::json!({"label": format!("vector_{}", i)}).to_string();
conn.execute(
"INSERT INTO hnsw_vectors (index_id, vector_data, metadata, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5)",
rusqlite::params![index_id, vector_bytes, metadata, 1000, 1000],
).unwrap();
}
}
{
let graph = SqliteGraph::open(&db_path).unwrap();
let index_names = graph.list_hnsw_indexes().unwrap();
assert_eq!(index_names, vec!["e2e_test".to_string()]);
let loaded_count = graph
.get_hnsw_index_ref("e2e_test", |hnsw| {
assert_eq!(hnsw.vector_count(), 5);
let (vector, metadata) = hnsw.get_vector(1).unwrap().unwrap();
assert_eq!(vector, vec![0.0, 0.0, 0.0]);
assert_eq!(metadata, serde_json::json!({"label": "vector_0"}));
let query = vec![2.0, 4.0, 6.0];
let results = hnsw.search(&query, 3).unwrap();
assert!(!results.is_empty());
hnsw.vector_count()
})
.unwrap();
assert_eq!(loaded_count, 5);
}
let _ = fs::remove_dir_all(test_dir);
}
#[test]
fn test_multilayer_level_distribution() {
let config = HnswConfig {
dimension: 4,
m: 16,
ef_construction: 200,
ef_search: 50,
ml: 4,
distance_metric: DistanceMetric::Euclidean,
enable_multilayer: true,
multilayer_level_distribution_base: Some(16),
multilayer_deterministic_seed: Some(42),
};
let hnsw = HnswIndex::new("test_multilayer_dist", config).unwrap();
assert!(
hnsw.has_level_distributor(),
"LevelDistributor should be initialized in multi-layer mode"
);
use crate::hnsw::multilayer::LevelDistributor;
let mut distributor = LevelDistributor::new(16.0, 4).with_seed(42);
let mut level_counts = [0; 4];
for _ in 0..1000 {
let level = distributor.sample_level_internal();
level_counts[level] += 1;
}
assert!(
level_counts[0] >= 900 && level_counts[0] <= 950,
"Level 0 should have ~938 samples, got {}",
level_counts[0]
);
assert!(
level_counts[1] >= 40 && level_counts[1] <= 80,
"Level 1 should have ~62 samples, got {}",
level_counts[1]
);
assert!(
level_counts[2] >= 1 && level_counts[2] <= 10,
"Level 2 should have ~4 samples, got {}",
level_counts[2]
);
println!(
"Level distribution (direct sampling): L0={}, L1={}, L2={}, L3={}",
level_counts[0], level_counts[1], level_counts[2], level_counts[3]
);
}
#[test]
fn test_single_layer_mode() {
let config = HnswConfig {
dimension: 4,
m: 16,
ef_construction: 200,
ef_search: 50,
ml: 4,
distance_metric: DistanceMetric::Euclidean,
enable_multilayer: false, multilayer_level_distribution_base: None,
multilayer_deterministic_seed: None,
};
let hnsw = HnswIndex::new("test_single_layer", config.clone()).unwrap();
assert!(
!hnsw.has_level_distributor(),
"LevelDistributor should not be initialized in single-layer mode"
);
let test_vector = vec![1.0, 0.0, 0.0, 0.0];
let mut hnsw_mut = HnswIndex::new("test_single_layer_mut", config).unwrap();
for _ in 0..100 {
hnsw_mut.insert_vector(&test_vector, None).unwrap();
}
let stats = hnsw_mut.statistics().unwrap();
assert_eq!(
stats.layer_stats[0].0, 100,
"Layer 0 should have 100 vectors"
);
assert_eq!(
stats.layer_stats[1].0, 0,
"Layer 1 should be empty in single-layer mode"
);
assert_eq!(
stats.layer_stats[2].0, 0,
"Layer 2 should be empty in single-layer mode"
);
assert_eq!(
stats.layer_stats[3].0, 0,
"Layer 3 should be empty in single-layer mode"
);
}
#[test]
fn test_multilayer_recall() {
use std::collections::HashSet;
let config = HnswConfig {
dimension: 64,
m: 16,
ef_construction: 200,
ef_search: 200,
ml: 16,
distance_metric: DistanceMetric::Euclidean,
enable_multilayer: true, multilayer_level_distribution_base: Some(16),
multilayer_deterministic_seed: Some(42),
};
let mut hnsw = HnswIndex::new("recall_test_unique", config).unwrap();
let mut vectors = Vec::new();
for i in 0..1000 {
let vector: Vec<f32> = (0..64)
.map(|j| ((i * 64 + j) as f32 * 0.01).cos())
.collect();
vectors.push(vector.clone());
hnsw.insert_vector(&vector, None).unwrap();
}
let k = 10;
let query = &vectors[0];
let hnsw_results = hnsw.search(query, k).unwrap();
let hnsw_ids: HashSet<_> = hnsw_results.iter().map(|(id, _)| *id).collect();
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
let mut exact_results: Vec<_> = vectors
.iter()
.enumerate()
.map(|(i, v)| (i as u64 + 1, euclidean_distance(query, v)))
.collect();
for i in 0..exact_results.len() {
let mut min_idx = i;
for j in (i + 1)..exact_results.len() {
if exact_results[j].1 < exact_results[min_idx].1 {
min_idx = j;
}
}
if min_idx != i {
exact_results.swap(i, min_idx);
}
}
let exact_ids: HashSet<_> = exact_results.iter().take(k).map(|(id, _)| *id).collect();
let overlap = hnsw_ids.intersection(&exact_ids).count();
let recall = (overlap as f64 / k as f64) * 100.0;
println!("HNSW results: {:?}", hnsw_results);
println!("Exact top {}: {:?}", k, exact_ids);
println!("Recall: {:.1}% ({}/{})", recall, overlap, k);
assert!(
recall >= 90.0,
"Recall {:.1}% is below 90% threshold",
recall
);
}
#[test]
#[ignore = "flaky: fails non-deterministically when run with all lib tests due to HNSW test pollution / NodeNotFound bug"]
fn test_multilayer_search_complexity_ologn() {
use std::time::Instant;
let sizes = vec![100, 1000, 10000];
let mut search_times = Vec::new();
for size in sizes {
let config = HnswConfig {
dimension: 64,
m: 16,
ef_construction: 200,
ef_search: 50,
ml: 16,
distance_metric: DistanceMetric::Euclidean,
enable_multilayer: true,
multilayer_level_distribution_base: Some(16),
multilayer_deterministic_seed: Some(42),
};
let mut hnsw = HnswIndex::new(&format!("complexity_test_{}", size), config).unwrap();
for i in 0..size {
let vector: Vec<f32> = (0..64)
.map(|j| ((i * 64 + j) as f32 * 0.01).sin())
.collect();
hnsw.insert_vector(&vector, None).unwrap();
}
let query: Vec<f32> = (0..64).map(|j| (j as f32 * 0.01).sin()).collect();
let iterations = 10;
let start = Instant::now();
for _ in 0..iterations {
let _ = hnsw.search(&query, 10).unwrap();
}
let elapsed = start.elapsed();
let avg_time_ns = elapsed.as_nanos() / iterations as u128;
search_times.push((size, avg_time_ns));
println!("Size {}: avg search time = {} ns", size, avg_time_ns);
}
let ratio_100_to_1000 = search_times[1].1 as f64 / search_times[0].1 as f64;
println!("Time ratio (1000/100): {:.2}x", ratio_100_to_1000);
assert!(
ratio_100_to_1000 < 10.0,
"Search time ratio {:.2}x suggests worse than log scaling; expected < 10x for O(log N)",
ratio_100_to_1000
);
let ratio_1000_to_10000 = search_times[2].1 as f64 / search_times[1].1 as f64;
println!("Time ratio (10000/1000): {:.2}x", ratio_1000_to_10000);
assert!(
ratio_1000_to_10000 < 10.0,
"Search time ratio {:.2}x suggests worse than log scaling; expected < 10x for O(log N)",
ratio_1000_to_10000
);
let overall_ratio = search_times[2].1 as f64 / search_times[0].1 as f64;
println!("Overall time ratio (10000/100): {:.2}x", overall_ratio);
assert!(
overall_ratio < 50.0,
"Overall search time ratio {:.2}x suggests linear scaling; expected < 50x for O(log N) (linear would be 100x)",
overall_ratio
);
}
#[test]
fn test_multilayer_insert_layers_correct() {
let config = HnswConfig {
dimension: 64,
m: 16,
ef_construction: 200,
ef_search: 50,
ml: 16,
distance_metric: DistanceMetric::Euclidean,
enable_multilayer: true,
multilayer_level_distribution_base: Some(16),
multilayer_deterministic_seed: Some(42),
};
let mut hnsw = HnswIndex::new("test_layers", config).unwrap();
for i in 0..100 {
let vector: Vec<f32> = (0..64)
.map(|j| ((i * 64 + j) as f32 * 0.01).cos())
.collect();
hnsw.insert_vector(&vector, None).unwrap();
}
let stats = hnsw.statistics().unwrap();
println!("Layer stats: {:?}", stats.layer_stats);
assert_eq!(
stats.layer_stats[0].0, 100,
"Layer 0 should have all 100 vectors"
);
let layer1_count = stats.layer_stats[1].0;
assert!(
layer1_count > 0 && layer1_count < 20,
"Layer 1 should have some vectors (got {}), but not all",
layer1_count
);
assert!(
stats.layer_stats[0].0 >= stats.layer_stats[1].0,
"Layer 0 should have >= Layer 1"
);
assert!(
stats.layer_stats[1].0 >= stats.layer_stats[2].0,
"Layer 1 should have >= Layer 2"
);
assert!(
hnsw.has_level_distributor(),
"LevelDistributor should be initialized"
);
}
#[test]
fn test_topology_persistence_cross_session() {
use std::fs;
let test_dir = "/tmp/test_hnsw_topology_cross_session";
let db_path = format!("{}/test.db", test_dir);
let _ = fs::remove_dir_all(test_dir);
fs::create_dir_all(test_dir).unwrap();
let vectors: Vec<Vec<f32>> = (0..20)
.map(|i| (0..8).map(|j| ((i * 8 + j) as f32 * 0.1).sin()).collect())
.collect();
let session1_results = {
let graph = SqliteGraph::open(&db_path).unwrap();
let config = HnswConfigBuilder::new()
.dimension(8)
.m_connections(4)
.distance_metric(DistanceMetric::Euclidean)
.build()
.unwrap();
{
let _indexes = graph.hnsw_index_persistent("topo_test", config).unwrap();
}
for vector in &vectors {
graph
.get_hnsw_index_mut("topo_test", |idx| idx.insert_vector(vector, None).unwrap())
.unwrap();
}
let query = vectors[0].clone();
graph
.get_hnsw_index_ref("topo_test", |idx| {
idx.search(&query, 5)
.unwrap()
.into_iter()
.map(|(id, dist)| (id, dist.to_bits()))
.collect::<Vec<_>>()
})
.unwrap()
};
assert!(
!session1_results.is_empty(),
"session 1 should return results"
);
let session2_results = {
let graph = SqliteGraph::open(&db_path).unwrap();
let query = vectors[0].clone();
graph
.get_hnsw_index_ref("topo_test", |idx| {
let stats = idx.statistics().unwrap();
assert!(
stats.vector_count > 0,
"restored index should have vectors, got {}",
stats.vector_count
);
idx.search(&query, 5)
.unwrap()
.into_iter()
.map(|(id, dist)| (id, dist.to_bits()))
.collect::<Vec<_>>()
})
.unwrap()
};
assert!(
!session2_results.is_empty(),
"session 2 should return results after restore"
);
assert_eq!(
session1_results[0].0, session2_results[0].0,
"cross-session top result must match: session1={:?} session2={:?}",
session1_results, session2_results
);
let _ = fs::remove_dir_all(test_dir);
}
#[test]
fn test_batch_insert_speed_large_index() {
use crate::hnsw::{DistanceMetric, HnswConfigBuilder, HnswIndex};
let config = HnswConfigBuilder::new()
.dimension(768)
.distance_metric(DistanceMetric::Cosine)
.build()
.unwrap();
let mut index = HnswIndex::new("bench", config).unwrap();
let dummy = vec![0.5f32; 768];
for batch_idx in 0..(50_000 / 256) {
let batch: Vec<_> = (0..256)
.map(|i| {
let mut v = dummy.clone();
v[0] = ((batch_idx * 256 + i) as f32) * 0.00001;
(v, None)
})
.collect();
index.batch_insert_vectors(&batch).unwrap();
}
let start = std::time::Instant::now();
let batch: Vec<_> = (0..16)
.map(|i| {
let mut v = dummy.clone();
v[0] = (50_000 + i) as f32 * 0.00001;
(v, None)
})
.collect();
index.batch_insert_vectors(&batch).unwrap();
let elapsed = start.elapsed();
println!(
"Batch insert into 50K in-memory index: {:?} for 16 vectors = {:.1} vectors/sec",
elapsed,
16.0 / elapsed.as_secs_f64()
);
}
#[test]
fn test_multi_index_coexistence() {
use std::fs;
let test_dir = "/tmp/test_hnsw_multi_index_coexist";
let db_path = format!("{}/test.db", test_dir);
let _ = fs::remove_dir_all(test_dir);
fs::create_dir_all(test_dir).unwrap();
{
let graph = SqliteGraph::open(&db_path).unwrap();
let config_a = HnswConfigBuilder::new()
.dimension(4)
.m_connections(4)
.distance_metric(DistanceMetric::Euclidean)
.build()
.unwrap();
{
let _indexes = graph.hnsw_index_persistent("index_a", config_a).unwrap();
}
for i in 0..10u32 {
let v = vec![i as f32, (i * 2) as f32, (i * 3) as f32, (i * 4) as f32];
graph
.get_hnsw_index_mut("index_a", |idx| idx.insert_vector(&v, None).unwrap())
.unwrap();
}
let config_b = HnswConfigBuilder::new()
.dimension(4)
.m_connections(4)
.distance_metric(DistanceMetric::Euclidean)
.build()
.unwrap();
{
let _indexes_b = graph.hnsw_index_persistent("index_b", config_b).unwrap();
}
for i in 100..110u32 {
let v = vec![
(i as f32).sin(),
(i as f32).cos(),
(i as f32).tan(),
(i as f32 * 0.5),
];
graph
.get_hnsw_index_mut("index_b", |idx| idx.insert_vector(&v, None).unwrap())
.unwrap();
}
let results_a = graph
.get_hnsw_index_ref("index_a", |idx| {
idx.search(&[1.0, 2.0, 3.0, 4.0], 3).unwrap()
})
.unwrap();
assert!(!results_a.is_empty(), "index_a should have results");
let results_b = graph
.get_hnsw_index_ref("index_b", |idx| {
idx.search(&[0.0, 1.0, 0.0, 50.0], 3).unwrap()
})
.unwrap();
assert!(!results_b.is_empty(), "index_b should have results");
let ids_a: std::collections::HashSet<u64> =
results_a.iter().map(|(id, _)| *id).collect();
let ids_b: std::collections::HashSet<u64> =
results_b.iter().map(|(id, _)| *id).collect();
assert!(
ids_a.intersection(&ids_b).count() == 0,
"indices must not share vector IDs: A={:?} B={:?}",
ids_a,
ids_b
);
}
{
let graph2 = SqliteGraph::open(&db_path).unwrap();
let restored_a = graph2
.get_hnsw_index_ref("index_a", |idx| {
let stats = idx.statistics().unwrap();
assert!(
stats.vector_count == 10,
"restored index_a should have 10 vectors, got {}",
stats.vector_count
);
idx.search(&[1.0, 2.0, 3.0, 4.0], 3).unwrap()
})
.unwrap();
assert!(
!restored_a.is_empty(),
"restored index_a should have results"
);
}
let _ = fs::remove_dir_all(test_dir);
}
#[test]
fn test_batch_insert_empty() {
let config = HnswConfigBuilder::new()
.dimension(3)
.distance_metric(DistanceMetric::Euclidean)
.build()
.unwrap();
let mut hnsw = HnswIndex::new("batch_empty", config).unwrap();
let ids = hnsw.batch_insert_vectors(&[]).unwrap();
assert!(ids.is_empty());
}
#[test]
fn test_batch_insert_basic() {
let config = HnswConfigBuilder::new()
.dimension(3)
.distance_metric(DistanceMetric::Euclidean)
.build()
.unwrap();
let mut hnsw = HnswIndex::new("batch_basic", config).unwrap();
let batch: Vec<(Vec<f32>, Option<serde_json::Value>)> = vec![
(vec![1.0, 0.0, 0.0], None),
(
vec![0.0, 1.0, 0.0],
Some(serde_json::json!({"label": "y-axis"})),
),
(vec![0.0, 0.0, 1.0], None),
];
let ids = hnsw.batch_insert_vectors(&batch).unwrap();
assert_eq!(ids.len(), 3);
assert_eq!(hnsw.statistics().unwrap().vector_count, 3);
let results = hnsw.search(&[0.9, 0.1, 0.0], 2).unwrap();
assert!(!results.is_empty());
}
#[test]
fn test_batch_insert_dimension_mismatch() {
let config = HnswConfigBuilder::new()
.dimension(3)
.distance_metric(DistanceMetric::Euclidean)
.build()
.unwrap();
let mut hnsw = HnswIndex::new("batch_dim", config).unwrap();
let batch: Vec<(Vec<f32>, Option<serde_json::Value>)> = vec![
(vec![1.0, 0.0, 0.0], None),
(vec![0.0, 1.0], None), ];
let result = hnsw.batch_insert_vectors(&batch);
assert!(result.is_err());
assert_eq!(hnsw.statistics().unwrap().vector_count, 0);
}
#[test]
fn test_batch_insert_vs_individual_equivalence() {
let config = HnswConfigBuilder::new()
.dimension(4)
.distance_metric(DistanceMetric::Cosine)
.m_connections(8)
.build()
.unwrap();
let vectors: Vec<(Vec<f32>, Option<serde_json::Value>)> = (0..20)
.map(|i| {
let v = vec![
(i as f32).cos(),
(i as f32).sin(),
((i + 1) as f32).cos(),
((i + 1) as f32).sin(),
];
(v, Some(serde_json::json!({"idx": i})))
})
.collect();
let mut hnsw_individual = HnswIndex::new("batch_vs_ind", config.clone()).unwrap();
for (vec, meta) in &vectors {
hnsw_individual.insert_vector(vec, meta.clone()).unwrap();
}
let mut hnsw_batch = HnswIndex::new("batch_vs_batch", config).unwrap();
let batch_ids = hnsw_batch.batch_insert_vectors(&vectors).unwrap();
assert_eq!(batch_ids.len(), 20);
assert_eq!(
hnsw_individual.statistics().unwrap().vector_count,
hnsw_batch.statistics().unwrap().vector_count
);
let query = vec![1.0, 0.0, 0.9, 0.1];
let results_ind = hnsw_individual.search(&query, 5).unwrap();
let results_batch = hnsw_batch.search(&query, 5).unwrap();
assert!(!results_ind.is_empty());
assert!(!results_batch.is_empty());
}
#[test]
fn test_batch_insert_with_persistence() {
let graph = SqliteGraph::open_in_memory().unwrap();
let config = HnswConfigBuilder::new()
.dimension(3)
.distance_metric(DistanceMetric::Cosine)
.build()
.unwrap();
{
let _guard = graph.hnsw_index("test_idx", config).unwrap();
}
let batch: Vec<(Vec<f32>, Option<serde_json::Value>)> = vec![
(vec![1.0, 0.0, 0.0], Some(serde_json::json!({"id": 1}))),
(vec![0.0, 1.0, 0.0], Some(serde_json::json!({"id": 2}))),
(vec![0.0, 0.0, 1.0], Some(serde_json::json!({"id": 3}))),
];
let ids = graph
.get_hnsw_index_mut("test_idx", |idx| idx.batch_insert_vectors(&batch).unwrap())
.unwrap();
assert_eq!(ids.len(), 3);
let results = graph
.get_hnsw_index_ref("test_idx", |idx| {
let stats = idx.statistics().unwrap();
assert_eq!(stats.vector_count, 3);
idx.search(&[1.0, 0.0, 0.0], 2).unwrap()
})
.unwrap();
assert!(!results.is_empty(), "search should find results");
}
#[test]
fn test_batch_insert_uses_transaction_for_sqlite() {
use std::fs;
let test_dir = "/tmp/test_hnsw_batch_tx";
let db_path = format!("{}/test.db", test_dir);
let _ = fs::remove_dir_all(test_dir);
fs::create_dir_all(test_dir).unwrap();
let n = 500usize;
let batch: Vec<(Vec<f32>, Option<serde_json::Value>)> = (0..n)
.map(|i| {
let f = i as f32;
(
vec![f.sin(), f.cos(), f * 0.01, 1.0 - f * 0.005],
Some(serde_json::json!({"seq": i})),
)
})
.collect();
let config = HnswConfigBuilder::new()
.dimension(4)
.distance_metric(DistanceMetric::Cosine)
.build()
.unwrap();
let graph = SqliteGraph::open(&db_path).unwrap();
{
let _guard = graph.hnsw_index_persistent("tx_test", config).unwrap();
}
let start = std::time::Instant::now();
let ids = graph
.get_hnsw_index_mut("tx_test", |idx| idx.batch_insert_vectors(&batch).unwrap())
.unwrap();
let elapsed = start.elapsed();
assert_eq!(ids.len(), n, "all {n} vectors must be inserted");
graph
.get_hnsw_index_ref("tx_test", |idx| {
assert_eq!(
idx.statistics().unwrap().vector_count,
n,
"vector_count must equal batch size"
);
})
.unwrap();
assert!(
elapsed.as_millis() < 500,
"batch_insert_vectors took {}ms for {n} vectors — missing transaction wrapper",
elapsed.as_millis()
);
let results = graph
.get_hnsw_index_ref("tx_test", |idx| {
idx.search(&[0.0, 1.0, 0.0, 1.0], 5).unwrap()
})
.unwrap();
assert!(
!results.is_empty(),
"search must find results after batch insert"
);
let _ = fs::remove_dir_all(test_dir);
}
}
fn available_memory_bytes() -> u64 {
#[cfg(target_os = "linux")]
{
use std::fs;
if let Ok(content) = fs::read_to_string("/proc/meminfo") {
for line in content.lines() {
if line.starts_with("MemAvailable:") {
let kb: u64 = line
.split_whitespace()
.nth(1)
.and_then(|v| v.parse().ok())
.unwrap_or(0);
return kb * 1024;
}
}
}
}
#[cfg(target_os = "macos")]
{
if let Ok(output) = std::process::Command::new("sysctl")
.args(["-n", "hw.memsize"])
.output()
{
if let Ok(s) = String::from_utf8(output.stdout) {
if let Ok(total) = s.trim().parse::<u64>() {
return total / 2;
}
}
}
}
0
}