Skip to main content

contextdb_vector/
hnsw.rs

1use anndists::dist::distances::DistCosine;
2use contextdb_core::{Error, Result, RowId, VectorEntry};
3use hnsw_rs::hnsw::Hnsw;
4use parking_lot::RwLock;
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicUsize, Ordering};
7
8pub struct HnswIndex {
9    hnsw: Hnsw<'static, f32, DistCosine>,
10    id_to_row: RwLock<HashMap<usize, RowId>>,
11    row_to_id: RwLock<HashMap<RowId, usize>>,
12    next_id: AtomicUsize,
13    dimension: usize,
14    ef_search: usize,
15}
16
17impl HnswIndex {
18    pub fn new(entries: &[VectorEntry], dimension: usize) -> Self {
19        let (m, ef_construction, ef_search) = select_params(entries.len());
20        let max_elements = entries.len().max(1);
21        let mut hnsw = Hnsw::new(m, max_elements, 16, ef_construction, DistCosine);
22        hnsw.set_extend_candidates(true);
23        hnsw.set_keeping_pruned(true);
24        let id_to_row = RwLock::new(HashMap::with_capacity(entries.len()));
25        let row_to_id = RwLock::new(HashMap::with_capacity(entries.len()));
26        let mut sorted_entries = entries.iter().collect::<Vec<_>>();
27        sorted_entries.sort_by_key(|entry| {
28            (
29                insertion_key(entry),
30                entry.lsn,
31                entry.created_tx,
32                entry.row_id,
33            )
34        });
35
36        for (data_id, entry) in sorted_entries.into_iter().enumerate() {
37            hnsw.insert((&entry.vector, data_id));
38            id_to_row.write().insert(data_id, entry.row_id);
39            row_to_id.write().insert(entry.row_id, data_id);
40        }
41
42        Self {
43            hnsw,
44            id_to_row,
45            row_to_id,
46            next_id: AtomicUsize::new(entries.len()),
47            dimension,
48            ef_search,
49        }
50    }
51
52    pub fn insert(&self, row_id: RowId, vector: &[f32]) {
53        let data_id = self.next_id.fetch_add(1, Ordering::Relaxed);
54        self.hnsw.insert((vector, data_id));
55        self.id_to_row.write().insert(data_id, row_id);
56        self.row_to_id.write().insert(row_id, data_id);
57    }
58
59    /// Number of vectors currently indexed in the HNSW graph.
60    pub fn len(&self) -> usize {
61        self.next_id.load(Ordering::Relaxed)
62    }
63
64    pub fn is_empty(&self) -> bool {
65        self.len() == 0
66    }
67
68    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(RowId, f32)>> {
69        if k == 0 {
70            return Ok(Vec::new());
71        }
72
73        let got = query.len();
74        if got != self.dimension {
75            return Err(Error::VectorDimensionMismatch {
76                expected: self.dimension,
77                got,
78            });
79        }
80
81        let ef = self.ef_search.max(k.saturating_mul(10)).max(1);
82        let neighbors = self.hnsw.search(query, ef, ef);
83        let id_to_row = self.id_to_row.read();
84
85        Ok(neighbors
86            .into_iter()
87            .filter_map(|neighbor| {
88                id_to_row
89                    .get(&neighbor.d_id)
90                    .copied()
91                    .map(|row_id| (row_id, 1.0 - neighbor.distance))
92            })
93            .collect())
94    }
95}
96
97fn select_params(count: usize) -> (usize, usize, usize) {
98    match count {
99        0..=5000 => (16, 200, count.max(200)),
100        5001..=50000 => (24, 400, 400),
101        _ => (16, 200, 200),
102    }
103}
104
105fn insertion_key(entry: &VectorEntry) -> u64 {
106    let mut x = entry.row_id ^ entry.lsn ^ entry.created_tx;
107    x = x.wrapping_add(0x9e37_79b9_7f4a_7c15);
108    x = (x ^ (x >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
109    x = (x ^ (x >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
110    x ^ (x >> 31)
111}