1use std::cmp::Ordering;
37use std::collections::{BinaryHeap, HashMap, HashSet};
38
39use serde::{Deserialize, Serialize};
40
41use crate::distance::Distance;
42
43pub type NodeId = u64;
48
49#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
51pub struct HnswParams {
52 pub m: usize,
54 pub m0: usize,
56 pub ef_construction: usize,
58 pub ef_search: usize,
61 pub seed: u64,
65}
66
67impl Default for HnswParams {
68 fn default() -> Self {
69 Self {
70 m: 16,
71 m0: 32,
72 ef_construction: 200,
73 ef_search: 50,
74 seed: 0xDEAD_BEEF_CAFE_F00D,
75 }
76 }
77}
78
79#[derive(Clone, Debug, Serialize, Deserialize)]
81struct HnswNode {
82 id: NodeId,
85 vector: Vec<f32>,
90 levels: Vec<Vec<usize>>,
93 deleted: bool,
97}
98
99impl HnswNode {
100 fn level(&self) -> usize {
101 self.levels.len().saturating_sub(1)
102 }
103}
104
105#[derive(Clone, Debug, Serialize, Deserialize)]
107pub struct HnswIndex {
108 params: HnswParams,
109 distance: Distance,
110 nodes: Vec<HnswNode>,
113 id_to_idx: HashMap<NodeId, usize>,
115 entry: Option<usize>,
118 ml: f64,
121 rng_state: u64,
123 dim: u16,
128}
129
130#[derive(Debug, thiserror::Error)]
132#[non_exhaustive]
133pub enum IndexError {
134 #[error("dimension mismatch: index has {expected}, got {got}")]
136 DimensionMismatch {
137 expected: u16,
139 got: u16,
141 },
142 #[error("id {0} already present in the index")]
144 Duplicate(NodeId),
145 #[error("empty vector")]
147 Empty,
148}
149
150#[derive(Clone, Debug, PartialEq)]
152pub struct SearchResult {
153 pub id: NodeId,
155 pub score: f32,
157}
158
159#[derive(Clone, Copy, Debug)]
162struct Candidate {
163 idx: usize,
164 score: f32,
165}
166
167impl PartialEq for Candidate {
168 fn eq(&self, other: &Self) -> bool {
169 self.score == other.score
170 }
171}
172impl Eq for Candidate {}
173impl PartialOrd for Candidate {
174 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
175 Some(self.cmp(other))
176 }
177}
178impl Ord for Candidate {
179 fn cmp(&self, other: &Self) -> Ordering {
180 other
182 .score
183 .partial_cmp(&self.score)
184 .unwrap_or(Ordering::Equal)
185 }
186}
187
188#[derive(Clone, Copy, Debug)]
191struct MaxCandidate {
192 idx: usize,
193 score: f32,
194}
195
196impl PartialEq for MaxCandidate {
197 fn eq(&self, other: &Self) -> bool {
198 self.score == other.score
199 }
200}
201impl Eq for MaxCandidate {}
202impl PartialOrd for MaxCandidate {
203 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
204 Some(self.cmp(other))
205 }
206}
207impl Ord for MaxCandidate {
208 fn cmp(&self, other: &Self) -> Ordering {
209 self.score
210 .partial_cmp(&other.score)
211 .unwrap_or(Ordering::Equal)
212 }
213}
214
215impl HnswIndex {
216 #[must_use]
218 pub fn new(distance: Distance, params: HnswParams) -> Self {
219 let ml = if params.m > 1 {
220 1.0 / f64::from(u32::try_from(params.m).unwrap_or(u32::MAX)).ln()
221 } else {
222 1.0
223 };
224 Self {
225 params,
226 distance,
227 nodes: Vec::new(),
228 id_to_idx: HashMap::new(),
229 entry: None,
230 ml,
231 rng_state: params.seed,
232 dim: 0,
233 }
234 }
235
236 #[must_use]
238 pub fn len(&self) -> usize {
239 self.nodes.iter().filter(|n| !n.deleted).count()
240 }
241
242 #[must_use]
244 pub fn is_empty(&self) -> bool {
245 self.len() == 0
246 }
247
248 #[must_use]
250 pub fn dim(&self) -> u16 {
251 self.dim
252 }
253
254 #[must_use]
256 pub fn distance(&self) -> Distance {
257 self.distance
258 }
259
260 pub fn insert(&mut self, id: NodeId, vector: Vec<f32>) -> Result<(), IndexError> {
270 if vector.is_empty() {
271 return Err(IndexError::Empty);
272 }
273 let got = u16::try_from(vector.len()).unwrap_or(u16::MAX);
274 if self.nodes.is_empty() {
275 self.dim = got;
276 } else if self.dim != got {
277 return Err(IndexError::DimensionMismatch {
278 expected: self.dim,
279 got,
280 });
281 }
282 if self.id_to_idx.contains_key(&id) {
283 return Err(IndexError::Duplicate(id));
284 }
285
286 let level = self.random_level();
287 let mut levels: Vec<Vec<usize>> = Vec::with_capacity(level + 1);
288 for _ in 0..=level {
289 levels.push(Vec::new());
290 }
291
292 let new_idx = self.nodes.len();
293 self.nodes.push(HnswNode {
294 id,
295 vector,
296 levels,
297 deleted: false,
298 });
299 self.id_to_idx.insert(id, new_idx);
300
301 let Some(entry) = self.entry else {
302 self.entry = Some(new_idx);
303 return Ok(());
304 };
305 let entry_level = self.nodes[entry].level();
306
307 let mut current = entry;
310 if entry_level > level {
311 for lc in (level + 1..=entry_level).rev() {
312 current = self.greedy_search_layer(current, new_idx, lc);
313 }
314 }
315
316 let start_layer = level.min(entry_level);
319 let mut entry_points = vec![current];
320 for lc in (0..=start_layer).rev() {
321 let neighbours = self.search_layer(
322 new_idx,
323 &entry_points,
324 lc,
325 self.params.ef_construction,
326 true,
327 );
328 let m = if lc == 0 {
329 self.params.m0
330 } else {
331 self.params.m
332 };
333 let selected = Self::select_neighbours(&neighbours, m);
334 for &nb in &selected {
336 self.nodes[new_idx].levels[lc].push(nb);
337 self.nodes[nb].levels[lc].push(new_idx);
338 let cap = if lc == 0 {
341 self.params.m0
342 } else {
343 self.params.m
344 };
345 if self.nodes[nb].levels[lc].len() > cap {
346 self.shrink_connections(nb, lc, cap);
347 }
348 }
349 entry_points = selected;
350 if entry_points.is_empty() {
351 entry_points = vec![current];
352 }
353 }
354
355 if level > entry_level {
358 self.entry = Some(new_idx);
359 }
360 Ok(())
361 }
362
363 pub fn delete(&mut self, id: NodeId) -> bool {
368 let Some(&idx) = self.id_to_idx.get(&id) else {
369 return false;
370 };
371 if self.nodes[idx].deleted {
372 return false;
373 }
374 self.nodes[idx].deleted = true;
375 true
376 }
377
378 pub fn search(
390 &self,
391 query: &[f32],
392 k: usize,
393 ef: Option<usize>,
394 ) -> Result<Vec<SearchResult>, IndexError> {
395 if query.is_empty() {
396 return Ok(Vec::new());
397 }
398 if self.nodes.is_empty() {
399 return Ok(Vec::new());
400 }
401 let got = u16::try_from(query.len()).unwrap_or(u16::MAX);
402 if self.dim != got {
403 return Err(IndexError::DimensionMismatch {
404 expected: self.dim,
405 got,
406 });
407 }
408
409 let mut entry = self.entry.unwrap_or(0);
410 let entry_level = self.nodes[entry].level();
411 let ef = ef.unwrap_or(self.params.ef_search).max(k);
412
413 let query_owned = query.to_vec();
415 for lc in (1..=entry_level).rev() {
416 entry = self.greedy_search_layer_against(&query_owned, entry, lc);
417 }
418
419 let candidates = self.search_layer_against(&query_owned, &[entry], 0, ef, true);
420
421 let mut sorted: Vec<MaxCandidate> = candidates;
422 sorted.sort_by(|a, b| {
423 a.score
424 .partial_cmp(&b.score)
425 .unwrap_or(std::cmp::Ordering::Equal)
426 });
427 Ok(sorted
428 .into_iter()
429 .filter(|c| !self.nodes[c.idx].deleted)
430 .take(k)
431 .map(|c| SearchResult {
432 id: self.nodes[c.idx].id,
433 score: c.score,
434 })
435 .collect())
436 }
437
438 #[must_use]
440 pub fn contains(&self, id: NodeId) -> bool {
441 self.id_to_idx
442 .get(&id)
443 .is_some_and(|&idx| !self.nodes[idx].deleted)
444 }
445
446 fn random_level(&mut self) -> usize {
449 let r = self.rand_unit();
450 let r = r.max(f64::MIN_POSITIVE);
452 let level = (-r.ln() * self.ml).floor();
453 let max_level = 16_f64;
456 let clamped = level.clamp(0.0, max_level);
457 #[allow(
460 clippy::cast_possible_truncation,
461 clippy::cast_sign_loss,
462 reason = "clamped to [0, 16]"
463 )]
464 let lvl = clamped as usize;
465 lvl
466 }
467
468 fn rand_unit(&mut self) -> f64 {
470 let mut x = self.rng_state;
471 x ^= x >> 12;
472 x ^= x << 25;
473 x ^= x >> 27;
474 self.rng_state = x;
475 let r = x.wrapping_mul(0x2545_F491_4F6C_DD1D);
476 let bits = (r >> 11) & ((1u64 << 53) - 1);
479 #[allow(
481 clippy::cast_precision_loss,
482 reason = "bits is in [0, 2^53), exactly representable as f64"
483 )]
484 let f = (bits as f64) / ((1_u64 << 53) as f64);
485 f
486 }
487
488 fn greedy_search_layer(&self, entry: usize, query_idx: usize, lc: usize) -> usize {
490 let q = self.nodes[query_idx].vector.clone();
491 self.greedy_search_layer_against(&q, entry, lc)
492 }
493
494 fn greedy_search_layer_against(&self, query: &[f32], entry: usize, lc: usize) -> usize {
496 let mut current = entry;
497 let mut current_score = self.distance.score(query, &self.nodes[current].vector);
498 loop {
499 let mut improved = false;
500 if lc < self.nodes[current].levels.len() {
501 let neighbours: Vec<usize> = self.nodes[current].levels[lc].clone();
502 for nb in neighbours {
503 let s = self.distance.score(query, &self.nodes[nb].vector);
504 if s < current_score {
505 current_score = s;
506 current = nb;
507 improved = true;
508 }
509 }
510 }
511 if !improved {
512 break;
513 }
514 }
515 current
516 }
517
518 fn search_layer(
521 &self,
522 query_idx: usize,
523 entry_points: &[usize],
524 lc: usize,
525 ef: usize,
526 include_deleted: bool,
527 ) -> Vec<MaxCandidate> {
528 let q = self.nodes[query_idx].vector.clone();
529 self.search_layer_against(&q, entry_points, lc, ef, include_deleted)
530 }
531
532 fn search_layer_against(
534 &self,
535 query: &[f32],
536 entry_points: &[usize],
537 lc: usize,
538 ef: usize,
539 include_deleted: bool,
540 ) -> Vec<MaxCandidate> {
541 let mut visited: HashSet<usize> = HashSet::new();
542 let mut frontier: BinaryHeap<Candidate> = BinaryHeap::new();
543 let mut top: BinaryHeap<MaxCandidate> = BinaryHeap::new();
544 for &ep in entry_points {
545 if visited.insert(ep) {
546 let s = self.distance.score(query, &self.nodes[ep].vector);
547 frontier.push(Candidate { idx: ep, score: s });
548 if include_deleted || !self.nodes[ep].deleted {
549 top.push(MaxCandidate { idx: ep, score: s });
550 }
551 }
552 }
553 while let Some(c) = frontier.pop() {
554 if top.len() >= ef {
557 if let Some(worst) = top.peek() {
558 if c.score > worst.score {
559 break;
560 }
561 }
562 }
563 if lc < self.nodes[c.idx].levels.len() {
564 let neighbours: Vec<usize> = self.nodes[c.idx].levels[lc].clone();
565 for nb in neighbours {
566 if !visited.insert(nb) {
567 continue;
568 }
569 let s = self.distance.score(query, &self.nodes[nb].vector);
570 let admit = match top.peek() {
571 Some(worst) => s < worst.score || top.len() < ef,
572 None => true,
573 };
574 if admit {
575 frontier.push(Candidate { idx: nb, score: s });
576 if include_deleted || !self.nodes[nb].deleted {
577 top.push(MaxCandidate { idx: nb, score: s });
578 if top.len() > ef {
579 top.pop();
580 }
581 }
582 }
583 }
584 }
585 }
586 top.into_vec()
587 }
588
589 fn select_neighbours(candidates: &[MaxCandidate], m: usize) -> Vec<usize> {
594 let mut sorted: Vec<MaxCandidate> = candidates.to_vec();
595 sorted.sort_by(|a, b| {
596 a.score
597 .partial_cmp(&b.score)
598 .unwrap_or(std::cmp::Ordering::Equal)
599 });
600 sorted.into_iter().take(m).map(|c| c.idx).collect()
601 }
602
603 fn shrink_connections(&mut self, idx: usize, lc: usize, cap: usize) {
606 let q = self.nodes[idx].vector.clone();
607 let neighbours = std::mem::take(&mut self.nodes[idx].levels[lc]);
608 let mut scored: Vec<(usize, f32)> = neighbours
609 .into_iter()
610 .map(|nb| {
611 let s = self.distance.score(&q, &self.nodes[nb].vector);
612 (nb, s)
613 })
614 .collect();
615 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
616 scored.truncate(cap);
617 self.nodes[idx].levels[lc] = scored.into_iter().map(|(nb, _)| nb).collect();
618 }
619}
620
621#[cfg(test)]
622mod tests {
623 use super::*;
624 use crate::distance::Distance;
625
626 fn unit(seed: u64, dim: usize) -> Vec<f32> {
627 let mut x = seed;
628 let mut v: Vec<f32> = Vec::with_capacity(dim);
629 for _ in 0..dim {
630 x ^= x << 13;
631 x ^= x >> 7;
632 x ^= x << 17;
633 let bits = (x >> 11) & ((1_u64 << 53) - 1);
638 #[allow(
639 clippy::cast_precision_loss,
640 clippy::cast_possible_truncation,
641 reason = "test fixture; PRNG output narrowed to f32"
642 )]
643 let r = ((bits as f64) / ((1_u64 << 53) as f64)) * 2.0 - 1.0;
644 #[allow(
645 clippy::cast_possible_truncation,
646 reason = "test fixture; f64 -> f32 narrowing is intentional"
647 )]
648 let rf = r as f32;
649 v.push(rf);
650 }
651 v
652 }
653
654 #[test]
655 fn insert_and_search_small() {
656 let mut idx = HnswIndex::new(Distance::Euclidean, HnswParams::default());
657 let target = unit(42, 8);
658 idx.insert(0, target.clone()).unwrap();
659 for i in 1..50_u64 {
660 idx.insert(i, unit(i.wrapping_mul(1_000_003) + 1, 8))
661 .unwrap();
662 }
663 let res = idx.search(&target, 3, None).unwrap();
664 assert!(!res.is_empty());
665 assert_eq!(res[0].id, 0);
668 }
669
670 #[test]
671 fn delete_excludes_from_search() {
672 let mut idx = HnswIndex::new(Distance::Euclidean, HnswParams::default());
673 for i in 0..30_u64 {
674 idx.insert(i, unit(i + 1, 8)).unwrap();
675 }
676 let q = unit(1, 8);
677 let before = idx.search(&q, 5, None).unwrap();
678 let target = before[0].id;
679 assert!(idx.delete(target));
680 let after = idx.search(&q, 5, None).unwrap();
681 assert!(after.iter().all(|r| r.id != target));
682 }
683
684 #[test]
685 fn dimension_mismatch_rejected() {
686 let mut idx = HnswIndex::new(Distance::Euclidean, HnswParams::default());
687 idx.insert(0, vec![0.1, 0.2, 0.3]).unwrap();
688 assert!(matches!(
689 idx.insert(1, vec![0.1, 0.2]),
690 Err(IndexError::DimensionMismatch { .. })
691 ));
692 }
693
694 #[test]
695 fn duplicate_id_rejected() {
696 let mut idx = HnswIndex::new(Distance::Euclidean, HnswParams::default());
697 idx.insert(7, vec![0.1, 0.2]).unwrap();
698 assert!(matches!(
699 idx.insert(7, vec![0.3, 0.4]),
700 Err(IndexError::Duplicate(7))
701 ));
702 }
703
704 #[test]
705 fn empty_index_search_is_empty() {
706 let idx = HnswIndex::new(Distance::Euclidean, HnswParams::default());
707 let res = idx.search(&[0.1, 0.2], 5, None).unwrap();
708 assert!(res.is_empty());
709 }
710}