1use scirs2_core::random::{Random, Rng, StdRng};
9#[cfg(test)]
10use std::collections::HashMap;
11use std::collections::{BinaryHeap, HashSet};
12
13pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
19 a.iter()
20 .zip(b.iter())
21 .map(|(x, y)| (x - y).powi(2))
22 .sum::<f32>()
23 .sqrt()
24}
25
26pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
28 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
29 let norm_a: f32 = a.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
30 let norm_b: f32 = b.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
31 if norm_a == 0.0 || norm_b == 0.0 {
32 0.0
33 } else {
34 (dot / (norm_a * norm_b)).clamp(0.0, 1.0)
35 }
36}
37
38pub fn random_level(m_l: f64, rng_val: f64) -> usize {
43 if rng_val <= 0.0 {
44 return 0;
45 }
46 (-rng_val.ln() * m_l).floor() as usize
47}
48
49#[derive(Debug, Clone)]
55pub struct HnswConfig {
56 pub m: usize,
58 pub m_max: usize,
60 pub ef_construction: usize,
62 pub m_l: f64,
64}
65
66impl Default for HnswConfig {
67 fn default() -> Self {
68 let m = 16_usize;
69 HnswConfig {
70 m,
71 m_max: 32,
72 ef_construction: 200,
73 m_l: 1.0 / (m as f64).ln(),
74 }
75 }
76}
77
78impl HnswConfig {
79 pub fn new(m: usize, ef_construction: usize) -> Self {
81 HnswConfig {
82 m,
83 m_max: m * 2,
84 ef_construction,
85 m_l: 1.0 / (m.max(2) as f64).ln(),
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
98pub struct HnswNode {
99 pub id: usize,
100 pub vector: Vec<f32>,
101 pub connections: Vec<Vec<usize>>,
103}
104
105impl HnswNode {
106 fn new(id: usize, vector: Vec<f32>, max_layer: usize) -> Self {
107 HnswNode {
108 id,
109 vector,
110 connections: vec![Vec::new(); max_layer + 1],
111 }
112 }
113
114 fn ensure_layers(&mut self, layers: usize) {
115 while self.connections.len() <= layers {
116 self.connections.push(Vec::new());
117 }
118 }
119}
120
121pub struct HnswGraph {
127 pub nodes: Vec<HnswNode>,
128 pub entry_point: Option<usize>,
129 pub max_layer: usize,
130 config: HnswConfig,
131 rng: StdRng,
133}
134
135impl HnswGraph {
136 pub fn new(config: HnswConfig) -> Self {
138 HnswGraph {
139 nodes: Vec::new(),
140 entry_point: None,
141 max_layer: 0,
142 config,
143 rng: Random::seed(42),
144 }
145 }
146
147 pub fn insert(&mut self, id: usize, vector: Vec<f32>) {
151 let rng_val: f64 = self.rng.random::<f64>();
152 let node_layer = random_level(self.config.m_l, rng_val);
153
154 let mut node = HnswNode::new(id, vector.clone(), node_layer);
155 node.ensure_layers(node_layer);
156
157 let node_idx = self.nodes.len();
158
159 match self.entry_point {
160 None => {
161 self.entry_point = Some(node_idx);
163 self.max_layer = node_layer;
164 self.nodes.push(node);
165 return;
166 }
167 Some(ep) => {
168 let mut ep_idx = ep;
170 let current_top = self.max_layer;
171
172 if current_top > node_layer {
173 for lc in (node_layer + 1..=current_top).rev() {
174 ep_idx = self.greedy_search_layer(ep_idx, &vector, lc);
175 }
176 }
177
178 for lc in (0..=node_layer.min(current_top)).rev() {
180 let candidates =
181 self.search_layer_ef(ep_idx, &vector, self.config.ef_construction, lc);
182 let neighbours = self.select_neighbours(&candidates, self.config.m);
183
184 for &nb_idx in &neighbours {
186 if nb_idx < self.nodes.len() {
187 self.nodes[nb_idx].ensure_layers(lc);
188 self.nodes[nb_idx].connections[lc].push(node_idx);
189 let nb_vec = self.nodes[nb_idx].vector.clone();
191 self.shrink_connections(nb_idx, lc, &nb_vec);
192 }
193 }
194 node.ensure_layers(lc);
195 node.connections[lc] = neighbours.clone();
196
197 if !candidates.is_empty() {
199 ep_idx = candidates[0].0;
200 }
201 }
202
203 if node_layer > current_top {
205 self.max_layer = node_layer;
206 self.entry_point = Some(node_idx);
207 }
208 }
209 }
210 self.nodes.push(node);
211 }
212
213 pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<(usize, f32)> {
218 if self.nodes.is_empty() {
219 return Vec::new();
220 }
221 let ep = match self.entry_point {
222 Some(e) => e,
223 None => return Vec::new(),
224 };
225
226 let mut ep_idx = ep;
227 for lc in (1..=self.max_layer).rev() {
229 ep_idx = self.greedy_search_layer(ep_idx, query, lc);
230 }
231
232 let candidates = self.search_layer_ef(ep_idx, query, ef.max(k), 0);
234
235 let mut results: Vec<(usize, f32)> = candidates
237 .iter()
238 .take(k)
239 .map(|&(idx, dist)| (self.nodes[idx].id, dist))
240 .collect();
241 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
242 results.truncate(k);
243 results
244 }
245
246 pub fn node_count(&self) -> usize {
248 self.nodes.len()
249 }
250
251 pub fn layer_count(&self) -> usize {
253 if self.nodes.is_empty() {
254 0
255 } else {
256 self.max_layer + 1
257 }
258 }
259
260 pub fn connections_at(&self, id: usize, layer: usize) -> Option<&Vec<usize>> {
262 let node = self.nodes.iter().find(|n| n.id == id)?;
263 node.connections.get(layer)
264 }
265
266 fn greedy_search_layer(&self, mut ep_idx: usize, query: &[f32], layer: usize) -> usize {
271 let mut best_dist = euclidean_distance(&self.nodes[ep_idx].vector, query);
272 loop {
273 let mut improved = false;
274 let conns: Vec<usize> = if layer < self.nodes[ep_idx].connections.len() {
275 self.nodes[ep_idx].connections[layer].clone()
276 } else {
277 Vec::new()
278 };
279 for nb_idx in conns {
280 if nb_idx < self.nodes.len() {
281 let d = euclidean_distance(&self.nodes[nb_idx].vector, query);
282 if d < best_dist {
283 best_dist = d;
284 ep_idx = nb_idx;
285 improved = true;
286 }
287 }
288 }
289 if !improved {
290 break;
291 }
292 }
293 ep_idx
294 }
295
296 fn search_layer_ef(
299 &self,
300 ep_idx: usize,
301 query: &[f32],
302 ef: usize,
303 layer: usize,
304 ) -> Vec<(usize, f32)> {
305 let ep_dist = euclidean_distance(&self.nodes[ep_idx].vector, query);
310
311 let mut candidates: BinaryHeap<OrdPair> = BinaryHeap::new();
314 let mut result: BinaryHeap<OrdPair> = BinaryHeap::new(); let mut visited: HashSet<usize> = HashSet::new();
316
317 candidates.push(OrdPair(ep_dist, ep_idx));
318 result.push(OrdPair(ep_dist, ep_idx));
319 visited.insert(ep_idx);
320
321 while let Some(OrdPair(dist, idx)) = pop_min(&mut candidates) {
322 if let Some(OrdPair(worst_dist, _)) = result.peek() {
324 if dist > *worst_dist && result.len() >= ef {
325 break;
326 }
327 }
328 let conns: Vec<usize> = if layer < self.nodes[idx].connections.len() {
330 self.nodes[idx].connections[layer].clone()
331 } else {
332 Vec::new()
333 };
334 for nb_idx in conns {
335 if nb_idx >= self.nodes.len() || visited.contains(&nb_idx) {
336 continue;
337 }
338 visited.insert(nb_idx);
339 let d = euclidean_distance(&self.nodes[nb_idx].vector, query);
340 let add = result.len() < ef || d < result.peek().map_or(f32::MAX, |p| p.0);
342 if add {
343 candidates.push(OrdPair(d, nb_idx));
344 result.push(OrdPair(d, nb_idx));
345 while result.len() > ef {
347 result.pop();
348 }
349 }
350 }
351 }
352
353 let mut out: Vec<(usize, f32)> = result.into_iter().map(|p| (p.1, p.0)).collect();
355 out.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
356 out
357 }
358
359 fn select_neighbours(&self, candidates: &[(usize, f32)], m: usize) -> Vec<usize> {
361 candidates.iter().take(m).map(|&(idx, _)| idx).collect()
362 }
363
364 fn shrink_connections(&mut self, node_idx: usize, layer: usize, node_vec: &[f32]) {
366 if layer >= self.nodes[node_idx].connections.len() {
367 return;
368 }
369 let m_max = self.config.m_max;
370 if self.nodes[node_idx].connections[layer].len() <= m_max {
371 return;
372 }
373 let mut conn_dists: Vec<(usize, f32)> = self.nodes[node_idx].connections[layer]
375 .iter()
376 .filter_map(|&nb| {
377 if nb < self.nodes.len() {
378 let d = euclidean_distance(&self.nodes[nb].vector, node_vec);
379 Some((nb, d))
380 } else {
381 None
382 }
383 })
384 .collect();
385 conn_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
386 conn_dists.truncate(m_max);
387 self.nodes[node_idx].connections[layer] =
388 conn_dists.into_iter().map(|(nb, _)| nb).collect();
389 }
390}
391
392#[derive(Debug, Clone, PartialEq)]
399struct OrdPair(f32, usize);
400
401impl Eq for OrdPair {}
402
403impl PartialOrd for OrdPair {
404 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
405 Some(self.cmp(other))
406 }
407}
408
409impl Ord for OrdPair {
410 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
411 other
413 .0
414 .partial_cmp(&self.0)
415 .unwrap_or(std::cmp::Ordering::Equal)
416 .then(self.1.cmp(&other.1))
417 }
418}
419
420fn pop_min(heap: &mut BinaryHeap<OrdPair>) -> Option<OrdPair> {
423 heap.pop()
424}
425
426#[cfg(test)]
431mod tests {
432 use super::*;
433
434 fn vec2(x: f32, y: f32) -> Vec<f32> {
435 vec![x, y]
436 }
437
438 #[test]
441 fn test_euclidean_distance_zero() {
442 let a = vec![1.0_f32, 2.0, 3.0];
443 assert_eq!(euclidean_distance(&a, &a), 0.0);
444 }
445
446 #[test]
447 fn test_euclidean_distance_unit() {
448 let a = vec![0.0_f32, 0.0];
449 let b = vec![3.0_f32, 4.0];
450 let d = euclidean_distance(&a, &b);
451 assert!((d - 5.0).abs() < 1e-5);
452 }
453
454 #[test]
455 fn test_euclidean_distance_symmetric() {
456 let a = vec![1.0_f32, 2.0, 3.0];
457 let b = vec![4.0_f32, 5.0, 6.0];
458 assert!((euclidean_distance(&a, &b) - euclidean_distance(&b, &a)).abs() < 1e-6);
459 }
460
461 #[test]
462 fn test_cosine_similarity_identical() {
463 let a = vec![1.0_f32, 0.0, 0.0];
464 assert!((cosine_similarity(&a, &a) - 1.0).abs() < 1e-6);
465 }
466
467 #[test]
468 fn test_cosine_similarity_orthogonal() {
469 let a = vec![1.0_f32, 0.0];
470 let b = vec![0.0_f32, 1.0];
471 assert!(cosine_similarity(&a, &b).abs() < 1e-6);
472 }
473
474 #[test]
475 fn test_cosine_similarity_range() {
476 let a = vec![0.6_f32, 0.8];
477 let b = vec![0.8_f32, 0.6];
478 let s = cosine_similarity(&a, &b);
479 assert!((0.0..=1.0).contains(&s));
480 }
481
482 #[test]
483 fn test_cosine_similarity_zero_vector() {
484 let a = vec![0.0_f32, 0.0];
485 let b = vec![1.0_f32, 0.0];
486 assert_eq!(cosine_similarity(&a, &b), 0.0);
487 }
488
489 #[test]
492 fn test_random_level_near_zero_returns_zero() {
493 let level = random_level(1.0 / (16.0_f64).ln(), 0.999);
494 assert_eq!(level, 0);
495 }
496
497 #[test]
498 fn test_random_level_small_value_high_level() {
499 let level = random_level(1.0 / (16.0_f64).ln(), 1e-10);
501 assert!(level > 0);
502 }
503
504 #[test]
505 fn test_random_level_distribution() {
506 let m_l = 1.0 / (16.0_f64).ln();
508 let mut rng = Random::seed(0);
509 let mut counts: HashMap<usize, usize> = HashMap::new();
510 for _ in 0..1000 {
511 let v: f64 = rng.random::<f64>();
512 let level = random_level(m_l, v);
513 *counts.entry(level).or_insert(0) += 1;
514 }
515 let count_0 = counts.get(&0).copied().unwrap_or(0);
517 assert!(count_0 > 500, "Level 0 should dominate; got {count_0}");
518 }
519
520 #[test]
523 fn test_config_default_values() {
524 let cfg = HnswConfig::default();
525 assert_eq!(cfg.m, 16);
526 assert_eq!(cfg.m_max, 32);
527 assert_eq!(cfg.ef_construction, 200);
528 assert!(cfg.m_l > 0.0);
529 }
530
531 #[test]
532 fn test_config_new() {
533 let cfg = HnswConfig::new(8, 100);
534 assert_eq!(cfg.m, 8);
535 assert_eq!(cfg.m_max, 16);
536 assert_eq!(cfg.ef_construction, 100);
537 }
538
539 #[test]
542 fn test_insert_single_node_entry_point_set() {
543 let mut g = HnswGraph::new(HnswConfig::default());
544 g.insert(0, vec2(1.0, 0.0));
545 assert_eq!(g.entry_point, Some(0));
546 assert_eq!(g.node_count(), 1);
547 }
548
549 #[test]
550 fn test_insert_single_node_layer_count() {
551 let mut g = HnswGraph::new(HnswConfig::default());
552 g.insert(0, vec2(0.0, 0.0));
553 assert!(g.layer_count() >= 1);
554 }
555
556 #[test]
559 fn test_insert_multiple_increases_node_count() {
560 let mut g = HnswGraph::new(HnswConfig::default());
561 for i in 0..10_u32 {
562 g.insert(i as usize, vec![i as f32, 0.0]);
563 }
564 assert_eq!(g.node_count(), 10);
565 }
566
567 #[test]
568 fn test_entry_point_set_after_first_insert() {
569 let mut g = HnswGraph::new(HnswConfig::default());
570 g.insert(42, vec![1.0, 2.0]);
571 assert!(g.entry_point.is_some());
572 }
573
574 #[test]
577 fn test_search_empty_graph_returns_empty() {
578 let g = HnswGraph::new(HnswConfig::default());
579 let results = g.search(&[0.0, 0.0], 3, 10);
580 assert!(results.is_empty());
581 }
582
583 #[test]
584 fn test_search_single_node_returns_it() {
585 let mut g = HnswGraph::new(HnswConfig::new(4, 50));
586 g.insert(0, vec2(1.0, 0.0));
587 let results = g.search(&[1.0, 0.0], 1, 10);
588 assert!(!results.is_empty());
589 assert_eq!(results[0].0, 0);
590 }
591
592 #[test]
593 fn test_search_returns_at_most_k_results() {
594 let mut g = HnswGraph::new(HnswConfig::new(4, 50));
595 for i in 0..20_u32 {
596 g.insert(i as usize, vec![i as f32, 0.0]);
597 }
598 let results = g.search(&[5.0, 0.0], 5, 20);
599 assert!(results.len() <= 5);
600 }
601
602 #[test]
603 fn test_search_results_ordered_by_distance() {
604 let mut g = HnswGraph::new(HnswConfig::new(4, 50));
605 for i in 0..10_u32 {
606 g.insert(i as usize, vec![i as f32, 0.0]);
607 }
608 let query = vec![4.5, 0.0];
609 let results = g.search(&query, 5, 20);
610 for w in results.windows(2) {
612 assert!(w[0].1 <= w[1].1 + 1e-5, "Results not sorted: {:?}", results);
613 }
614 }
615
616 #[test]
617 fn test_search_nearest_is_closest() {
618 let mut g = HnswGraph::new(HnswConfig::new(4, 50));
619 g.insert(0, vec2(0.0, 0.0));
621 g.insert(1, vec2(100.0, 0.0));
622 g.insert(2, vec2(0.0, 100.0));
623 let results = g.search(&[1.0, 1.0], 1, 10);
624 assert!(!results.is_empty());
625 assert_eq!(results[0].0, 0); }
627
628 #[test]
631 fn test_layer_count_non_zero_after_insert() {
632 let mut g = HnswGraph::new(HnswConfig::default());
633 g.insert(0, vec![1.0]);
634 assert!(g.layer_count() >= 1);
635 }
636
637 #[test]
638 fn test_layer_count_zero_when_empty() {
639 let g = HnswGraph::new(HnswConfig::default());
640 assert_eq!(g.layer_count(), 0);
641 }
642
643 #[test]
646 fn test_connections_at_returns_none_for_unknown_id() {
647 let mut g = HnswGraph::new(HnswConfig::new(4, 50));
648 g.insert(0, vec2(1.0, 0.0));
649 assert!(g.connections_at(99, 0).is_none());
651 }
652
653 #[test]
654 fn test_connections_at_returns_some_for_inserted_node() {
655 let mut g = HnswGraph::new(HnswConfig::new(4, 50));
656 g.insert(0, vec2(0.0, 0.0));
657 assert!(g.connections_at(0, 0).is_some());
659 }
660
661 #[test]
664 fn test_exact_search_3_nodes() {
665 let mut g = HnswGraph::new(HnswConfig::new(2, 20));
666 g.insert(0, vec2(0.0, 0.0));
667 g.insert(1, vec2(1.0, 0.0));
668 g.insert(2, vec2(10.0, 0.0));
669
670 let results = g.search(&[0.1, 0.0], 3, 10);
672 assert!(!results.is_empty());
673 let nearest = results[0].0;
675 assert!(
676 nearest == 0 || nearest == 1,
677 "Expected 0 or 1, got {nearest}"
678 );
679 }
680
681 #[test]
684 fn test_hnsw_node_new() {
685 let n = HnswNode::new(5, vec![1.0, 2.0], 2);
686 assert_eq!(n.id, 5);
687 assert_eq!(n.vector, vec![1.0, 2.0]);
688 assert_eq!(n.connections.len(), 3); }
690
691 #[test]
692 fn test_hnsw_node_ensure_layers() {
693 let mut n = HnswNode::new(0, vec![1.0], 0);
694 n.ensure_layers(3);
695 assert!(n.connections.len() >= 4);
696 }
697
698 #[test]
701 fn test_search_reproducible() {
702 let mut g = HnswGraph::new(HnswConfig::new(4, 50));
703 for i in 0..15_u32 {
704 g.insert(i as usize, vec![(i as f32) * 0.1, 0.0]);
705 }
706 let r1 = g.search(&[0.5, 0.0], 3, 10);
707 let r2 = g.search(&[0.5, 0.0], 3, 10);
708 assert_eq!(r1.len(), r2.len());
709 for (a, b) in r1.iter().zip(r2.iter()) {
710 assert_eq!(a.0, b.0);
711 }
712 }
713
714 #[test]
715 fn test_search_returns_k_or_fewer() {
716 let mut g = HnswGraph::new(HnswConfig::new(4, 50));
717 for i in 0..5_u32 {
718 g.insert(i as usize, vec![i as f32]);
719 }
720 let results = g.search(&[2.0], 10, 10);
721 assert!(results.len() <= 5);
723 }
724
725 #[test]
726 fn test_distances_non_negative() {
727 let mut g = HnswGraph::new(HnswConfig::new(4, 50));
728 for i in 0..8_u32 {
729 g.insert(i as usize, vec![i as f32, (8 - i) as f32]);
730 }
731 let results = g.search(&[4.0, 4.0], 5, 20);
732 for (_, dist) in &results {
733 assert!(*dist >= 0.0);
734 }
735 }
736}