1use std::cmp::Ordering;
63use std::collections::{BinaryHeap, HashMap, HashSet};
64
65use turbovec::codebook::codebook;
66use turbovec::rotation::make_rotation_matrix;
67
68use crate::distance::Distance;
69use crate::index::{HnswParams, IndexError, NodeId, SearchResult};
70
71pub trait CodecDistance {
82 fn distance(&self, a: NodeId, b: NodeId) -> f32;
89}
90
91#[derive(Clone, Debug)]
93struct TurboHnswNode {
94 id: NodeId,
95 levels: Vec<Vec<usize>>,
98 deleted: bool,
102}
103
104impl TurboHnswNode {
105 fn level(&self) -> usize {
106 self.levels.len().saturating_sub(1)
107 }
108}
109
110pub struct TurboHnswIndex<const BITS: u8> {
116 distance: Distance,
118 dim: u16,
120 params: HnswParams,
123
124 rotation: Vec<f32>,
127 boundaries: Vec<f32>,
130 centroids: Vec<f32>,
133
134 packed: Vec<u8>,
139 scales: Vec<f32>,
142
143 nodes: Vec<TurboHnswNode>,
146 id_to_idx: HashMap<NodeId, usize>,
149 entry: Option<usize>,
152 rng_state: u64,
154 ml: f64,
157}
158
159#[derive(Clone, Copy, Debug)]
161struct Candidate {
162 idx: usize,
163 score: f32,
164}
165
166impl PartialEq for Candidate {
167 fn eq(&self, other: &Self) -> bool {
168 self.score == other.score
169 }
170}
171impl Eq for Candidate {}
172impl PartialOrd for Candidate {
173 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
174 Some(self.cmp(other))
175 }
176}
177impl Ord for Candidate {
178 fn cmp(&self, other: &Self) -> Ordering {
179 other
181 .score
182 .partial_cmp(&self.score)
183 .unwrap_or(Ordering::Equal)
184 }
185}
186
187#[derive(Clone, Copy, Debug)]
190struct MaxCandidate {
191 idx: usize,
192 score: f32,
193}
194
195impl PartialEq for MaxCandidate {
196 fn eq(&self, other: &Self) -> bool {
197 self.score == other.score
198 }
199}
200impl Eq for MaxCandidate {}
201impl PartialOrd for MaxCandidate {
202 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
203 Some(self.cmp(other))
204 }
205}
206impl Ord for MaxCandidate {
207 fn cmp(&self, other: &Self) -> Ordering {
208 self.score
209 .partial_cmp(&other.score)
210 .unwrap_or(Ordering::Equal)
211 }
212}
213
214impl<const BITS: u8> TurboHnswIndex<BITS> {
215 pub fn new(distance: Distance, dim: u16, params: HnswParams) -> Result<Self, IndexError> {
227 if !(2..=4).contains(&BITS) {
228 return Err(IndexError::Empty);
229 }
230 if dim == 0 {
231 return Err(IndexError::Empty);
232 }
233 if !dim.is_multiple_of(8) {
234 return Err(IndexError::DimensionMismatch {
235 expected: ((dim / 8) + 1) * 8,
236 got: dim,
237 });
238 }
239 let dim_usize = usize::from(dim);
240 let bits_usize = usize::from(BITS);
241 let rotation = make_rotation_matrix(dim_usize);
242 let (boundaries, centroids) = codebook(bits_usize, dim_usize);
243 let ml = if params.m > 1 {
244 1.0 / f64::from(u32::try_from(params.m).unwrap_or(u32::MAX)).ln()
245 } else {
246 1.0
247 };
248 Ok(Self {
249 distance,
250 dim,
251 params,
252 rotation,
253 boundaries,
254 centroids,
255 packed: Vec::new(),
256 scales: Vec::new(),
257 nodes: Vec::new(),
258 id_to_idx: HashMap::new(),
259 entry: None,
260 rng_state: params.seed,
261 ml,
262 })
263 }
264
265 #[must_use]
267 pub fn len(&self) -> usize {
268 self.nodes.iter().filter(|n| !n.deleted).count()
269 }
270
271 #[must_use]
273 pub fn is_empty(&self) -> bool {
274 self.len() == 0
275 }
276
277 #[must_use]
279 pub fn dim(&self) -> u16 {
280 self.dim
281 }
282
283 #[must_use]
285 pub fn distance_metric(&self) -> Distance {
286 self.distance
287 }
288
289 #[must_use]
291 pub fn bits(&self) -> u8 {
292 BITS
293 }
294
295 #[must_use]
297 pub fn contains(&self, id: NodeId) -> bool {
298 self.id_to_idx
299 .get(&id)
300 .is_some_and(|&idx| !self.nodes[idx].deleted)
301 }
302
303 pub fn insert(&mut self, id: NodeId, vector: Vec<f32>) -> Result<(), IndexError> {
312 if vector.is_empty() {
313 return Err(IndexError::Empty);
314 }
315 let got = u16::try_from(vector.len()).unwrap_or(u16::MAX);
316 if got != self.dim {
317 return Err(IndexError::DimensionMismatch {
318 expected: self.dim,
319 got,
320 });
321 }
322 if self.id_to_idx.contains_key(&id) {
323 return Err(IndexError::Duplicate(id));
324 }
325 let prepared = match self.distance {
330 Distance::Cosine | Distance::Euclidean => l2_normalise(&vector),
331 Distance::DotProduct => vector,
332 };
333 for v in &prepared {
337 if !v.is_finite() || v.abs() >= 1e16_f32 {
338 return Err(IndexError::Empty);
339 }
340 }
341 let dim_usize = usize::from(self.dim);
342 let bytes_per_vec = self.bytes_per_vec();
343 let (packed, scale) = self.encode_one(&prepared);
344 debug_assert_eq!(packed.len(), bytes_per_vec);
345 let _ = dim_usize;
346
347 self.packed.extend_from_slice(&packed);
351 self.scales.push(scale);
352
353 let level = self.random_level();
354 let mut levels: Vec<Vec<usize>> = Vec::with_capacity(level + 1);
355 for _ in 0..=level {
356 levels.push(Vec::new());
357 }
358
359 let new_idx = self.nodes.len();
360 self.nodes.push(TurboHnswNode {
361 id,
362 levels,
363 deleted: false,
364 });
365 self.id_to_idx.insert(id, new_idx);
366
367 let Some(entry) = self.entry else {
368 self.entry = Some(new_idx);
369 return Ok(());
370 };
371 let entry_level = self.nodes[entry].level();
372
373 let q_rot = self.rotate(&prepared);
376 let mut current = entry;
377 if entry_level > level {
378 for lc in (level + 1..=entry_level).rev() {
379 current = self.greedy_search_layer(&q_rot, current, lc, new_idx);
380 }
381 }
382
383 let start_layer = level.min(entry_level);
387 let mut entry_points = vec![current];
388 for lc in (0..=start_layer).rev() {
389 let neighbours = self.search_layer(
390 &q_rot,
391 &entry_points,
392 lc,
393 self.params.ef_construction,
394 Some(new_idx),
395 );
396 let m = if lc == 0 {
397 self.params.m0
398 } else {
399 self.params.m
400 };
401 let selected = Self::select_neighbours(&neighbours, m);
402 for &nb in &selected {
403 self.nodes[new_idx].levels[lc].push(nb);
404 self.nodes[nb].levels[lc].push(new_idx);
405 let cap = if lc == 0 {
406 self.params.m0
407 } else {
408 self.params.m
409 };
410 if self.nodes[nb].levels[lc].len() > cap {
411 self.shrink_connections(nb, lc, cap);
412 }
413 }
414 entry_points = selected;
415 if entry_points.is_empty() {
416 entry_points = vec![current];
417 }
418 }
419
420 if level > entry_level {
421 self.entry = Some(new_idx);
422 }
423 Ok(())
424 }
425
426 pub fn delete(&mut self, id: NodeId) -> bool {
429 let Some(&idx) = self.id_to_idx.get(&id) else {
430 return false;
431 };
432 if self.nodes[idx].deleted {
433 return false;
434 }
435 self.nodes[idx].deleted = true;
436 true
437 }
438
439 pub fn search(
448 &self,
449 query: &[f32],
450 k: usize,
451 ef: Option<usize>,
452 ) -> Result<Vec<SearchResult>, IndexError> {
453 if query.is_empty() {
454 return Ok(Vec::new());
455 }
456 if self.nodes.is_empty() {
457 return Ok(Vec::new());
458 }
459 let got = u16::try_from(query.len()).unwrap_or(u16::MAX);
460 if got != self.dim {
461 return Err(IndexError::DimensionMismatch {
462 expected: self.dim,
463 got,
464 });
465 }
466
467 let prepared = match self.distance {
468 Distance::Cosine | Distance::Euclidean => l2_normalise(query),
469 Distance::DotProduct => query.to_vec(),
470 };
471 let q_rot = self.rotate(&prepared);
472
473 let mut entry = self.entry.unwrap_or(0);
474 let entry_level = self.nodes[entry].level();
475 let ef = ef.unwrap_or(self.params.ef_search).max(k);
476
477 for lc in (1..=entry_level).rev() {
478 entry = self.greedy_search_layer(&q_rot, entry, lc, usize::MAX);
479 }
480
481 let candidates = self.search_layer(&q_rot, &[entry], 0, ef, None);
482
483 let mut sorted = candidates;
484 sorted.sort_by(|a, b| {
485 a.score
486 .partial_cmp(&b.score)
487 .unwrap_or(std::cmp::Ordering::Equal)
488 });
489 Ok(sorted
490 .into_iter()
491 .filter(|c| !self.nodes[c.idx].deleted)
492 .take(k)
493 .map(|c| SearchResult {
494 id: self.nodes[c.idx].id,
495 score: c.score,
496 })
497 .collect())
498 }
499
500 fn bytes_per_vec(&self) -> usize {
503 usize::from(self.dim)
504 }
505
506 fn encode_one(&self, vector: &[f32]) -> (Vec<u8>, f32) {
522 let dim = usize::from(self.dim);
523 let mut norm_sq = 0.0_f32;
525 for &x in vector {
526 norm_sq += x * x;
527 }
528 let norm = norm_sq.sqrt();
529 let inv_norm = if norm > 1e-10 { 1.0 / norm } else { 0.0 };
530 let mut unit = vec![0.0_f32; dim];
531 for (d, slot) in unit.iter_mut().enumerate().take(dim) {
532 *slot = vector[d] * inv_norm;
533 }
534 let u_rot = self.rotate(&unit);
536 let mut packed = vec![0_u8; dim];
540 let mut inner = 0.0_f32;
541 for (j, &uj) in u_rot.iter().enumerate().take(dim) {
542 let mut code = 0_u8;
543 for &b in &self.boundaries {
544 if uj > b {
545 code += 1;
546 }
547 }
548 inner += uj * self.centroids[usize::from(code)];
549 packed[j] = code;
550 }
551 let inner = inner.max(1e-10_f32);
555 let scale = norm / inner;
556 (packed, scale)
557 }
558
559 fn codes(&self, slot: usize) -> &[u8] {
561 let dim = usize::from(self.dim);
562 let row_start = slot * dim;
563 &self.packed[row_start..row_start + dim]
564 }
565
566 fn rotate(&self, q: &[f32]) -> Vec<f32> {
568 let dim = usize::from(self.dim);
569 let mut out = vec![0.0_f32; dim];
570 for (d, slot) in out.iter_mut().enumerate().take(dim) {
571 let row = &self.rotation[d * dim..(d + 1) * dim];
572 let mut sum = 0.0_f32;
573 for (e, &qe) in q.iter().enumerate().take(dim) {
574 sum += row[e] * qe;
575 }
576 *slot = sum;
577 }
578 out
579 }
580
581 fn similarity_query(&self, q_rot: &[f32], slot: usize) -> f32 {
589 let dim = usize::from(self.dim);
590 let codes = self.codes(slot);
591 let centroids = self.centroids.as_slice();
592 let mut acc = 0.0_f32;
593 for d in 0..dim {
594 acc += q_rot[d] * centroids[codes[d] as usize];
595 }
596 acc * self.scales[slot]
597 }
598
599 fn similarity_pair(&self, a: usize, b: usize) -> f32 {
608 let dim = usize::from(self.dim);
609 let ca = self.codes(a);
610 let cb = self.codes(b);
611 let centroids = self.centroids.as_slice();
612 let mut acc = 0.0_f32;
613 for d in 0..dim {
614 acc += centroids[ca[d] as usize] * centroids[cb[d] as usize];
615 }
616 acc * self.scales[a] * self.scales[b]
617 }
618
619 fn similarity_to_distance(&self, similarity: f32) -> f32 {
622 match self.distance {
623 Distance::DotProduct => -similarity,
624 Distance::Cosine => 1.0 - similarity,
625 Distance::Euclidean => (2.0 - 2.0 * similarity).max(0.0).sqrt(),
626 }
627 }
628
629 fn distance_query(&self, q_rot: &[f32], slot: usize) -> f32 {
630 self.similarity_to_distance(self.similarity_query(q_rot, slot))
631 }
632
633 fn distance_pair(&self, a: usize, b: usize) -> f32 {
634 self.similarity_to_distance(self.similarity_pair(a, b))
635 }
636
637 fn greedy_search_layer(
642 &self,
643 q_rot: &[f32],
644 entry: usize,
645 lc: usize,
646 skip_idx: usize,
647 ) -> usize {
648 let mut current = entry;
649 let mut current_score = self.distance_query(q_rot, current);
650 loop {
651 let mut improved = false;
652 let next = if lc < self.nodes[current].levels.len() {
653 let neighbours = self.nodes[current].levels[lc].as_slice();
654 let mut best = (current, current_score);
655 for &nb in neighbours {
656 if nb == skip_idx {
657 continue;
658 }
659 let s = self.distance_query(q_rot, nb);
660 if s < best.1 {
661 best = (nb, s);
662 improved = true;
663 }
664 }
665 best
666 } else {
667 (current, current_score)
668 };
669 current = next.0;
670 current_score = next.1;
671 if !improved {
672 break;
673 }
674 }
675 current
676 }
677
678 fn search_layer(
681 &self,
682 q_rot: &[f32],
683 entry_points: &[usize],
684 lc: usize,
685 ef: usize,
686 skip_idx: Option<usize>,
687 ) -> Vec<MaxCandidate> {
688 let mut visited: HashSet<usize> = HashSet::new();
689 let mut frontier: BinaryHeap<Candidate> = BinaryHeap::new();
690 let mut top: BinaryHeap<MaxCandidate> = BinaryHeap::new();
691 for &ep in entry_points {
692 if Some(ep) == skip_idx {
693 continue;
694 }
695 if visited.insert(ep) {
696 let s = self.distance_query(q_rot, ep);
697 frontier.push(Candidate { idx: ep, score: s });
698 top.push(MaxCandidate { idx: ep, score: s });
699 }
700 }
701 while let Some(c) = frontier.pop() {
702 if top.len() >= ef {
703 if let Some(worst) = top.peek() {
704 if c.score > worst.score {
705 break;
706 }
707 }
708 }
709 if lc < self.nodes[c.idx].levels.len() {
710 let neighbours = self.nodes[c.idx].levels[lc].as_slice();
711 for &nb in neighbours {
712 if Some(nb) == skip_idx {
713 continue;
714 }
715 if !visited.insert(nb) {
716 continue;
717 }
718 let s = self.distance_query(q_rot, nb);
719 let admit = match top.peek() {
720 Some(worst) => s < worst.score || top.len() < ef,
721 None => true,
722 };
723 if admit {
724 frontier.push(Candidate { idx: nb, score: s });
725 top.push(MaxCandidate { idx: nb, score: s });
726 if top.len() > ef {
727 top.pop();
728 }
729 }
730 }
731 }
732 }
733 top.into_vec()
734 }
735
736 fn select_neighbours(candidates: &[MaxCandidate], m: usize) -> Vec<usize> {
738 let mut sorted: Vec<MaxCandidate> = candidates.to_vec();
739 sorted.sort_by(|a, b| {
740 a.score
741 .partial_cmp(&b.score)
742 .unwrap_or(std::cmp::Ordering::Equal)
743 });
744 sorted.into_iter().take(m).map(|c| c.idx).collect()
745 }
746
747 fn shrink_connections(&mut self, idx: usize, lc: usize, cap: usize) {
753 let neighbours = std::mem::take(&mut self.nodes[idx].levels[lc]);
754 let mut scored: Vec<(usize, f32)> = neighbours
755 .into_iter()
756 .map(|nb| {
757 let s = self.distance_pair(idx, nb);
758 (nb, s)
759 })
760 .collect();
761 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
762 scored.truncate(cap);
763 self.nodes[idx].levels[lc] = scored.into_iter().map(|(nb, _)| nb).collect();
764 }
765
766 fn rand_unit(&mut self) -> f64 {
768 let mut x = self.rng_state;
769 x ^= x >> 12;
770 x ^= x << 25;
771 x ^= x >> 27;
772 self.rng_state = x;
773 let r = x.wrapping_mul(0x2545_F491_4F6C_DD1D);
774 let bits = (r >> 11) & ((1_u64 << 53) - 1);
775 #[allow(
777 clippy::cast_precision_loss,
778 reason = "bits is in [0, 2^53), exactly representable as f64"
779 )]
780 let f = (bits as f64) / ((1_u64 << 53) as f64);
781 f
782 }
783
784 fn random_level(&mut self) -> usize {
787 let r = self.rand_unit().max(f64::MIN_POSITIVE);
788 let level = (-r.ln() * self.ml).floor();
789 let clamped = level.clamp(0.0, 16.0);
790 #[allow(
792 clippy::cast_possible_truncation,
793 clippy::cast_sign_loss,
794 reason = "clamped to [0, 16]"
795 )]
796 let lvl = clamped as usize;
797 lvl
798 }
799}
800
801impl<const BITS: u8> CodecDistance for TurboHnswIndex<BITS> {
802 fn distance(&self, a: NodeId, b: NodeId) -> f32 {
803 let Some(&sa) = self.id_to_idx.get(&a) else {
804 return f32::INFINITY;
805 };
806 let Some(&sb) = self.id_to_idx.get(&b) else {
807 return f32::INFINITY;
808 };
809 self.distance_pair(sa, sb)
810 }
811}
812
813fn l2_normalise(v: &[f32]) -> Vec<f32> {
814 let n2: f32 = v.iter().map(|x| x * x).sum();
815 let n = n2.sqrt();
816 if n <= 0.0 {
817 return v.to_vec();
818 }
819 v.iter().map(|x| x / n).collect()
820}
821
822#[cfg(test)]
823mod tests {
824 use super::*;
825
826 fn rand_vec(seed: u64, dim: usize) -> Vec<f32> {
827 let mut x = if seed == 0 { 0xDEAD_BEEF } else { seed };
828 let mut v = Vec::with_capacity(dim);
829 for _ in 0..dim {
830 x ^= x << 13;
831 x ^= x >> 7;
832 x ^= x << 17;
833 let bits = (x >> 11) & ((1_u64 << 53) - 1);
834 #[allow(
835 clippy::cast_precision_loss,
836 clippy::cast_possible_truncation,
837 reason = "test fixture: PRNG narrowed to f32"
838 )]
839 let r = (((bits as f64) / ((1_u64 << 53) as f64)) * 2.0 - 1.0) as f32;
840 v.push(r);
841 }
842 v
843 }
844
845 #[test]
846 fn insert_and_search_returns_self_first_4bit() {
847 let mut idx = TurboHnswIndex::<4>::new(Distance::Cosine, 64, HnswParams::default())
848 .expect("4-bit ctor");
849 let target = rand_vec(42, 64);
850 idx.insert(0, target.clone()).unwrap();
851 for i in 1..50_u64 {
852 idx.insert(i, rand_vec(i.wrapping_mul(1_000_003) + 1, 64))
853 .unwrap();
854 }
855 let res = idx.search(&target, 3, None).unwrap();
856 assert!(!res.is_empty());
857 assert_eq!(res[0].id, 0);
858 }
859
860 #[test]
861 fn delete_excludes_from_search() {
862 let mut idx = TurboHnswIndex::<4>::new(Distance::Cosine, 64, HnswParams::default())
863 .expect("4-bit ctor");
864 for i in 0..30_u64 {
865 idx.insert(i, rand_vec(i + 1, 64)).unwrap();
866 }
867 let q = rand_vec(1, 64);
868 let before = idx.search(&q, 5, None).unwrap();
869 let target = before[0].id;
870 assert!(idx.delete(target));
871 let after = idx.search(&q, 5, None).unwrap();
872 assert!(after.iter().all(|r| r.id != target));
873 }
874
875 #[test]
876 fn duplicate_id_rejected() {
877 let mut idx = TurboHnswIndex::<4>::new(Distance::Cosine, 64, HnswParams::default())
878 .expect("4-bit ctor");
879 idx.insert(7, rand_vec(7, 64)).unwrap();
880 assert!(matches!(
881 idx.insert(7, rand_vec(8, 64)),
882 Err(IndexError::Duplicate(7))
883 ));
884 }
885
886 #[test]
887 fn dimension_mismatch_rejected() {
888 let mut idx = TurboHnswIndex::<4>::new(Distance::Cosine, 64, HnswParams::default())
889 .expect("4-bit ctor");
890 assert!(matches!(
891 idx.insert(0, vec![0.1; 32]),
892 Err(IndexError::DimensionMismatch { .. })
893 ));
894 }
895
896 #[test]
897 fn empty_index_search_is_empty() {
898 let idx = TurboHnswIndex::<4>::new(Distance::Cosine, 64, HnswParams::default())
899 .expect("4-bit ctor");
900 let res = idx.search(&rand_vec(0, 64), 5, None).unwrap();
901 assert!(res.is_empty());
902 }
903
904 #[test]
905 fn ctor_rejects_misaligned_dim() {
906 let r = TurboHnswIndex::<4>::new(Distance::Cosine, 7, HnswParams::default());
907 assert!(matches!(
908 r,
909 Err(IndexError::DimensionMismatch {
910 expected: 8,
911 got: 7
912 })
913 ));
914 }
915}