1use std::sync::RwLock;
7
8use manifoldb_core::EntityId;
9use manifoldb_storage::{StorageEngine, Transaction};
10
11use crate::distance::DistanceMetric;
12use crate::error::VectorError;
13use crate::types::Embedding;
14
15use super::config::HnswConfig;
16use super::graph::{
17 search_layer, search_layer_filtered, select_neighbors_heuristic, Candidate, HnswGraph, HnswNode,
18};
19use super::persistence::{
20 self, delete_node, load_graph, load_metadata, save_graph, save_metadata, save_node, table_name,
21 update_connections, IndexMetadata,
22};
23use super::traits::{FilteredSearchConfig, SearchResult, VectorIndex};
24
25struct LevelGenerator {
30 ml: f64,
31 rng_state: u64,
32}
33
34impl LevelGenerator {
35 #[allow(clippy::cast_possible_truncation)] fn new(ml: f64) -> Self {
37 let seed = std::time::SystemTime::now()
39 .duration_since(std::time::UNIX_EPOCH)
40 .map(|d| d.as_nanos() as u64)
41 .unwrap_or(12345);
42 Self { ml, rng_state: seed }
43 }
44
45 #[allow(clippy::cast_precision_loss)] #[allow(clippy::cast_possible_truncation)] #[allow(clippy::cast_sign_loss)] fn generate_level(&mut self) -> usize {
50 let mut x = self.rng_state;
52 x ^= x << 13;
53 x ^= x >> 7;
54 x ^= x << 17;
55 self.rng_state = x;
56
57 let uniform = (x as f64) / (u64::MAX as f64);
59
60 let level = (-uniform.ln() * self.ml).floor() as usize;
63
64 level.min(16)
66 }
67}
68
69pub struct HnswIndex<E: StorageEngine> {
74 engine: E,
76 table: String,
78 graph: RwLock<HnswGraph>,
80 config: HnswConfig,
82 level_gen: RwLock<LevelGenerator>,
84}
85
86impl<E: StorageEngine> HnswIndex<E> {
87 pub fn new(
100 engine: E,
101 name: &str,
102 dimension: usize,
103 distance_metric: DistanceMetric,
104 config: HnswConfig,
105 ) -> Result<Self, VectorError> {
106 let table = table_name(name);
107
108 if let Some(metadata) = load_metadata(&engine, &table)? {
110 if metadata.dimension != dimension {
112 return Err(VectorError::DimensionMismatch {
113 expected: metadata.dimension,
114 actual: dimension,
115 });
116 }
117 if metadata.distance_metric != distance_metric {
118 return Err(VectorError::Encoding(format!(
119 "distance metric mismatch: stored {:?}, requested {:?}",
120 metadata.distance_metric, distance_metric
121 )));
122 }
123
124 let graph = load_graph(&engine, &table, &metadata)?;
126
127 let config = HnswConfig {
129 m: metadata.m,
130 m_max0: metadata.m_max0,
131 ef_construction: metadata.ef_construction,
132 ef_search: metadata.ef_search,
133 ml: f64::from_bits(metadata.ml_bits),
134 pq_segments: metadata.pq_segments,
135 pq_centroids: metadata.pq_centroids,
136 pq_training_samples: 1000, };
138
139 return Ok(Self {
140 engine,
141 table,
142 graph: RwLock::new(graph),
143 config: config.clone(),
144 level_gen: RwLock::new(LevelGenerator::new(config.ml)),
145 });
146 }
147
148 let graph = HnswGraph::new(dimension, distance_metric);
150
151 let metadata = IndexMetadata {
153 dimension,
154 distance_metric,
155 entry_point: None,
156 max_layer: 0,
157 m: config.m,
158 m_max0: config.m_max0,
159 ef_construction: config.ef_construction,
160 ef_search: config.ef_search,
161 ml_bits: config.ml.to_bits(),
162 pq_segments: config.pq_segments,
163 pq_centroids: config.pq_centroids,
164 };
165 save_metadata(&engine, &table, &metadata)?;
166
167 Ok(Self {
168 engine,
169 table,
170 graph: RwLock::new(graph),
171 config: config.clone(),
172 level_gen: RwLock::new(LevelGenerator::new(config.ml)),
173 })
174 }
175
176 pub fn open(engine: E, name: &str) -> Result<Self, VectorError> {
180 let table = table_name(name);
181
182 let metadata = load_metadata(&engine, &table)?.ok_or_else(|| {
183 VectorError::SpaceNotFound(format!("HNSW index '{}' not found", name))
184 })?;
185
186 let graph = load_graph(&engine, &table, &metadata)?;
187
188 let config = HnswConfig {
189 m: metadata.m,
190 m_max0: metadata.m_max0,
191 ef_construction: metadata.ef_construction,
192 ef_search: metadata.ef_search,
193 ml: f64::from_bits(metadata.ml_bits),
194 pq_segments: metadata.pq_segments,
195 pq_centroids: metadata.pq_centroids,
196 pq_training_samples: 1000, };
198
199 Ok(Self {
200 engine,
201 table,
202 graph: RwLock::new(graph),
203 config: config.clone(),
204 level_gen: RwLock::new(LevelGenerator::new(config.ml)),
205 })
206 }
207
208 #[must_use]
210 pub fn config(&self) -> &HnswConfig {
211 &self.config
212 }
213
214 pub fn distance_metric(&self) -> Result<DistanceMetric, VectorError> {
221 let graph = self.graph.read().map_err(|_| VectorError::LockPoisoned)?;
222 Ok(graph.distance_metric)
223 }
224
225 pub fn flush(&self) -> Result<(), VectorError> {
234 let graph = self.graph.read().map_err(|_| VectorError::LockPoisoned)?;
235 save_graph(&self.engine, &self.table, &graph, &self.config)?;
236 Ok(())
237 }
238
239 fn insert_internal(
243 &self,
244 graph: &mut HnswGraph,
245 entity_id: EntityId,
246 embedding: &Embedding,
247 ) -> Result<(), VectorError> {
248 let node_level =
250 self.level_gen.write().map_err(|_| VectorError::LockPoisoned)?.generate_level();
251
252 let new_node = HnswNode::new(entity_id, embedding.clone(), node_level);
254
255 if graph.is_empty() {
257 graph.insert_node(new_node);
258 save_node(
259 &self.engine,
260 &self.table,
261 graph.get_node(entity_id).ok_or(VectorError::NodeNotFound(entity_id))?,
262 )?;
263 self.update_metadata(graph)?;
264 return Ok(());
265 }
266
267 let entry_point = graph
269 .entry_point
270 .ok_or(VectorError::InvalidGraphState("entry_point missing in non-empty graph"))?;
271 let current_max_layer = graph.max_layer;
272
273 let mut current_ep = vec![entry_point];
275
276 for layer in (node_level + 1..=current_max_layer).rev() {
277 let candidates = search_layer(graph, embedding, ¤t_ep, 1, layer);
278 current_ep = candidates.into_iter().map(|c| c.entity_id).collect();
279 if current_ep.is_empty() {
280 current_ep = vec![entry_point];
281 }
282 }
283
284 graph.insert_node(new_node);
286
287 let start_layer = node_level.min(current_max_layer);
289
290 for layer in (0..=start_layer).rev() {
291 let candidates =
293 search_layer(graph, embedding, ¤t_ep, self.config.ef_construction, layer);
294
295 let m = if layer == 0 { self.config.m_max0 } else { self.config.m };
297 let neighbors = select_neighbors_heuristic(graph, embedding, &candidates, m, false);
298
299 if let Some(node) = graph.get_node_mut(entity_id) {
301 node.set_connections(layer, neighbors.clone());
302 }
303
304 let mut neighbors_to_prune = Vec::new();
306 let max_conn = if layer == 0 { self.config.m_max0 } else { self.config.m };
307
308 for &neighbor_id in &neighbors {
309 if let Some(neighbor) = graph.get_node_mut(neighbor_id) {
310 neighbor.add_connection(layer, entity_id);
311
312 if neighbor.connections_at(layer).len() > max_conn {
314 let neighbor_conn_ids: Vec<EntityId> =
316 neighbor.connections_at(layer).to_vec();
317 let neighbor_embedding = neighbor.embedding.clone();
318 neighbors_to_prune.push((
319 neighbor_id,
320 neighbor_conn_ids,
321 neighbor_embedding,
322 ));
323 }
324 }
325 }
326
327 for (neighbor_id, neighbor_conn_ids, neighbor_embedding) in neighbors_to_prune {
329 let neighbor_connections: Vec<Candidate> = neighbor_conn_ids
331 .iter()
332 .filter_map(|&id| {
333 graph.get_node(id).map(|n| {
334 Candidate::new(id, graph.distance(&neighbor_embedding, &n.embedding))
335 })
336 })
337 .collect();
338
339 let pruned = select_neighbors_heuristic(
341 graph,
342 &neighbor_embedding,
343 &neighbor_connections,
344 max_conn,
345 false,
346 );
347
348 if let Some(neighbor) = graph.get_node_mut(neighbor_id) {
350 neighbor.set_connections(layer, pruned.clone());
351 }
352
353 update_connections(&self.engine, &self.table, neighbor_id, layer, &pruned)?;
355 }
356
357 current_ep = candidates.into_iter().map(|c| c.entity_id).collect();
359 if current_ep.is_empty() && !neighbors.is_empty() {
360 current_ep = neighbors;
361 }
362 }
363
364 save_node(
366 &self.engine,
367 &self.table,
368 graph.get_node(entity_id).ok_or(VectorError::NodeNotFound(entity_id))?,
369 )?;
370
371 for layer in 0..=start_layer {
373 if let Some(node) = graph.get_node(entity_id) {
374 for &neighbor_id in node.connections_at(layer) {
375 if let Some(neighbor) = graph.get_node(neighbor_id) {
376 update_connections(
377 &self.engine,
378 &self.table,
379 neighbor_id,
380 layer,
381 neighbor.connections_at(layer),
382 )?;
383 }
384 }
385 }
386 }
387
388 if node_level > current_max_layer {
390 graph.entry_point = Some(entity_id);
391 graph.max_layer = node_level;
392 self.update_metadata(graph)?;
393 }
394
395 Ok(())
396 }
397
398 fn update_metadata(&self, graph: &HnswGraph) -> Result<(), VectorError> {
400 let metadata = IndexMetadata {
401 dimension: graph.dimension,
402 distance_metric: graph.distance_metric,
403 entry_point: graph.entry_point,
404 max_layer: graph.max_layer,
405 m: self.config.m,
406 m_max0: self.config.m_max0,
407 ef_construction: self.config.ef_construction,
408 ef_search: self.config.ef_search,
409 ml_bits: self.config.ml.to_bits(),
410 pq_segments: self.config.pq_segments,
411 pq_centroids: self.config.pq_centroids,
412 };
413 persistence::save_metadata(&self.engine, &self.table, &metadata)?;
414 Ok(())
415 }
416
417 fn insert_batch_internal(
422 &self,
423 graph: &mut HnswGraph,
424 embeddings: &[(EntityId, &Embedding)],
425 ) -> Result<(), VectorError> {
426 let mut new_nodes: Vec<(EntityId, &Embedding, usize)> =
428 Vec::with_capacity(embeddings.len());
429
430 {
431 let mut level_gen = self.level_gen.write().map_err(|_| VectorError::LockPoisoned)?;
432 for (entity_id, embedding) in embeddings {
433 let node_level = level_gen.generate_level();
434 new_nodes.push((*entity_id, embedding, node_level));
435 }
436 }
437
438 let mut affected_neighbors: std::collections::HashSet<EntityId> =
441 std::collections::HashSet::new();
442
443 for (entity_id, embedding, node_level) in &new_nodes {
444 let new_node = HnswNode::new(*entity_id, (*embedding).clone(), *node_level);
445
446 if graph.is_empty() {
448 graph.insert_node(new_node);
449 continue;
450 }
451
452 let entry_point = graph
454 .entry_point
455 .ok_or(VectorError::InvalidGraphState("entry_point missing in non-empty graph"))?;
456 let current_max_layer = graph.max_layer;
457
458 let mut current_ep = vec![entry_point];
460
461 for layer in (node_level + 1..=current_max_layer).rev() {
462 let candidates = search_layer(graph, embedding, ¤t_ep, 1, layer);
463 current_ep = candidates.into_iter().map(|c| c.entity_id).collect();
464 if current_ep.is_empty() {
465 current_ep = vec![entry_point];
466 }
467 }
468
469 graph.insert_node(new_node);
471
472 let start_layer = (*node_level).min(current_max_layer);
474
475 for layer in (0..=start_layer).rev() {
476 let candidates =
477 search_layer(graph, embedding, ¤t_ep, self.config.ef_construction, layer);
478
479 let m = if layer == 0 { self.config.m_max0 } else { self.config.m };
480 let neighbors = select_neighbors_heuristic(graph, embedding, &candidates, m, false);
481
482 if let Some(node) = graph.get_node_mut(*entity_id) {
484 node.set_connections(layer, neighbors.clone());
485 }
486
487 let max_conn = if layer == 0 { self.config.m_max0 } else { self.config.m };
489 let mut neighbors_to_prune = Vec::new();
490
491 for &neighbor_id in &neighbors {
492 affected_neighbors.insert(neighbor_id);
493 if let Some(neighbor) = graph.get_node_mut(neighbor_id) {
494 neighbor.add_connection(layer, *entity_id);
495
496 if neighbor.connections_at(layer).len() > max_conn {
497 let neighbor_conn_ids: Vec<EntityId> =
498 neighbor.connections_at(layer).to_vec();
499 let neighbor_embedding = neighbor.embedding.clone();
500 neighbors_to_prune.push((
501 neighbor_id,
502 neighbor_conn_ids,
503 neighbor_embedding,
504 ));
505 }
506 }
507 }
508
509 for (neighbor_id, neighbor_conn_ids, neighbor_embedding) in neighbors_to_prune {
511 let neighbor_connections: Vec<Candidate> = neighbor_conn_ids
512 .iter()
513 .filter_map(|&id| {
514 graph.get_node(id).map(|n| {
515 Candidate::new(
516 id,
517 graph.distance(&neighbor_embedding, &n.embedding),
518 )
519 })
520 })
521 .collect();
522
523 let pruned = select_neighbors_heuristic(
524 graph,
525 &neighbor_embedding,
526 &neighbor_connections,
527 max_conn,
528 false,
529 );
530
531 if let Some(neighbor) = graph.get_node_mut(neighbor_id) {
532 neighbor.set_connections(layer, pruned);
533 }
534 }
535
536 current_ep = candidates.into_iter().map(|c| c.entity_id).collect();
537 if current_ep.is_empty() && !neighbors.is_empty() {
538 current_ep = neighbors;
539 }
540 }
541
542 if *node_level > current_max_layer {
544 graph.entry_point = Some(*entity_id);
545 graph.max_layer = *node_level;
546 }
547 }
548
549 let mut tx = self.engine.begin_write()?;
551
552 for (entity_id, _, _) in &new_nodes {
554 if let Some(node) = graph.get_node(*entity_id) {
555 persistence::save_node_tx(&mut tx, &self.table, node)?;
556 }
557 }
558
559 for neighbor_id in affected_neighbors {
561 if let Some(neighbor) = graph.get_node(neighbor_id) {
562 for (layer, connections) in neighbor.connections.iter().enumerate() {
564 persistence::update_connections_tx(
565 &mut tx,
566 &self.table,
567 neighbor_id,
568 layer,
569 connections,
570 )?;
571 }
572 }
573 }
574
575 let metadata = IndexMetadata {
577 dimension: graph.dimension,
578 distance_metric: graph.distance_metric,
579 entry_point: graph.entry_point,
580 max_layer: graph.max_layer,
581 m: self.config.m,
582 m_max0: self.config.m_max0,
583 ef_construction: self.config.ef_construction,
584 ef_search: self.config.ef_search,
585 ml_bits: self.config.ml.to_bits(),
586 pq_segments: self.config.pq_segments,
587 pq_centroids: self.config.pq_centroids,
588 };
589 persistence::save_metadata_tx(&mut tx, &self.table, &metadata)?;
590
591 tx.commit()?;
593
594 Ok(())
595 }
596}
597
598impl<E: StorageEngine> VectorIndex for HnswIndex<E> {
599 fn insert(&mut self, entity_id: EntityId, embedding: &Embedding) -> Result<(), VectorError> {
600 let mut graph = self.graph.write().map_err(|_| VectorError::LockPoisoned)?;
602 if embedding.dimension() != graph.dimension {
603 return Err(VectorError::DimensionMismatch {
604 expected: graph.dimension,
605 actual: embedding.dimension(),
606 });
607 }
608
609 if graph.contains(entity_id) {
611 self.delete_internal(&mut graph, entity_id)?;
612 }
613
614 self.insert_internal(&mut graph, entity_id, embedding)
615 }
616
617 fn insert_batch(&mut self, embeddings: &[(EntityId, &Embedding)]) -> Result<(), VectorError> {
618 if embeddings.is_empty() {
619 return Ok(());
620 }
621
622 let mut graph = self.graph.write().map_err(|_| VectorError::LockPoisoned)?;
623
624 for (entity_id, embedding) in embeddings {
626 if embedding.dimension() != graph.dimension {
627 return Err(VectorError::DimensionMismatch {
628 expected: graph.dimension,
629 actual: embedding.dimension(),
630 });
631 }
632
633 if graph.contains(*entity_id) {
635 self.delete_internal(&mut graph, *entity_id)?;
636 }
637 }
638
639 self.insert_batch_internal(&mut graph, embeddings)
641 }
642
643 fn delete(&mut self, entity_id: EntityId) -> Result<bool, VectorError> {
644 let mut graph = self.graph.write().map_err(|_| VectorError::LockPoisoned)?;
645 self.delete_internal(&mut graph, entity_id)
646 }
647
648 fn search(
649 &self,
650 query: &Embedding,
651 k: usize,
652 ef_search: Option<usize>,
653 ) -> Result<Vec<SearchResult>, VectorError> {
654 let graph = self.graph.read().map_err(|_| VectorError::LockPoisoned)?;
655
656 if query.dimension() != graph.dimension {
658 return Err(VectorError::DimensionMismatch {
659 expected: graph.dimension,
660 actual: query.dimension(),
661 });
662 }
663
664 if graph.is_empty() {
665 return Ok(Vec::new());
666 }
667
668 let ef = ef_search.unwrap_or(self.config.ef_search).max(k);
670 let entry_point = graph
671 .entry_point
672 .ok_or(VectorError::InvalidGraphState("entry_point missing in non-empty graph"))?;
673
674 let mut current_ep = vec![entry_point];
676
677 for layer in (1..=graph.max_layer).rev() {
678 let candidates = search_layer(&graph, query, ¤t_ep, 1, layer);
679 current_ep = candidates.into_iter().map(|c| c.entity_id).collect();
680 if current_ep.is_empty() {
681 current_ep = vec![entry_point];
682 }
683 }
684
685 let candidates = search_layer(&graph, query, ¤t_ep, ef, 0);
687
688 let results: Vec<SearchResult> = candidates
690 .into_iter()
691 .take(k)
692 .map(|c| SearchResult::new(c.entity_id, c.distance))
693 .collect();
694
695 Ok(results)
696 }
697
698 fn search_with_filter<F>(
699 &self,
700 query: &Embedding,
701 k: usize,
702 predicate: F,
703 ef_search: Option<usize>,
704 config: Option<FilteredSearchConfig>,
705 ) -> Result<Vec<SearchResult>, VectorError>
706 where
707 F: Fn(EntityId) -> bool,
708 {
709 let graph = self.graph.read().map_err(|_| VectorError::LockPoisoned)?;
710
711 if query.dimension() != graph.dimension {
713 return Err(VectorError::DimensionMismatch {
714 expected: graph.dimension,
715 actual: query.dimension(),
716 });
717 }
718
719 if graph.is_empty() {
720 return Ok(Vec::new());
721 }
722
723 let filter_config = config.unwrap_or_default();
725
726 let base_ef = ef_search.unwrap_or(self.config.ef_search).max(k);
729 let ef = filter_config.adjusted_ef(base_ef, None);
730
731 let entry_point = graph
732 .entry_point
733 .ok_or(VectorError::InvalidGraphState("entry_point missing in non-empty graph"))?;
734
735 let mut current_ep = vec![entry_point];
738
739 for layer in (1..=graph.max_layer).rev() {
740 let candidates = search_layer(&graph, query, ¤t_ep, 1, layer);
741 current_ep = candidates.into_iter().map(|c| c.entity_id).collect();
742 if current_ep.is_empty() {
743 current_ep = vec![entry_point];
744 }
745 }
746
747 let candidates = search_layer_filtered(&graph, query, ¤t_ep, ef, 0, &predicate);
749
750 let results: Vec<SearchResult> = candidates
752 .into_iter()
753 .take(k)
754 .map(|c| SearchResult::new(c.entity_id, c.distance))
755 .collect();
756
757 Ok(results)
758 }
759
760 fn contains(&self, entity_id: EntityId) -> Result<bool, VectorError> {
761 let graph = self.graph.read().map_err(|_| VectorError::LockPoisoned)?;
762 Ok(graph.contains(entity_id))
763 }
764
765 fn len(&self) -> Result<usize, VectorError> {
766 let graph = self.graph.read().map_err(|_| VectorError::LockPoisoned)?;
767 Ok(graph.len())
768 }
769
770 fn dimension(&self) -> Result<usize, VectorError> {
771 let graph = self.graph.read().map_err(|_| VectorError::LockPoisoned)?;
772 Ok(graph.dimension)
773 }
774}
775
776impl<E: StorageEngine> HnswIndex<E> {
777 fn delete_internal(
779 &self,
780 graph: &mut HnswGraph,
781 entity_id: EntityId,
782 ) -> Result<bool, VectorError> {
783 let node = match graph.remove_node(entity_id) {
784 Some(n) => n,
785 None => return Ok(false),
786 };
787
788 delete_node(&self.engine, &self.table, entity_id, node.max_layer)?;
790
791 self.update_metadata(graph)?;
793
794 for layer in 0..=node.max_layer {
797 for &neighbor_id in &node.connections[layer] {
798 if let Some(neighbor) = graph.get_node(neighbor_id) {
799 update_connections(
800 &self.engine,
801 &self.table,
802 neighbor_id,
803 layer,
804 neighbor.connections_at(layer),
805 )?;
806 }
807 }
808 }
809
810 Ok(true)
811 }
812}
813
814#[cfg(test)]
815mod tests {
816 use super::*;
817 use manifoldb_storage::backends::RedbEngine;
818
819 fn create_test_embedding(dim: usize, value: f32) -> Embedding {
820 Embedding::new(vec![value; dim]).unwrap()
821 }
822
823 #[test]
824 fn test_create_index() {
825 let engine = RedbEngine::in_memory().unwrap();
826 let config = HnswConfig::default();
827 let index = HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
828
829 assert_eq!(index.dimension().unwrap(), 4);
830 assert_eq!(index.len().unwrap(), 0);
831 assert!(index.is_empty().unwrap());
832 }
833
834 #[test]
835 fn test_insert_single() {
836 let engine = RedbEngine::in_memory().unwrap();
837 let config = HnswConfig::default();
838 let mut index =
839 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
840
841 let embedding = create_test_embedding(4, 1.0);
842 index.insert(EntityId::new(1), &embedding).unwrap();
843
844 assert_eq!(index.len().unwrap(), 1);
845 assert!(index.contains(EntityId::new(1)).unwrap());
846 }
847
848 #[test]
849 fn test_insert_multiple() {
850 let engine = RedbEngine::in_memory().unwrap();
851 let config = HnswConfig::new(4);
852 let mut index =
853 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
854
855 for i in 0..10 {
856 let embedding = create_test_embedding(4, i as f32);
857 index.insert(EntityId::new(i), &embedding).unwrap();
858 }
859
860 assert_eq!(index.len().unwrap(), 10);
861 }
862
863 #[test]
864 fn test_search_empty() {
865 let engine = RedbEngine::in_memory().unwrap();
866 let config = HnswConfig::default();
867 let index = HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
868
869 let query = create_test_embedding(4, 1.0);
870 let results = index.search(&query, 5, None).unwrap();
871
872 assert!(results.is_empty());
873 }
874
875 #[test]
876 fn test_search_single() {
877 let engine = RedbEngine::in_memory().unwrap();
878 let config = HnswConfig::default();
879 let mut index =
880 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
881
882 let embedding = create_test_embedding(4, 1.0);
883 index.insert(EntityId::new(1), &embedding).unwrap();
884
885 let query = create_test_embedding(4, 1.0);
886 let results = index.search(&query, 1, None).unwrap();
887
888 assert_eq!(results.len(), 1);
889 assert_eq!(results[0].entity_id, EntityId::new(1));
890 assert!(results[0].distance < 1e-6); }
892
893 #[test]
894 fn test_search_nearest() {
895 let engine = RedbEngine::in_memory().unwrap();
896 let config = HnswConfig::new(4);
897 let mut index =
898 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
899
900 for i in 1..=5 {
902 let embedding = create_test_embedding(4, i as f32);
903 index.insert(EntityId::new(i as u64), &embedding).unwrap();
904 }
905
906 let query = create_test_embedding(4, 1.5);
908 let results = index.search(&query, 3, None).unwrap();
909
910 assert_eq!(results.len(), 3);
911 assert!(
913 results[0].entity_id == EntityId::new(1) || results[0].entity_id == EntityId::new(2)
914 );
915 }
916
917 #[test]
918 fn test_delete() {
919 let engine = RedbEngine::in_memory().unwrap();
920 let config = HnswConfig::new(4);
921 let mut index =
922 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
923
924 let embedding = create_test_embedding(4, 1.0);
925 index.insert(EntityId::new(1), &embedding).unwrap();
926
927 assert!(index.delete(EntityId::new(1)).unwrap());
928 assert!(!index.contains(EntityId::new(1)).unwrap());
929 assert_eq!(index.len().unwrap(), 0);
930
931 assert!(!index.delete(EntityId::new(999)).unwrap());
933 }
934
935 #[test]
936 fn test_update_embedding() {
937 let engine = RedbEngine::in_memory().unwrap();
938 let config = HnswConfig::default();
939 let mut index =
940 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
941
942 let embedding1 = create_test_embedding(4, 1.0);
943 index.insert(EntityId::new(1), &embedding1).unwrap();
944
945 let embedding2 = create_test_embedding(4, 10.0);
947 index.insert(EntityId::new(1), &embedding2).unwrap();
948
949 assert_eq!(index.len().unwrap(), 1);
950
951 let query = create_test_embedding(4, 10.0);
953 let results = index.search(&query, 1, None).unwrap();
954 assert_eq!(results.len(), 1);
955 assert!(results[0].distance < 1e-6);
956 }
957
958 #[test]
959 fn test_dimension_mismatch_insert() {
960 let engine = RedbEngine::in_memory().unwrap();
961 let config = HnswConfig::default();
962 let mut index =
963 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
964
965 let embedding = create_test_embedding(8, 1.0); let result = index.insert(EntityId::new(1), &embedding);
967
968 assert!(matches!(result, Err(VectorError::DimensionMismatch { .. })));
969 }
970
971 #[test]
972 fn test_dimension_mismatch_search() {
973 let engine = RedbEngine::in_memory().unwrap();
974 let config = HnswConfig::default();
975 let mut index =
976 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
977
978 let embedding = create_test_embedding(4, 1.0);
979 index.insert(EntityId::new(1), &embedding).unwrap();
980
981 let query = create_test_embedding(8, 1.0); let result = index.search(&query, 1, None);
983
984 assert!(matches!(result, Err(VectorError::DimensionMismatch { .. })));
985 }
986
987 #[test]
988 fn test_persistence() {
989 let temp_dir = std::env::temp_dir();
991 let db_path = temp_dir.join(format!("hnsw_persist_test_{}.redb", std::process::id()));
992
993 let _ = std::fs::remove_file(&db_path);
995
996 {
998 let engine = RedbEngine::open(&db_path).unwrap();
999 let config = HnswConfig::new(4);
1000 let mut index =
1001 HnswIndex::new(engine, "persist_test", 4, DistanceMetric::Euclidean, config)
1002 .unwrap();
1003
1004 for i in 0..5 {
1005 let embedding = create_test_embedding(4, i as f32);
1006 index.insert(EntityId::new(i), &embedding).unwrap();
1007 }
1008
1009 index.flush().unwrap();
1010 }
1011
1012 {
1014 let engine = RedbEngine::open(&db_path).unwrap();
1015 let index: HnswIndex<RedbEngine> = HnswIndex::open(engine, "persist_test").unwrap();
1016
1017 assert_eq!(index.len().unwrap(), 5);
1018 assert_eq!(index.dimension().unwrap(), 4);
1019
1020 for i in 0..5 {
1021 assert!(index.contains(EntityId::new(i)).unwrap());
1022 }
1023 }
1024
1025 let _ = std::fs::remove_file(&db_path);
1027 }
1028
1029 #[test]
1030 fn test_cosine_distance() {
1031 let engine = RedbEngine::in_memory().unwrap();
1032 let config = HnswConfig::new(4);
1033 let mut index = HnswIndex::new(engine, "test", 4, DistanceMetric::Cosine, config).unwrap();
1034
1035 let e1 = Embedding::new(vec![1.0, 0.0, 0.0, 0.0]).unwrap();
1037 let e2 = Embedding::new(vec![0.0, 1.0, 0.0, 0.0]).unwrap();
1038 let e3 = Embedding::new(vec![0.5, 0.5, 0.0, 0.0]).unwrap();
1039
1040 index.insert(EntityId::new(1), &e1).unwrap();
1041 index.insert(EntityId::new(2), &e2).unwrap();
1042 index.insert(EntityId::new(3), &e3).unwrap();
1043
1044 let query = Embedding::new(vec![2.0, 0.0, 0.0, 0.0]).unwrap();
1046 let results = index.search(&query, 3, None).unwrap();
1047
1048 assert_eq!(results[0].entity_id, EntityId::new(1));
1050 }
1051
1052 #[test]
1053 fn test_search_with_filter() {
1054 let engine = RedbEngine::in_memory().unwrap();
1055 let config = HnswConfig::new(4);
1056 let mut index =
1057 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
1058
1059 for i in 1..=10 {
1061 let embedding = create_test_embedding(4, i as f32);
1062 index.insert(EntityId::new(i as u64), &embedding).unwrap();
1063 }
1064
1065 let query = create_test_embedding(4, 1.5);
1067
1068 let predicate = |id: EntityId| id.as_u64() % 2 == 0;
1070
1071 let results = index.search_with_filter(&query, 3, predicate, None, None).unwrap();
1072
1073 assert!(!results.is_empty());
1075 for result in &results {
1076 assert_eq!(result.entity_id.as_u64() % 2, 0);
1077 }
1078
1079 assert_eq!(results[0].entity_id, EntityId::new(2));
1081 }
1082
1083 #[test]
1084 fn test_search_with_filter_empty_match() {
1085 let engine = RedbEngine::in_memory().unwrap();
1086 let config = HnswConfig::new(4);
1087 let mut index =
1088 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
1089
1090 for i in 1..=5 {
1092 let embedding = create_test_embedding(4, i as f32);
1093 index.insert(EntityId::new(i as u64), &embedding).unwrap();
1094 }
1095
1096 let query = create_test_embedding(4, 1.0);
1097
1098 let predicate = |_id: EntityId| false;
1100
1101 let results = index.search_with_filter(&query, 3, predicate, None, None).unwrap();
1102
1103 assert!(results.is_empty());
1105 }
1106
1107 #[test]
1108 fn test_search_with_filter_all_match() {
1109 let engine = RedbEngine::in_memory().unwrap();
1110 let config = HnswConfig::new(4);
1111 let mut index =
1112 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
1113
1114 for i in 1..=5 {
1116 let embedding = create_test_embedding(4, i as f32);
1117 index.insert(EntityId::new(i as u64), &embedding).unwrap();
1118 }
1119
1120 let query = create_test_embedding(4, 1.0);
1121
1122 let predicate = |_id: EntityId| true;
1124
1125 let results = index.search_with_filter(&query, 3, predicate, None, None).unwrap();
1126
1127 assert_eq!(results.len(), 3);
1129 let regular_results = index.search(&query, 3, None).unwrap();
1131 assert_eq!(results[0].entity_id, regular_results[0].entity_id);
1132 }
1133
1134 #[test]
1135 fn test_filtered_search_config() {
1136 let config = FilteredSearchConfig::new()
1137 .with_min_ef_search(50)
1138 .with_max_ef_search(1000)
1139 .with_ef_multiplier(3.0);
1140
1141 assert_eq!(config.min_ef_search, 50);
1142 assert_eq!(config.max_ef_search, 1000);
1143 assert_eq!(config.ef_multiplier, 3.0);
1144
1145 let adjusted = config.adjusted_ef(100, None);
1147 assert_eq!(adjusted, 300); let adjusted_selective = config.adjusted_ef(100, Some(0.5));
1151 assert_eq!(adjusted_selective, 200);
1152
1153 let adjusted_very_selective = config.adjusted_ef(100, Some(0.1));
1155 assert_eq!(adjusted_very_selective, 1000); let adjusted_min = config.adjusted_ef(10, None);
1159 assert_eq!(adjusted_min, 50); }
1161
1162 #[test]
1167 fn test_insert_batch_empty() {
1168 let engine = RedbEngine::in_memory().unwrap();
1169 let config = HnswConfig::default();
1170 let mut index =
1171 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
1172
1173 index.insert_batch(&[]).unwrap();
1175 assert_eq!(index.len().unwrap(), 0);
1176 }
1177
1178 #[test]
1179 fn test_insert_batch_single() {
1180 let engine = RedbEngine::in_memory().unwrap();
1181 let config = HnswConfig::default();
1182 let mut index =
1183 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
1184
1185 let embedding = create_test_embedding(4, 1.0);
1186 index.insert_batch(&[(EntityId::new(1), &embedding)]).unwrap();
1187
1188 assert_eq!(index.len().unwrap(), 1);
1189 assert!(index.contains(EntityId::new(1)).unwrap());
1190 }
1191
1192 #[test]
1193 fn test_insert_batch_multiple() {
1194 let engine = RedbEngine::in_memory().unwrap();
1195 let config = HnswConfig::new(4);
1196 let mut index =
1197 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
1198
1199 let embeddings: Vec<Embedding> =
1201 (0..10).map(|i| create_test_embedding(4, i as f32)).collect();
1202
1203 let batch: Vec<(EntityId, &Embedding)> =
1204 embeddings.iter().enumerate().map(|(i, e)| (EntityId::new(i as u64), e)).collect();
1205
1206 index.insert_batch(&batch).unwrap();
1207
1208 assert_eq!(index.len().unwrap(), 10);
1209 for i in 0..10 {
1210 assert!(index.contains(EntityId::new(i)).unwrap());
1211 }
1212 }
1213
1214 #[test]
1215 fn test_insert_batch_large() {
1216 let engine = RedbEngine::in_memory().unwrap();
1217 let config = HnswConfig::new(16);
1218 let mut index =
1219 HnswIndex::new(engine, "test", 128, DistanceMetric::Euclidean, config).unwrap();
1220
1221 let embeddings: Vec<Embedding> =
1223 (0..500).map(|i| Embedding::new(vec![i as f32 / 500.0; 128]).unwrap()).collect();
1224
1225 let batch: Vec<(EntityId, &Embedding)> =
1226 embeddings.iter().enumerate().map(|(i, e)| (EntityId::new(i as u64), e)).collect();
1227
1228 index.insert_batch(&batch).unwrap();
1229
1230 assert_eq!(index.len().unwrap(), 500);
1231
1232 let query = Embedding::new(vec![0.5; 128]).unwrap();
1234 let results = index.search(&query, 10, None).unwrap();
1235 assert_eq!(results.len(), 10);
1236 }
1237
1238 #[test]
1239 fn test_insert_batch_dimension_mismatch() {
1240 let engine = RedbEngine::in_memory().unwrap();
1241 let config = HnswConfig::default();
1242 let mut index =
1243 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
1244
1245 let good_embedding = create_test_embedding(4, 1.0);
1246 let bad_embedding = create_test_embedding(8, 2.0); let batch: Vec<(EntityId, &Embedding)> =
1249 vec![(EntityId::new(1), &good_embedding), (EntityId::new(2), &bad_embedding)];
1250
1251 let result = index.insert_batch(&batch);
1252 assert!(matches!(result, Err(VectorError::DimensionMismatch { .. })));
1253 }
1254
1255 #[test]
1256 fn test_insert_batch_updates_existing() {
1257 let engine = RedbEngine::in_memory().unwrap();
1258 let config = HnswConfig::new(4);
1259 let mut index =
1260 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
1261
1262 let original = create_test_embedding(4, 1.0);
1264 index.insert(EntityId::new(1), &original).unwrap();
1265
1266 let updated = create_test_embedding(4, 10.0);
1268 index.insert_batch(&[(EntityId::new(1), &updated)]).unwrap();
1269
1270 assert_eq!(index.len().unwrap(), 1);
1272
1273 let query = create_test_embedding(4, 10.0);
1275 let results = index.search(&query, 1, None).unwrap();
1276 assert_eq!(results.len(), 1);
1277 assert!(results[0].distance < 1e-6);
1278 }
1279
1280 #[test]
1281 fn test_insert_batch_search_quality() {
1282 let engine = RedbEngine::in_memory().unwrap();
1283 let config = HnswConfig::new(16);
1284 let mut index =
1285 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
1286
1287 let embeddings: Vec<Embedding> =
1289 (1..=20).map(|i| create_test_embedding(4, i as f32)).collect();
1290
1291 let batch: Vec<(EntityId, &Embedding)> = embeddings
1292 .iter()
1293 .enumerate()
1294 .map(|(i, e)| (EntityId::new((i + 1) as u64), e))
1295 .collect();
1296
1297 index.insert_batch(&batch).unwrap();
1298
1299 let query = create_test_embedding(4, 10.5);
1301 let results = index.search(&query, 5, None).unwrap();
1302
1303 let top_ids: Vec<u64> = results.iter().map(|r| r.entity_id.as_u64()).collect();
1305 assert!(top_ids.contains(&10) || top_ids.contains(&11));
1306 }
1307
1308 #[test]
1309 fn test_insert_batch_persistence() {
1310 let temp_dir = std::env::temp_dir();
1311 let db_path = temp_dir.join(format!("hnsw_batch_persist_test_{}.redb", std::process::id()));
1312 let _ = std::fs::remove_file(&db_path);
1313
1314 {
1316 let engine = RedbEngine::open(&db_path).unwrap();
1317 let config = HnswConfig::new(4);
1318 let mut index =
1319 HnswIndex::new(engine, "batch_test", 4, DistanceMetric::Euclidean, config).unwrap();
1320
1321 let embeddings: Vec<Embedding> =
1322 (0..50).map(|i| create_test_embedding(4, i as f32)).collect();
1323
1324 let batch: Vec<(EntityId, &Embedding)> =
1325 embeddings.iter().enumerate().map(|(i, e)| (EntityId::new(i as u64), e)).collect();
1326
1327 index.insert_batch(&batch).unwrap();
1328 index.flush().unwrap();
1329 }
1330
1331 {
1333 let engine = RedbEngine::open(&db_path).unwrap();
1334 let index: HnswIndex<RedbEngine> = HnswIndex::open(engine, "batch_test").unwrap();
1335
1336 assert_eq!(index.len().unwrap(), 50);
1337 for i in 0..50 {
1338 assert!(index.contains(EntityId::new(i)).unwrap());
1339 }
1340
1341 let query = create_test_embedding(4, 25.0);
1343 let results = index.search(&query, 5, None).unwrap();
1344 assert_eq!(results.len(), 5);
1345 }
1346
1347 let _ = std::fs::remove_file(&db_path);
1348 }
1349
1350 #[test]
1355 fn test_error_node_not_found_display() {
1356 let error = VectorError::NodeNotFound(EntityId::new(42));
1357 let msg = error.to_string();
1358 assert!(msg.contains("42"), "Error should contain entity ID");
1359 assert!(msg.contains("node not found"), "Error should describe issue");
1360 }
1361
1362 #[test]
1363 fn test_error_invalid_graph_state_display() {
1364 let error = VectorError::InvalidGraphState("entry_point missing in non-empty graph");
1365 let msg = error.to_string();
1366 assert!(msg.contains("entry_point"), "Error should contain context");
1367 assert!(msg.contains("invalid graph state"), "Error should describe issue");
1368 }
1369
1370 #[test]
1371 fn test_delete_nonexistent_returns_false() {
1372 let engine = RedbEngine::in_memory().unwrap();
1373 let config = HnswConfig::default();
1374 let mut index =
1375 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
1376
1377 let result = index.delete(EntityId::new(999));
1379 assert!(result.is_ok());
1380 assert!(!result.unwrap());
1381 }
1382
1383 #[test]
1384 fn test_contains_nonexistent_returns_false() {
1385 let engine = RedbEngine::in_memory().unwrap();
1386 let config = HnswConfig::default();
1387 let index = HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
1388
1389 assert!(!index.contains(EntityId::new(999)).unwrap());
1391 }
1392
1393 #[test]
1394 fn test_search_after_all_deleted() {
1395 let engine = RedbEngine::in_memory().unwrap();
1396 let config = HnswConfig::new(4);
1397 let mut index =
1398 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
1399
1400 for i in 0..5 {
1402 let embedding = create_test_embedding(4, i as f32);
1403 index.insert(EntityId::new(i), &embedding).unwrap();
1404 }
1405
1406 for i in 0..5 {
1408 assert!(index.delete(EntityId::new(i)).unwrap());
1409 }
1410
1411 let query = create_test_embedding(4, 1.0);
1413 let results = index.search(&query, 5, None).unwrap();
1414 assert!(results.is_empty());
1415 }
1416
1417 #[test]
1418 fn test_filtered_search_on_empty_index() {
1419 let engine = RedbEngine::in_memory().unwrap();
1420 let config = HnswConfig::default();
1421 let index = HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
1422
1423 let query = create_test_embedding(4, 1.0);
1424 let predicate = |_id: EntityId| true;
1425
1426 let results = index.search_with_filter(&query, 5, predicate, None, None).unwrap();
1427 assert!(results.is_empty());
1428 }
1429
1430 #[test]
1431 fn test_batch_insert_empty_vec() {
1432 let engine = RedbEngine::in_memory().unwrap();
1433 let config = HnswConfig::default();
1434 let mut index =
1435 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, config).unwrap();
1436
1437 let empty: Vec<(EntityId, &Embedding)> = vec![];
1439 index.insert_batch(&empty).unwrap();
1440 assert_eq!(index.len().unwrap(), 0);
1441 }
1442}