1use anndists::dist::distances::DistCosine;
2use contextdb_core::{Error, Result, RowId, VectorEntry};
3use hnsw_rs::hnsw::Hnsw;
4use parking_lot::RwLock;
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicUsize, Ordering};
7
8pub struct HnswIndex {
9 hnsw: Hnsw<'static, f32, DistCosine>,
10 id_to_row: RwLock<HashMap<usize, RowId>>,
11 row_to_id: RwLock<HashMap<RowId, usize>>,
12 next_id: AtomicUsize,
13 dimension: usize,
14 ef_search: usize,
15}
16
17impl HnswIndex {
18 pub fn new(entries: &[VectorEntry], dimension: usize) -> Self {
19 let (m, ef_construction, ef_search) = select_params(entries.len());
20 let max_elements = entries.len().max(1);
21 let mut hnsw = Hnsw::new(m, max_elements, 16, ef_construction, DistCosine);
22 hnsw.set_extend_candidates(true);
23 hnsw.set_keeping_pruned(true);
24 let id_to_row = RwLock::new(HashMap::with_capacity(entries.len()));
25 let row_to_id = RwLock::new(HashMap::with_capacity(entries.len()));
26 let mut sorted_entries = entries.iter().collect::<Vec<_>>();
27 sorted_entries.sort_by_key(|entry| {
28 (
29 insertion_key(entry),
30 entry.lsn,
31 entry.created_tx,
32 entry.row_id,
33 )
34 });
35
36 for (data_id, entry) in sorted_entries.into_iter().enumerate() {
37 hnsw.insert((&entry.vector, data_id));
38 id_to_row.write().insert(data_id, entry.row_id);
39 row_to_id.write().insert(entry.row_id, data_id);
40 }
41
42 Self {
43 hnsw,
44 id_to_row,
45 row_to_id,
46 next_id: AtomicUsize::new(entries.len()),
47 dimension,
48 ef_search,
49 }
50 }
51
52 pub fn insert(&self, row_id: RowId, vector: &[f32]) {
53 let data_id = self.next_id.fetch_add(1, Ordering::Relaxed);
54 self.hnsw.insert((vector, data_id));
55 self.id_to_row.write().insert(data_id, row_id);
56 self.row_to_id.write().insert(row_id, data_id);
57 }
58
59 pub fn len(&self) -> usize {
61 self.next_id.load(Ordering::Relaxed)
62 }
63
64 pub fn is_empty(&self) -> bool {
65 self.len() == 0
66 }
67
68 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(RowId, f32)>> {
69 if k == 0 {
70 return Ok(Vec::new());
71 }
72
73 let got = query.len();
74 if got != self.dimension {
75 return Err(Error::VectorDimensionMismatch {
76 expected: self.dimension,
77 got,
78 });
79 }
80
81 let ef = self.ef_search.max(k.saturating_mul(10)).max(1);
82 let neighbors = self.hnsw.search(query, ef, ef);
83 let id_to_row = self.id_to_row.read();
84
85 Ok(neighbors
86 .into_iter()
87 .filter_map(|neighbor| {
88 id_to_row
89 .get(&neighbor.d_id)
90 .copied()
91 .map(|row_id| (row_id, 1.0 - neighbor.distance))
92 })
93 .collect())
94 }
95}
96
97fn select_params(count: usize) -> (usize, usize, usize) {
98 match count {
99 0..=5000 => (16, 200, count.max(200)),
100 5001..=50000 => (24, 400, 400),
101 _ => (16, 200, 200),
102 }
103}
104
105fn insertion_key(entry: &VectorEntry) -> u64 {
106 let mut x = entry.row_id ^ entry.lsn ^ entry.created_tx;
107 x = x.wrapping_add(0x9e37_79b9_7f4a_7c15);
108 x = (x ^ (x >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
109 x = (x ^ (x >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
110 x ^ (x >> 31)
111}