1use std::cmp::Reverse;
22use std::collections::{BinaryHeap, HashMap, HashSet};
23use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
24
25use super::distance::{
26 cmp_f32, distance_simd, DistanceMetric, DistanceResult, ReverseDistanceResult,
27};
28
29pub type NodeId = u64;
31
32#[derive(Debug, Clone)]
34pub struct HnswConfig {
35 pub m: usize,
37 pub m_max0: usize,
39 pub ef_construction: usize,
41 pub ef_search: usize,
43 pub ml: f64,
45 pub metric: DistanceMetric,
47}
48
49impl Default for HnswConfig {
50 fn default() -> Self {
51 let m = 16;
52 Self {
53 m,
54 m_max0: m * 2,
55 ef_construction: 100,
56 ef_search: 50,
57 ml: 1.0 / (m as f64).ln(),
58 metric: DistanceMetric::L2,
59 }
60 }
61}
62
63impl HnswConfig {
64 pub fn with_m(m: usize) -> Self {
66 Self {
67 m,
68 m_max0: m * 2,
69 ml: 1.0 / (m as f64).ln(),
70 ..Default::default()
71 }
72 }
73
74 pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
76 self.metric = metric;
77 self
78 }
79
80 pub fn with_ef_construction(mut self, ef: usize) -> Self {
82 self.ef_construction = ef;
83 self
84 }
85
86 pub fn with_ef_search(mut self, ef: usize) -> Self {
88 self.ef_search = ef;
89 self
90 }
91}
92
93#[derive(Debug, Clone)]
95struct HnswNode {
96 id: NodeId,
98 vector: Vec<f32>,
100 max_layer: usize,
102 connections: Vec<Vec<NodeId>>,
104}
105
106impl HnswNode {
107 fn new(id: NodeId, vector: Vec<f32>, max_layer: usize) -> Self {
108 let mut connections = Vec::with_capacity(max_layer + 1);
109 for _ in 0..=max_layer {
110 connections.push(Vec::new());
111 }
112 Self {
113 id,
114 vector,
115 max_layer,
116 connections,
117 }
118 }
119}
120
121pub struct HnswIndex {
123 config: HnswConfig,
125 nodes: HashMap<NodeId, HnswNode>,
127 entry_point: Option<NodeId>,
129 max_layer: usize,
131 dimension: usize,
133 next_id: AtomicU64,
135 rng_state: u64,
137}
138
139impl HnswIndex {
140 pub fn new(dimension: usize, config: HnswConfig) -> Self {
142 Self {
143 config,
144 nodes: HashMap::new(),
145 entry_point: None,
146 max_layer: 0,
147 dimension,
148 next_id: AtomicU64::new(0),
149 rng_state: 0x853c49e6748fea9b, }
151 }
152
153 pub fn with_dimension(dimension: usize) -> Self {
155 Self::new(dimension, HnswConfig::default())
156 }
157
158 pub fn len(&self) -> usize {
160 self.nodes.len()
161 }
162
163 pub fn is_empty(&self) -> bool {
165 self.nodes.is_empty()
166 }
167
168 pub fn get_vector(&self, id: NodeId) -> Option<&[f32]> {
170 self.nodes.get(&id).map(|n| n.vector.as_slice())
171 }
172
173 pub fn insert(&mut self, vector: Vec<f32>) -> NodeId {
175 let id = self.next_id.fetch_add(1, AtomicOrdering::SeqCst);
176 self.insert_with_id(id, vector);
177 id
178 }
179
180 pub fn insert_with_id(&mut self, id: NodeId, vector: Vec<f32>) {
182 assert_eq!(
183 vector.len(),
184 self.dimension,
185 "Vector dimension mismatch: expected {}, got {}",
186 self.dimension,
187 vector.len()
188 );
189
190 let node_layer = self.random_layer();
192
193 let node = HnswNode::new(id, vector, node_layer);
195
196 if self.entry_point.is_none() {
197 self.nodes.insert(id, node);
199 self.entry_point = Some(id);
200 self.max_layer = node_layer;
201 return;
202 }
203
204 let entry_point = self.entry_point.unwrap();
205 let vector = self.nodes.get(&id).map(|n| n.vector.clone());
206
207 let vector = node.vector.clone();
210 self.nodes.insert(id, node);
211
212 let mut current = entry_point;
214
215 for layer in (node_layer + 1..=self.max_layer).rev() {
218 current = self.search_layer_single(&vector, current, layer);
219 }
220
221 for layer in (0..=node_layer.min(self.max_layer)).rev() {
223 let neighbors = self.search_layer(&vector, current, self.config.ef_construction, layer);
225
226 let m = if layer == 0 {
228 self.config.m_max0
229 } else {
230 self.config.m
231 };
232 let selected: Vec<NodeId> = neighbors.into_iter().take(m).map(|r| r.id).collect();
233
234 if let Some(node) = self.nodes.get_mut(&id) {
236 node.connections[layer] = selected.clone();
237 }
238
239 for &neighbor_id in &selected {
241 self.add_connection(neighbor_id, id, layer);
242 }
243
244 if let Some(&first) = selected.first() {
246 current = first;
247 }
248 }
249
250 if node_layer > self.max_layer {
252 self.entry_point = Some(id);
253 self.max_layer = node_layer;
254 }
255 }
256
257 pub fn search(&self, query: &[f32], k: usize) -> Vec<DistanceResult> {
259 self.search_with_ef(query, k, self.config.ef_search)
260 }
261
262 pub fn search_with_ef(&self, query: &[f32], k: usize, ef: usize) -> Vec<DistanceResult> {
264 if self.entry_point.is_none() {
265 return Vec::new();
266 }
267
268 let entry_point = self.entry_point.unwrap();
269 let mut current = entry_point;
270
271 for layer in (1..=self.max_layer).rev() {
273 current = self.search_layer_single(query, current, layer);
274 }
275
276 let candidates = self.search_layer(query, current, ef.max(k), 0);
278
279 candidates.into_iter().take(k).collect()
281 }
282
283 pub fn search_filtered(
285 &self,
286 query: &[f32],
287 k: usize,
288 filter: &HashSet<NodeId>,
289 ) -> Vec<DistanceResult> {
290 self.search_filtered_with_ef(query, k, filter, self.config.ef_search)
291 }
292
293 pub fn search_filtered_with_ef(
295 &self,
296 query: &[f32],
297 k: usize,
298 filter: &HashSet<NodeId>,
299 ef: usize,
300 ) -> Vec<DistanceResult> {
301 if self.entry_point.is_none() || filter.is_empty() {
302 return Vec::new();
303 }
304
305 let entry_point = self.entry_point.unwrap();
306 let mut current = entry_point;
307
308 for layer in (1..=self.max_layer).rev() {
310 current = self.search_layer_single(query, current, layer);
311 }
312
313 let expanded_ef = (ef * 2).max(k * 4);
316 let candidates = self.search_layer(query, current, expanded_ef, 0);
317
318 candidates
320 .into_iter()
321 .filter(|r| filter.contains(&r.id))
322 .take(k)
323 .collect()
324 }
325
326 fn random_layer(&mut self) -> usize {
332 self.rng_state ^= self.rng_state << 13;
334 self.rng_state ^= self.rng_state >> 7;
335 self.rng_state ^= self.rng_state << 17;
336
337 let uniform = (self.rng_state as f64) / (u64::MAX as f64);
339
340 (-uniform.ln() * self.config.ml).floor() as usize
343 }
344
345 fn search_layer_single(&self, query: &[f32], entry: NodeId, layer: usize) -> NodeId {
347 let mut current = entry;
348 let mut current_dist = self.compute_distance(query, current);
349
350 loop {
351 let mut changed = false;
352
353 if let Some(node) = self.nodes.get(¤t) {
354 if layer < node.connections.len() {
355 for &neighbor in &node.connections[layer] {
356 let dist = self.compute_distance(query, neighbor);
357 if dist < current_dist {
358 current_dist = dist;
359 current = neighbor;
360 changed = true;
361 }
362 }
363 }
364 }
365
366 if !changed {
367 break;
368 }
369 }
370
371 current
372 }
373
374 fn search_layer(
376 &self,
377 query: &[f32],
378 entry: NodeId,
379 ef: usize,
380 layer: usize,
381 ) -> Vec<DistanceResult> {
382 let entry_dist = self.compute_distance(query, entry);
383
384 let mut candidates: BinaryHeap<Reverse<DistanceResult>> = BinaryHeap::new();
386 candidates.push(Reverse(DistanceResult::new(entry, entry_dist)));
387
388 let mut results: BinaryHeap<ReverseDistanceResult> = BinaryHeap::new();
390 results.push(ReverseDistanceResult(DistanceResult::new(
391 entry, entry_dist,
392 )));
393
394 let mut visited: HashSet<NodeId> = HashSet::new();
396 visited.insert(entry);
397
398 while let Some(Reverse(current)) = candidates.pop() {
399 let furthest_dist = results.peek().map(|r| r.0.distance).unwrap_or(f32::MAX);
401
402 if current.distance > furthest_dist {
404 break;
405 }
406
407 if let Some(node) = self.nodes.get(¤t.id) {
409 if layer < node.connections.len() {
410 for &neighbor_id in &node.connections[layer] {
411 if visited.contains(&neighbor_id) {
412 continue;
413 }
414 visited.insert(neighbor_id);
415
416 let dist = self.compute_distance(query, neighbor_id);
417 let furthest_dist =
418 results.peek().map(|r| r.0.distance).unwrap_or(f32::MAX);
419
420 if dist < furthest_dist || results.len() < ef {
422 candidates.push(Reverse(DistanceResult::new(neighbor_id, dist)));
423 results.push(ReverseDistanceResult(DistanceResult::new(
424 neighbor_id,
425 dist,
426 )));
427
428 while results.len() > ef {
430 results.pop();
431 }
432 }
433 }
434 }
435 }
436 }
437
438 let mut result_vec: Vec<DistanceResult> = results.into_iter().map(|r| r.0).collect();
440 result_vec.sort_by(|a, b| cmp_f32(a.distance, b.distance).then_with(|| a.id.cmp(&b.id)));
441 result_vec
442 }
443
444 fn add_connection(&mut self, from: NodeId, to: NodeId, layer: usize) {
446 let max_connections = if layer == 0 {
447 self.config.m_max0
448 } else {
449 self.config.m
450 };
451
452 if let Some(node) = self.nodes.get_mut(&from) {
453 while node.connections.len() <= layer {
455 node.connections.push(Vec::new());
456 }
457
458 if !node.connections[layer].contains(&to) {
460 node.connections[layer].push(to);
461
462 if node.connections[layer].len() > max_connections {
464 self.prune_connections(from, layer, max_connections);
465 }
466 }
467 }
468 }
469
470 fn prune_connections(&mut self, node_id: NodeId, layer: usize, max_connections: usize) {
472 let node_vector = self.nodes.get(&node_id).map(|n| n.vector.clone());
473 let node_vector = match node_vector {
474 Some(v) => v,
475 None => return,
476 };
477
478 let neighbors: Vec<NodeId> = self
480 .nodes
481 .get(&node_id)
482 .map(|n| n.connections[layer].clone())
483 .unwrap_or_default();
484
485 let mut scored: Vec<DistanceResult> = neighbors
486 .iter()
487 .map(|&neighbor_id| {
488 let dist = self.compute_distance(&node_vector, neighbor_id);
489 DistanceResult::new(neighbor_id, dist)
490 })
491 .collect();
492
493 scored.sort_by(|a, b| cmp_f32(a.distance, b.distance).then_with(|| a.id.cmp(&b.id)));
495
496 let kept: Vec<NodeId> = scored
498 .into_iter()
499 .take(max_connections)
500 .map(|r| r.id)
501 .collect();
502
503 if let Some(node) = self.nodes.get_mut(&node_id) {
504 node.connections[layer] = kept;
505 }
506 }
507
508 fn compute_distance(&self, query: &[f32], node_id: NodeId) -> f32 {
510 match self.nodes.get(&node_id) {
511 Some(node) => distance_simd(query, &node.vector, self.config.metric),
512 None => f32::MAX,
513 }
514 }
515
516 pub fn insert_batch(&mut self, vectors: Vec<Vec<f32>>) -> Vec<NodeId> {
524 vectors.into_iter().map(|v| self.insert(v)).collect()
525 }
526
527 pub fn insert_batch_with_ids(&mut self, items: Vec<(NodeId, Vec<f32>)>) {
529 for (id, vector) in items {
530 self.insert_with_id(id, vector);
531 }
532 }
533
534 pub fn delete(&mut self, id: NodeId) -> bool {
544 if self.nodes.remove(&id).is_none() {
545 return false;
546 }
547
548 if self.entry_point == Some(id) {
550 self.entry_point = self.nodes.keys().next().copied();
551
552 if let Some(ep) = self.entry_point {
554 self.max_layer = self.nodes.get(&ep).map(|n| n.max_layer).unwrap_or(0);
555 } else {
556 self.max_layer = 0;
557 }
558 }
559
560 for node in self.nodes.values_mut() {
562 for layer_connections in node.connections.iter_mut() {
563 layer_connections.retain(|&neighbor| neighbor != id);
564 }
565 }
566
567 true
568 }
569
570 pub fn contains(&self, id: NodeId) -> bool {
572 self.nodes.contains_key(&id)
573 }
574
575 pub fn search_adaptive(&self, query: &[f32], k: usize) -> Vec<DistanceResult> {
584 let n = self.nodes.len();
586 let adaptive_ef = if n < 100 {
587 k.max(10)
588 } else if n < 10000 {
589 k.max(50)
590 } else if n < 100000 {
591 k.max(100)
592 } else {
593 k.max(200)
594 };
595
596 self.search_with_ef(query, k, adaptive_ef)
597 }
598
599 pub fn stats(&self) -> HnswStats {
605 let mut layer_counts = vec![0usize; self.max_layer + 1];
606 let mut total_connections = 0usize;
607 let mut max_connections = 0usize;
608 let mut min_connections = usize::MAX;
609
610 for node in self.nodes.values() {
611 for (layer, layer_count) in layer_counts.iter_mut().enumerate().take(node.max_layer + 1)
612 {
613 *layer_count += 1;
614 let conns = node.connections.get(layer).map(|c| c.len()).unwrap_or(0);
615 total_connections += conns;
616 max_connections = max_connections.max(conns);
617 if conns > 0 {
618 min_connections = min_connections.min(conns);
619 }
620 }
621 }
622
623 if self.nodes.is_empty() {
624 min_connections = 0;
625 }
626
627 HnswStats {
628 node_count: self.nodes.len(),
629 dimension: self.dimension,
630 max_layer: self.max_layer,
631 layer_counts,
632 total_connections,
633 avg_connections: if self.nodes.is_empty() {
634 0.0
635 } else {
636 total_connections as f64 / self.nodes.len() as f64
637 },
638 max_connections,
639 min_connections,
640 entry_point: self.entry_point,
641 }
642 }
643
644 pub fn to_bytes(&self) -> Vec<u8> {
650 let mut bytes = Vec::new();
651
652 bytes.extend_from_slice(b"HNSW");
654 bytes.extend_from_slice(&1u32.to_le_bytes()); bytes.extend_from_slice(&(self.dimension as u32).to_le_bytes());
658 bytes.extend_from_slice(&(self.config.m as u32).to_le_bytes());
659 bytes.extend_from_slice(&(self.config.m_max0 as u32).to_le_bytes());
660 bytes.extend_from_slice(&(self.config.ef_construction as u32).to_le_bytes());
661 bytes.extend_from_slice(&(self.config.ef_search as u32).to_le_bytes());
662 bytes.extend_from_slice(&self.config.ml.to_le_bytes());
663 bytes.push(match self.config.metric {
664 DistanceMetric::L2 => 0,
665 DistanceMetric::Cosine => 1,
666 DistanceMetric::InnerProduct => 2,
667 });
668
669 bytes.extend_from_slice(&(self.max_layer as u32).to_le_bytes());
671 bytes.extend_from_slice(&self.entry_point.unwrap_or(u64::MAX).to_le_bytes());
672
673 bytes.extend_from_slice(&(self.nodes.len() as u64).to_le_bytes());
675
676 for (&id, node) in &self.nodes {
678 bytes.extend_from_slice(&id.to_le_bytes());
679 bytes.extend_from_slice(&(node.max_layer as u32).to_le_bytes());
680
681 for &val in &node.vector {
683 bytes.extend_from_slice(&val.to_le_bytes());
684 }
685
686 for layer in 0..=node.max_layer {
688 let conns = &node.connections[layer];
689 bytes.extend_from_slice(&(conns.len() as u32).to_le_bytes());
690 for &conn in conns {
691 bytes.extend_from_slice(&conn.to_le_bytes());
692 }
693 }
694 }
695
696 bytes
697 }
698
699 pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
701 if bytes.len() < 8 {
702 return Err("Data too short".to_string());
703 }
704
705 if &bytes[0..4] != b"HNSW" {
707 return Err("Invalid magic number".to_string());
708 }
709
710 let version = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
711 if version != 1 {
712 return Err(format!("Unsupported version: {}", version));
713 }
714
715 let mut pos = 8;
716
717 let dimension = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
719 pos += 4;
720 let m = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
721 pos += 4;
722 let m_max0 = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
723 pos += 4;
724 let ef_construction = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
725 pos += 4;
726 let ef_search = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
727 pos += 4;
728 let ml = f64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap());
729 pos += 8;
730 let metric = match bytes[pos] {
731 0 => DistanceMetric::L2,
732 1 => DistanceMetric::Cosine,
733 2 => DistanceMetric::InnerProduct,
734 _ => return Err("Invalid distance metric".to_string()),
735 };
736 pos += 1;
737
738 let config = HnswConfig {
739 m,
740 m_max0,
741 ef_construction,
742 ef_search,
743 ml,
744 metric,
745 };
746
747 let max_layer = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
749 pos += 4;
750 let ep_value = u64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap());
751 pos += 8;
752 let entry_point = if ep_value == u64::MAX {
753 None
754 } else {
755 Some(ep_value)
756 };
757
758 let node_count = u64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap()) as usize;
760 pos += 8;
761
762 let mut nodes = HashMap::new();
763 let mut max_id = 0u64;
764
765 for _ in 0..node_count {
766 let id = u64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap());
767 pos += 8;
768 max_id = max_id.max(id);
769
770 let level = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
771 pos += 4;
772
773 let mut vector = Vec::with_capacity(dimension);
774 for _ in 0..dimension {
775 vector.push(f32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()));
776 pos += 4;
777 }
778
779 let mut connections = vec![Vec::new(); level + 1];
780 for conn_list in connections.iter_mut().take(level + 1) {
781 let conn_count =
782 u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
783 pos += 4;
784
785 for _ in 0..conn_count {
786 let conn = u64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap());
787 pos += 8;
788 conn_list.push(conn);
789 }
790 }
791
792 nodes.insert(
793 id,
794 HnswNode {
795 id,
796 max_layer: level,
797 vector,
798 connections,
799 },
800 );
801 }
802
803 Ok(Self {
804 config,
805 nodes,
806 entry_point,
807 max_layer,
808 dimension,
809 next_id: AtomicU64::new(max_id + 1),
810 rng_state: 12345, })
812 }
813}
814
815#[derive(Debug, Clone)]
817pub struct HnswStats {
818 pub node_count: usize,
820 pub dimension: usize,
822 pub max_layer: usize,
824 pub layer_counts: Vec<usize>,
826 pub total_connections: usize,
828 pub avg_connections: f64,
830 pub max_connections: usize,
832 pub min_connections: usize,
834 pub entry_point: Option<NodeId>,
836}
837
838#[derive(Debug, Clone)]
840pub struct Bitset {
841 bits: Vec<u64>,
842 len: usize,
843}
844
845impl Bitset {
846 pub fn with_capacity(n: usize) -> Self {
848 let num_words = n.div_ceil(64);
849 Self {
850 bits: vec![0; num_words],
851 len: n,
852 }
853 }
854
855 pub fn all(n: usize) -> Self {
857 let num_words = n.div_ceil(64);
858 let mut bits = vec![u64::MAX; num_words];
859
860 if !n.is_multiple_of(64) {
862 let last_idx = num_words - 1;
863 let valid_bits = n % 64;
864 bits[last_idx] = (1u64 << valid_bits) - 1;
865 }
866
867 Self { bits, len: n }
868 }
869
870 pub fn set(&mut self, idx: usize) {
872 if idx < self.len {
873 let word = idx / 64;
874 let bit = idx % 64;
875 self.bits[word] |= 1u64 << bit;
876 }
877 }
878
879 pub fn clear(&mut self, idx: usize) {
881 if idx < self.len {
882 let word = idx / 64;
883 let bit = idx % 64;
884 self.bits[word] &= !(1u64 << bit);
885 }
886 }
887
888 pub fn is_set(&self, idx: usize) -> bool {
890 if idx >= self.len {
891 return false;
892 }
893 let word = idx / 64;
894 let bit = idx % 64;
895 (self.bits[word] & (1u64 << bit)) != 0
896 }
897
898 pub fn to_hashset(&self) -> HashSet<NodeId> {
900 let mut set = HashSet::new();
901 for i in 0..self.len {
902 if self.is_set(i) {
903 set.insert(i as NodeId);
904 }
905 }
906 set
907 }
908}
909
910#[cfg(test)]
911mod tests {
912 use super::*;
913
914 fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
915 let mut state = seed;
916 (0..dim)
917 .map(|_| {
918 state ^= state << 13;
919 state ^= state >> 7;
920 state ^= state << 17;
921 (state as f32) / (u64::MAX as f32)
922 })
923 .collect()
924 }
925
926 #[test]
927 fn test_empty_index() {
928 let index = HnswIndex::with_dimension(128);
929 assert!(index.is_empty());
930 assert_eq!(index.len(), 0);
931
932 let results = index.search(&vec![0.0; 128], 10);
933 assert!(results.is_empty());
934 }
935
936 #[test]
937 fn test_single_insert() {
938 let mut index = HnswIndex::with_dimension(3);
939 let id = index.insert(vec![1.0, 2.0, 3.0]);
940
941 assert_eq!(index.len(), 1);
942 assert!(!index.is_empty());
943 assert!(index.get_vector(id).is_some());
944 }
945
946 #[test]
947 fn test_exact_match() {
948 let mut index = HnswIndex::with_dimension(3);
949 index.insert(vec![1.0, 0.0, 0.0]);
950 index.insert(vec![0.0, 1.0, 0.0]);
951 index.insert(vec![0.0, 0.0, 1.0]);
952
953 let results = index.search(&[1.0, 0.0, 0.0], 1);
955 assert_eq!(results.len(), 1);
956 assert_eq!(results[0].distance, 0.0);
957 }
958
959 #[test]
960 fn test_nearest_neighbor() {
961 let mut index = HnswIndex::with_dimension(2);
962 index.insert_with_id(0, vec![0.0, 0.0]);
963 index.insert_with_id(1, vec![1.0, 0.0]);
964 index.insert_with_id(2, vec![2.0, 0.0]);
965 index.insert_with_id(3, vec![3.0, 0.0]);
966
967 let results = index.search(&[0.9, 0.0], 1);
969 assert_eq!(results.len(), 1);
970 assert_eq!(results[0].id, 1); }
972
973 #[test]
974 fn test_k_nearest() {
975 let mut index = HnswIndex::with_dimension(2);
976 for i in 0..10 {
977 index.insert_with_id(i, vec![i as f32, 0.0]);
978 }
979
980 let results = index.search(&[4.5, 0.0], 3);
982 assert_eq!(results.len(), 3);
983
984 let ids: HashSet<_> = results.iter().map(|r| r.id).collect();
986 assert!(ids.contains(&4));
987 assert!(ids.contains(&5));
988 }
989
990 #[test]
991 fn test_filtered_search() {
992 let mut index = HnswIndex::with_dimension(2);
993 for i in 0..10 {
994 index.insert_with_id(i, vec![i as f32, 0.0]);
995 }
996
997 let filter: HashSet<NodeId> = [0, 2, 4, 6, 8].iter().copied().collect();
999
1000 let results = index.search_filtered(&[5.0, 0.0], 2, &filter);
1002
1003 assert_eq!(results.len(), 2);
1005 let ids: HashSet<_> = results.iter().map(|r| r.id).collect();
1006 assert!(ids.contains(&4));
1007 assert!(ids.contains(&6));
1008 }
1009
1010 #[test]
1011 fn test_cosine_distance() {
1012 let config = HnswConfig::default().with_metric(DistanceMetric::Cosine);
1013 let mut index = HnswIndex::new(3, config);
1014
1015 index.insert_with_id(0, vec![1.0, 0.0, 0.0]);
1017 index.insert_with_id(1, vec![0.0, 1.0, 0.0]);
1018 index.insert_with_id(2, vec![0.707, 0.707, 0.0]); let results = index.search(&[0.707, 0.707, 0.0], 1);
1022 assert_eq!(results[0].id, 2);
1023 }
1024
1025 #[test]
1026 fn test_many_vectors() {
1027 let dim = 64;
1028 let n: usize = 1000;
1029
1030 let mut index = HnswIndex::with_dimension(dim);
1031
1032 for i in 0..n {
1034 let vector = random_vector(dim, i as u64);
1035 index.insert_with_id(i as u64, vector);
1036 }
1037
1038 assert_eq!(index.len(), n);
1039
1040 let query = random_vector(dim, 12345);
1042 let results = index.search(&query, 10);
1043 assert_eq!(results.len(), 10);
1044
1045 for i in 1..results.len() {
1047 assert!(results[i].distance >= results[i - 1].distance);
1048 }
1049 }
1050
1051 #[test]
1052 fn test_bitset() {
1053 let mut bs = Bitset::with_capacity(100);
1054
1055 bs.set(0);
1056 bs.set(50);
1057 bs.set(99);
1058
1059 assert!(bs.is_set(0));
1060 assert!(bs.is_set(50));
1061 assert!(bs.is_set(99));
1062 assert!(!bs.is_set(1));
1063 assert!(!bs.is_set(64));
1064
1065 bs.clear(50);
1066 assert!(!bs.is_set(50));
1067 }
1068
1069 #[test]
1070 fn test_bitset_all() {
1071 let bs = Bitset::all(100);
1072
1073 for i in 0..100 {
1074 assert!(bs.is_set(i));
1075 }
1076 assert!(!bs.is_set(100)); }
1078}