1use std::cmp::Reverse;
22use std::collections::{BinaryHeap, HashMap, HashSet};
23use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
24
25use super::distance::{
26 cmp_f32, distance_simd, DistanceMetric, DistanceResult, ReverseDistanceResult,
27};
28
29pub type NodeId = u64;
31
32#[derive(Debug, Clone)]
34pub struct HnswConfig {
35 pub m: usize,
37 pub m_max0: usize,
39 pub ef_construction: usize,
41 pub ef_search: usize,
43 pub ml: f64,
45 pub metric: DistanceMetric,
47}
48
49impl Default for HnswConfig {
50 fn default() -> Self {
51 let m = 16;
52 Self {
53 m,
54 m_max0: m * 2,
55 ef_construction: 100,
56 ef_search: 50,
57 ml: 1.0 / (m as f64).ln(),
58 metric: DistanceMetric::L2,
59 }
60 }
61}
62
63impl HnswConfig {
64 pub fn with_m(m: usize) -> Self {
66 Self {
67 m,
68 m_max0: m * 2,
69 ml: 1.0 / (m as f64).ln(),
70 ..Default::default()
71 }
72 }
73
74 pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
76 self.metric = metric;
77 self
78 }
79
80 pub fn with_ef_construction(mut self, ef: usize) -> Self {
82 self.ef_construction = ef;
83 self
84 }
85
86 pub fn with_ef_search(mut self, ef: usize) -> Self {
88 self.ef_search = ef;
89 self
90 }
91}
92
93#[derive(Debug, Clone)]
95struct HnswNode {
96 id: NodeId,
98 vector: Vec<f32>,
100 max_layer: usize,
102 connections: Vec<Vec<NodeId>>,
104}
105
106impl HnswNode {
107 fn new(id: NodeId, vector: Vec<f32>, max_layer: usize) -> Self {
108 let mut connections = Vec::with_capacity(max_layer + 1);
109 for _ in 0..=max_layer {
110 connections.push(Vec::new());
111 }
112 Self {
113 id,
114 vector,
115 max_layer,
116 connections,
117 }
118 }
119}
120
121pub struct HnswIndex {
123 config: HnswConfig,
125 nodes: HashMap<NodeId, HnswNode>,
127 entry_point: Option<NodeId>,
129 max_layer: usize,
131 dimension: usize,
133 next_id: AtomicU64,
135 rng_state: u64,
137}
138
139impl HnswIndex {
140 pub fn new(dimension: usize, config: HnswConfig) -> Self {
142 Self {
143 config,
144 nodes: HashMap::new(),
145 entry_point: None,
146 max_layer: 0,
147 dimension,
148 next_id: AtomicU64::new(0),
149 rng_state: 0x853c49e6748fea9b, }
151 }
152
153 pub fn with_dimension(dimension: usize) -> Self {
155 Self::new(dimension, HnswConfig::default())
156 }
157
158 pub fn len(&self) -> usize {
160 self.nodes.len()
161 }
162
163 pub fn is_empty(&self) -> bool {
165 self.nodes.is_empty()
166 }
167
168 pub fn get_vector(&self, id: NodeId) -> Option<&[f32]> {
170 self.nodes.get(&id).map(|n| n.vector.as_slice())
171 }
172
173 pub fn insert(&mut self, vector: Vec<f32>) -> NodeId {
175 let id = self.next_id.fetch_add(1, AtomicOrdering::SeqCst);
176 self.insert_with_id(id, vector);
177 id
178 }
179
180 pub fn insert_with_id(&mut self, id: NodeId, vector: Vec<f32>) {
182 assert_eq!(
183 vector.len(),
184 self.dimension,
185 "Vector dimension mismatch: expected {}, got {}",
186 self.dimension,
187 vector.len()
188 );
189
190 let node_layer = self.random_layer();
192
193 let node = HnswNode::new(id, vector, node_layer);
195
196 if self.entry_point.is_none() {
197 self.nodes.insert(id, node);
199 self.entry_point = Some(id);
200 self.max_layer = node_layer;
201 return;
202 }
203
204 let entry_point = self.entry_point.unwrap();
205 let vector = self.nodes.get(&id).map(|n| n.vector.clone());
206
207 let vector = node.vector.clone();
210 self.nodes.insert(id, node);
211
212 let mut current = entry_point;
214
215 for layer in (node_layer + 1..=self.max_layer).rev() {
218 current = self.search_layer_single(&vector, current, layer);
219 }
220
221 for layer in (0..=node_layer.min(self.max_layer)).rev() {
223 let neighbors = self.search_layer(&vector, current, self.config.ef_construction, layer);
225
226 let m = if layer == 0 {
228 self.config.m_max0
229 } else {
230 self.config.m
231 };
232 let selected: Vec<NodeId> = neighbors.into_iter().take(m).map(|r| r.id).collect();
233
234 if let Some(node) = self.nodes.get_mut(&id) {
236 node.connections[layer] = selected.clone();
237 }
238
239 for &neighbor_id in &selected {
241 self.add_connection(neighbor_id, id, layer);
242 }
243
244 if let Some(&first) = selected.first() {
246 current = first;
247 }
248 }
249
250 if node_layer > self.max_layer {
252 self.entry_point = Some(id);
253 self.max_layer = node_layer;
254 }
255 }
256
257 pub fn search(&self, query: &[f32], k: usize) -> Vec<DistanceResult> {
259 self.search_with_ef(query, k, self.config.ef_search)
260 }
261
262 pub fn search_with_ef(&self, query: &[f32], k: usize, ef: usize) -> Vec<DistanceResult> {
264 if self.entry_point.is_none() {
265 return Vec::new();
266 }
267
268 let entry_point = self.entry_point.unwrap();
269 let mut current = entry_point;
270
271 for layer in (1..=self.max_layer).rev() {
273 current = self.search_layer_single(query, current, layer);
274 }
275
276 let candidates = self.search_layer(query, current, ef.max(k), 0);
278
279 candidates.into_iter().take(k).collect()
281 }
282
283 pub fn search_filtered(
285 &self,
286 query: &[f32],
287 k: usize,
288 filter: &HashSet<NodeId>,
289 ) -> Vec<DistanceResult> {
290 self.search_filtered_with_ef(query, k, filter, self.config.ef_search)
291 }
292
293 pub fn search_filtered_with_ef(
295 &self,
296 query: &[f32],
297 k: usize,
298 filter: &HashSet<NodeId>,
299 ef: usize,
300 ) -> Vec<DistanceResult> {
301 if self.entry_point.is_none() || filter.is_empty() {
302 return Vec::new();
303 }
304
305 let entry_point = self.entry_point.unwrap();
306 let mut current = entry_point;
307
308 for layer in (1..=self.max_layer).rev() {
310 current = self.search_layer_single(query, current, layer);
311 }
312
313 let expanded_ef = (ef * 2).max(k * 4);
316 let candidates = self.search_layer(query, current, expanded_ef, 0);
317
318 candidates
320 .into_iter()
321 .filter(|r| filter.contains(&r.id))
322 .take(k)
323 .collect()
324 }
325
326 fn random_layer(&mut self) -> usize {
332 self.rng_state ^= self.rng_state << 13;
334 self.rng_state ^= self.rng_state >> 7;
335 self.rng_state ^= self.rng_state << 17;
336
337 let uniform = (self.rng_state as f64) / (u64::MAX as f64);
339
340 (-uniform.ln() * self.config.ml).floor() as usize
343 }
344
345 fn search_layer_single(&self, query: &[f32], entry: NodeId, layer: usize) -> NodeId {
347 let mut current = entry;
348 let mut current_dist = self.compute_distance(query, current);
349
350 loop {
351 let mut changed = false;
352
353 if let Some(node) = self.nodes.get(¤t) {
354 if layer < node.connections.len() {
355 for &neighbor in &node.connections[layer] {
356 let dist = self.compute_distance(query, neighbor);
357 if dist < current_dist {
358 current_dist = dist;
359 current = neighbor;
360 changed = true;
361 }
362 }
363 }
364 }
365
366 if !changed {
367 break;
368 }
369 }
370
371 current
372 }
373
374 fn search_layer(
376 &self,
377 query: &[f32],
378 entry: NodeId,
379 ef: usize,
380 layer: usize,
381 ) -> Vec<DistanceResult> {
382 let entry_dist = self.compute_distance(query, entry);
383
384 let mut candidates: BinaryHeap<Reverse<DistanceResult>> = BinaryHeap::new();
386 candidates.push(Reverse(DistanceResult::new(entry, entry_dist)));
387
388 let mut results: BinaryHeap<ReverseDistanceResult> = BinaryHeap::new();
390 results.push(ReverseDistanceResult(DistanceResult::new(
391 entry, entry_dist,
392 )));
393
394 let mut visited: HashSet<NodeId> = HashSet::new();
396 visited.insert(entry);
397
398 while let Some(Reverse(current)) = candidates.pop() {
399 let furthest_dist = results.peek().map(|r| r.0.distance).unwrap_or(f32::MAX);
401
402 if current.distance > furthest_dist {
404 break;
405 }
406
407 if let Some(node) = self.nodes.get(¤t.id) {
409 if layer < node.connections.len() {
410 for &neighbor_id in &node.connections[layer] {
411 if visited.contains(&neighbor_id) {
412 continue;
413 }
414 visited.insert(neighbor_id);
415
416 let dist = self.compute_distance(query, neighbor_id);
417 let furthest_dist =
418 results.peek().map(|r| r.0.distance).unwrap_or(f32::MAX);
419
420 if dist < furthest_dist || results.len() < ef {
422 candidates.push(Reverse(DistanceResult::new(neighbor_id, dist)));
423 results.push(ReverseDistanceResult(DistanceResult::new(
424 neighbor_id,
425 dist,
426 )));
427
428 while results.len() > ef {
430 results.pop();
431 }
432 }
433 }
434 }
435 }
436 }
437
438 let mut result_vec: Vec<DistanceResult> = results.into_iter().map(|r| r.0).collect();
440 result_vec.sort_by(|a, b| cmp_f32(a.distance, b.distance).then_with(|| a.id.cmp(&b.id)));
441 result_vec
442 }
443
444 fn add_connection(&mut self, from: NodeId, to: NodeId, layer: usize) {
446 let max_connections = if layer == 0 {
447 self.config.m_max0
448 } else {
449 self.config.m
450 };
451
452 if let Some(node) = self.nodes.get_mut(&from) {
453 while node.connections.len() <= layer {
455 node.connections.push(Vec::new());
456 }
457
458 if !node.connections[layer].contains(&to) {
460 node.connections[layer].push(to);
461
462 if node.connections[layer].len() > max_connections {
464 self.prune_connections(from, layer, max_connections);
465 }
466 }
467 }
468 }
469
470 fn prune_connections(&mut self, node_id: NodeId, layer: usize, max_connections: usize) {
472 let node_vector = self.nodes.get(&node_id).map(|n| n.vector.clone());
473 let node_vector = match node_vector {
474 Some(v) => v,
475 None => return,
476 };
477
478 let neighbors: Vec<NodeId> = self
480 .nodes
481 .get(&node_id)
482 .map(|n| n.connections[layer].clone())
483 .unwrap_or_default();
484
485 let mut scored: Vec<DistanceResult> = neighbors
486 .iter()
487 .map(|&neighbor_id| {
488 let dist = self.compute_distance(&node_vector, neighbor_id);
489 DistanceResult::new(neighbor_id, dist)
490 })
491 .collect();
492
493 scored.sort_by(|a, b| cmp_f32(a.distance, b.distance).then_with(|| a.id.cmp(&b.id)));
495
496 let kept: Vec<NodeId> = scored
498 .into_iter()
499 .take(max_connections)
500 .map(|r| r.id)
501 .collect();
502
503 if let Some(node) = self.nodes.get_mut(&node_id) {
504 node.connections[layer] = kept;
505 }
506 }
507
508 fn compute_distance(&self, query: &[f32], node_id: NodeId) -> f32 {
510 match self.nodes.get(&node_id) {
511 Some(node) => distance_simd(query, &node.vector, self.config.metric),
512 None => f32::MAX,
513 }
514 }
515
516 pub fn insert_batch(&mut self, vectors: Vec<Vec<f32>>) -> Vec<NodeId> {
524 vectors.into_iter().map(|v| self.insert(v)).collect()
525 }
526
527 pub fn insert_batch_with_ids(&mut self, items: Vec<(NodeId, Vec<f32>)>) {
529 for (id, vector) in items {
530 self.insert_with_id(id, vector);
531 }
532 }
533
534 pub fn delete(&mut self, id: NodeId) -> bool {
544 if self.nodes.remove(&id).is_none() {
545 return false;
546 }
547
548 if self.entry_point == Some(id) {
550 self.entry_point = self.nodes.keys().next().copied();
551
552 if let Some(ep) = self.entry_point {
554 self.max_layer = self.nodes.get(&ep).map(|n| n.max_layer).unwrap_or(0);
555 } else {
556 self.max_layer = 0;
557 }
558 }
559
560 for node in self.nodes.values_mut() {
562 for layer_connections in node.connections.iter_mut() {
563 layer_connections.retain(|&neighbor| neighbor != id);
564 }
565 }
566
567 true
568 }
569
570 pub fn contains(&self, id: NodeId) -> bool {
572 self.nodes.contains_key(&id)
573 }
574
575 pub fn search_adaptive(&self, query: &[f32], k: usize) -> Vec<DistanceResult> {
584 let n = self.nodes.len();
586 let adaptive_ef = if n < 100 {
587 k.max(10)
588 } else if n < 10000 {
589 k.max(50)
590 } else if n < 100000 {
591 k.max(100)
592 } else {
593 k.max(200)
594 };
595
596 self.search_with_ef(query, k, adaptive_ef)
597 }
598
599 pub fn stats(&self) -> HnswStats {
605 let mut layer_counts = vec![0usize; self.max_layer + 1];
606 let mut total_connections = 0usize;
607 let mut max_connections = 0usize;
608 let mut min_connections = usize::MAX;
609
610 for node in self.nodes.values() {
611 for (layer, layer_count) in layer_counts.iter_mut().enumerate().take(node.max_layer + 1)
612 {
613 *layer_count += 1;
614 let conns = node.connections.get(layer).map(|c| c.len()).unwrap_or(0);
615 total_connections += conns;
616 max_connections = max_connections.max(conns);
617 if conns > 0 {
618 min_connections = min_connections.min(conns);
619 }
620 }
621 }
622
623 if self.nodes.is_empty() {
624 min_connections = 0;
625 }
626
627 HnswStats {
628 node_count: self.nodes.len(),
629 dimension: self.dimension,
630 max_layer: self.max_layer,
631 layer_counts,
632 total_connections,
633 avg_connections: if self.nodes.is_empty() {
634 0.0
635 } else {
636 total_connections as f64 / self.nodes.len() as f64
637 },
638 max_connections,
639 min_connections,
640 entry_point: self.entry_point,
641 }
642 }
643
644 pub fn to_bytes(&self) -> Vec<u8> {
654 let metric = match self.config.metric {
655 DistanceMetric::L2 => 0,
656 DistanceMetric::Cosine => 1,
657 DistanceMetric::InnerProduct => 2,
658 };
659 let nodes = self
660 .nodes
661 .values()
662 .map(|node| reddb_file::HnswNodeLayout {
663 id: node.id,
664 max_layer: node.max_layer,
665 vector: node.vector.clone(),
666 connections: node.connections.clone(),
667 })
668 .collect();
669 let layout = reddb_file::HnswIndexLayout {
670 dimension: self.dimension,
671 m: self.config.m,
672 m_max0: self.config.m_max0,
673 ef_construction: self.config.ef_construction,
674 ef_search: self.config.ef_search,
675 ml: self.config.ml,
676 metric,
677 max_layer: self.max_layer,
678 entry_point: self.entry_point,
679 nodes,
680 };
681 reddb_file::encode_hnsw_index(&layout)
682 }
683
684 pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
686 let layout = reddb_file::decode_hnsw_index(bytes).map_err(|e| e.to_string())?;
687
688 let metric = match layout.metric {
689 0 => DistanceMetric::L2,
690 1 => DistanceMetric::Cosine,
691 2 => DistanceMetric::InnerProduct,
692 _ => return Err("Invalid distance metric".to_string()),
693 };
694
695 let config = HnswConfig {
696 m: layout.m,
697 m_max0: layout.m_max0,
698 ef_construction: layout.ef_construction,
699 ef_search: layout.ef_search,
700 ml: layout.ml,
701 metric,
702 };
703
704 let mut nodes = HashMap::new();
705 let mut max_id = 0u64;
706 for node in layout.nodes {
707 max_id = max_id.max(node.id);
708 nodes.insert(
709 node.id,
710 HnswNode {
711 id: node.id,
712 max_layer: node.max_layer,
713 vector: node.vector,
714 connections: node.connections,
715 },
716 );
717 }
718
719 Ok(Self {
720 config,
721 nodes,
722 entry_point: layout.entry_point,
723 max_layer: layout.max_layer,
724 dimension: layout.dimension,
725 next_id: AtomicU64::new(max_id + 1),
726 rng_state: 12345, })
728 }
729}
730
731#[derive(Debug, Clone)]
733pub struct HnswStats {
734 pub node_count: usize,
736 pub dimension: usize,
738 pub max_layer: usize,
740 pub layer_counts: Vec<usize>,
742 pub total_connections: usize,
744 pub avg_connections: f64,
746 pub max_connections: usize,
748 pub min_connections: usize,
750 pub entry_point: Option<NodeId>,
752}
753
754#[derive(Debug, Clone)]
756pub struct Bitset {
757 bits: Vec<u64>,
758 len: usize,
759}
760
761impl Bitset {
762 pub fn with_capacity(n: usize) -> Self {
764 let num_words = n.div_ceil(64);
765 Self {
766 bits: vec![0; num_words],
767 len: n,
768 }
769 }
770
771 pub fn all(n: usize) -> Self {
773 let num_words = n.div_ceil(64);
774 let mut bits = vec![u64::MAX; num_words];
775
776 if !n.is_multiple_of(64) {
778 let last_idx = num_words - 1;
779 let valid_bits = n % 64;
780 bits[last_idx] = (1u64 << valid_bits) - 1;
781 }
782
783 Self { bits, len: n }
784 }
785
786 pub fn set(&mut self, idx: usize) {
788 if idx < self.len {
789 let word = idx / 64;
790 let bit = idx % 64;
791 self.bits[word] |= 1u64 << bit;
792 }
793 }
794
795 pub fn clear(&mut self, idx: usize) {
797 if idx < self.len {
798 let word = idx / 64;
799 let bit = idx % 64;
800 self.bits[word] &= !(1u64 << bit);
801 }
802 }
803
804 pub fn is_set(&self, idx: usize) -> bool {
806 if idx >= self.len {
807 return false;
808 }
809 let word = idx / 64;
810 let bit = idx % 64;
811 (self.bits[word] & (1u64 << bit)) != 0
812 }
813
814 pub fn to_hashset(&self) -> HashSet<NodeId> {
816 let mut set = HashSet::new();
817 for i in 0..self.len {
818 if self.is_set(i) {
819 set.insert(i as NodeId);
820 }
821 }
822 set
823 }
824}
825
826#[cfg(test)]
827mod tests {
828 use super::*;
829
830 fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
831 let mut state = seed;
832 (0..dim)
833 .map(|_| {
834 state ^= state << 13;
835 state ^= state >> 7;
836 state ^= state << 17;
837 (state as f32) / (u64::MAX as f32)
838 })
839 .collect()
840 }
841
842 #[test]
843 fn test_empty_index() {
844 let index = HnswIndex::with_dimension(128);
845 assert!(index.is_empty());
846 assert_eq!(index.len(), 0);
847
848 let results = index.search(&vec![0.0; 128], 10);
849 assert!(results.is_empty());
850 }
851
852 #[test]
853 fn test_single_insert() {
854 let mut index = HnswIndex::with_dimension(3);
855 let id = index.insert(vec![1.0, 2.0, 3.0]);
856
857 assert_eq!(index.len(), 1);
858 assert!(!index.is_empty());
859 assert!(index.get_vector(id).is_some());
860 }
861
862 #[test]
863 fn test_exact_match() {
864 let mut index = HnswIndex::with_dimension(3);
865 index.insert(vec![1.0, 0.0, 0.0]);
866 index.insert(vec![0.0, 1.0, 0.0]);
867 index.insert(vec![0.0, 0.0, 1.0]);
868
869 let results = index.search(&[1.0, 0.0, 0.0], 1);
871 assert_eq!(results.len(), 1);
872 assert_eq!(results[0].distance, 0.0);
873 }
874
875 #[test]
876 fn test_nearest_neighbor() {
877 let mut index = HnswIndex::with_dimension(2);
878 index.insert_with_id(0, vec![0.0, 0.0]);
879 index.insert_with_id(1, vec![1.0, 0.0]);
880 index.insert_with_id(2, vec![2.0, 0.0]);
881 index.insert_with_id(3, vec![3.0, 0.0]);
882
883 let results = index.search(&[0.9, 0.0], 1);
885 assert_eq!(results.len(), 1);
886 assert_eq!(results[0].id, 1); }
888
889 #[test]
890 fn test_k_nearest() {
891 let mut index = HnswIndex::with_dimension(2);
892 for i in 0..10 {
893 index.insert_with_id(i, vec![i as f32, 0.0]);
894 }
895
896 let results = index.search(&[4.5, 0.0], 3);
898 assert_eq!(results.len(), 3);
899
900 let ids: HashSet<_> = results.iter().map(|r| r.id).collect();
902 assert!(ids.contains(&4));
903 assert!(ids.contains(&5));
904 }
905
906 #[test]
907 fn test_filtered_search() {
908 let mut index = HnswIndex::with_dimension(2);
909 for i in 0..10 {
910 index.insert_with_id(i, vec![i as f32, 0.0]);
911 }
912
913 let filter: HashSet<NodeId> = [0, 2, 4, 6, 8].iter().copied().collect();
915
916 let results = index.search_filtered(&[5.0, 0.0], 2, &filter);
918
919 assert_eq!(results.len(), 2);
921 let ids: HashSet<_> = results.iter().map(|r| r.id).collect();
922 assert!(ids.contains(&4));
923 assert!(ids.contains(&6));
924 }
925
926 #[test]
927 fn test_cosine_distance() {
928 let config = HnswConfig::default().with_metric(DistanceMetric::Cosine);
929 let mut index = HnswIndex::new(3, config);
930
931 index.insert_with_id(0, vec![1.0, 0.0, 0.0]);
933 index.insert_with_id(1, vec![0.0, 1.0, 0.0]);
934 index.insert_with_id(2, vec![0.707, 0.707, 0.0]); let results = index.search(&[0.707, 0.707, 0.0], 1);
938 assert_eq!(results[0].id, 2);
939 }
940
941 #[test]
942 fn test_many_vectors() {
943 let dim = 64;
944 let n: usize = 1000;
945
946 let mut index = HnswIndex::with_dimension(dim);
947
948 for i in 0..n {
950 let vector = random_vector(dim, i as u64);
951 index.insert_with_id(i as u64, vector);
952 }
953
954 assert_eq!(index.len(), n);
955
956 let query = random_vector(dim, 12345);
958 let results = index.search(&query, 10);
959 assert_eq!(results.len(), 10);
960
961 for i in 1..results.len() {
963 assert!(results[i].distance >= results[i - 1].distance);
964 }
965 }
966
967 #[test]
968 fn test_bitset() {
969 let mut bs = Bitset::with_capacity(100);
970
971 bs.set(0);
972 bs.set(50);
973 bs.set(99);
974
975 assert!(bs.is_set(0));
976 assert!(bs.is_set(50));
977 assert!(bs.is_set(99));
978 assert!(!bs.is_set(1));
979 assert!(!bs.is_set(64));
980
981 bs.clear(50);
982 assert!(!bs.is_set(50));
983 }
984
985 #[test]
986 fn test_bitset_all() {
987 let bs = Bitset::all(100);
988
989 for i in 0..100 {
990 assert!(bs.is_set(i));
991 }
992 assert!(!bs.is_set(100)); }
994}