Skip to main content

contextdb_vector/
hnsw.rs

1use crate::quantized::{StoredVector, StoredVectorEntry, quantized_hnsw_distance};
2use anndists::dist::distances::{DistCosine, Distance};
3use contextdb_core::{Error, Result, RowId, VectorIndexRef, VectorQuantization};
4use hnsw_rs::hnsw::Hnsw;
5use parking_lot::RwLock;
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9pub struct HnswIndex {
10    hnsw: HnswInner,
11    id_to_row: RwLock<HashMap<usize, RowId>>,
12    row_to_id: RwLock<HashMap<RowId, usize>>,
13    next_id: AtomicUsize,
14    dimension: usize,
15    quantization: VectorQuantization,
16    ef_search: usize,
17}
18
19enum HnswInner {
20    F32(Hnsw<'static, f32, DistCosine>),
21    Quantized(Hnsw<'static, u8, DistQuantizedCosine>),
22}
23
24#[derive(Debug, Clone, Copy)]
25struct DistQuantizedCosine {
26    quantization: VectorQuantization,
27}
28
29impl Distance<u8> for DistQuantizedCosine {
30    fn eval(&self, va: &[u8], vb: &[u8]) -> f32 {
31        quantized_hnsw_distance(va, vb, self.quantization)
32    }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub struct HnswGraphStats {
37    pub point_count: usize,
38    pub layer0_points: usize,
39    pub layer0_neighbor_edges: usize,
40    pub max_level_observed: u8,
41    pub dimension: usize,
42}
43
44impl HnswIndex {
45    pub(crate) fn new(
46        entries: &[StoredVectorEntry],
47        dimension: usize,
48        quantization: VectorQuantization,
49    ) -> Self {
50        let (m, ef_construction, ef_search) = select_params(entries.len(), quantization);
51        let max_elements = entries.len().max(1);
52        let hnsw = match quantization {
53            VectorQuantization::F32 => {
54                let mut hnsw = Hnsw::new(m, max_elements, 16, ef_construction, DistCosine);
55                hnsw.set_extend_candidates(true);
56                hnsw.set_keeping_pruned(true);
57                HnswInner::F32(hnsw)
58            }
59            VectorQuantization::SQ8 | VectorQuantization::SQ4 => {
60                let mut hnsw = Hnsw::new(
61                    m,
62                    max_elements,
63                    16,
64                    ef_construction,
65                    DistQuantizedCosine { quantization },
66                );
67                hnsw.set_extend_candidates(true);
68                hnsw.set_keeping_pruned(true);
69                HnswInner::Quantized(hnsw)
70            }
71        };
72        let id_to_row = RwLock::new(HashMap::with_capacity(entries.len()));
73        let row_to_id = RwLock::new(HashMap::with_capacity(entries.len()));
74        let mut sorted_entries = entries.iter().collect::<Vec<_>>();
75        sorted_entries.sort_by_key(|entry| {
76            (
77                insertion_key(entry),
78                entry.lsn,
79                entry.created_tx,
80                entry.row_id,
81            )
82        });
83
84        match &hnsw {
85            HnswInner::F32(index) => {
86                let data = sorted_entries
87                    .iter()
88                    .enumerate()
89                    .filter_map(|(data_id, entry)| {
90                        entry.vector.as_f32_slice().map(|vector| {
91                            id_to_row.write().insert(data_id, entry.row_id);
92                            row_to_id.write().insert(entry.row_id, data_id);
93                            (vector.to_vec(), data_id)
94                        })
95                    })
96                    .collect::<Vec<_>>();
97                let refs = data
98                    .iter()
99                    .map(|(vector, data_id)| (vector, *data_id))
100                    .collect::<Vec<_>>();
101                index.parallel_insert(&refs);
102            }
103            HnswInner::Quantized(index) => {
104                let data = sorted_entries
105                    .iter()
106                    .enumerate()
107                    .filter_map(|(data_id, entry)| {
108                        let encoded = entry.vector.to_hnsw_u8();
109                        (!encoded.is_empty()).then(|| {
110                            id_to_row.write().insert(data_id, entry.row_id);
111                            row_to_id.write().insert(entry.row_id, data_id);
112                            (encoded, data_id)
113                        })
114                    })
115                    .collect::<Vec<_>>();
116                let refs = data
117                    .iter()
118                    .map(|(vector, data_id)| (vector, *data_id))
119                    .collect::<Vec<_>>();
120                index.parallel_insert(&refs);
121            }
122        }
123
124        Self {
125            hnsw,
126            id_to_row,
127            row_to_id,
128            next_id: AtomicUsize::new(entries.len()),
129            dimension,
130            quantization,
131            ef_search,
132        }
133    }
134
135    pub(crate) fn insert(&self, row_id: RowId, vector: &StoredVector) {
136        let data_id = self.next_id.fetch_add(1, Ordering::Relaxed);
137        insert_into_hnsw(&self.hnsw, vector, data_id);
138        self.id_to_row.write().insert(data_id, row_id);
139        self.row_to_id.write().insert(row_id, data_id);
140    }
141
142    /// Number of vectors currently indexed in the HNSW graph.
143    pub fn len(&self) -> usize {
144        self.next_id.load(Ordering::Relaxed)
145    }
146
147    pub fn is_empty(&self) -> bool {
148        self.len() == 0
149    }
150
151    #[doc(hidden)]
152    pub fn graph_stats(&self) -> HnswGraphStats {
153        let (point_count, layer0_neighbor_edges, max_level_observed) = match &self.hnsw {
154            HnswInner::F32(hnsw) => hnsw_stats(hnsw),
155            HnswInner::Quantized(hnsw) => hnsw_stats(hnsw),
156        };
157
158        HnswGraphStats {
159            point_count,
160            layer0_points: point_count,
161            layer0_neighbor_edges,
162            max_level_observed,
163            dimension: self.dimension,
164        }
165    }
166
167    pub fn search(
168        &self,
169        index: &VectorIndexRef,
170        query: &[f32],
171        k: usize,
172    ) -> Result<Vec<(RowId, f32)>> {
173        if k == 0 {
174            return Ok(Vec::new());
175        }
176
177        let got = query.len();
178        if got != self.dimension {
179            return Err(Error::VectorIndexDimensionMismatch {
180                index: index.clone(),
181                expected: self.dimension,
182                actual: got,
183            });
184        }
185
186        let ef = self.ef_search.max(k.saturating_mul(10)).max(1);
187        let neighbors = match &self.hnsw {
188            HnswInner::F32(hnsw) => hnsw.search(query, ef, ef),
189            HnswInner::Quantized(hnsw) => {
190                let encoded = StoredVector::from_f32(query, self.quantization).to_hnsw_u8();
191                hnsw.search(&encoded, ef, ef)
192            }
193        };
194        let id_to_row = self.id_to_row.read();
195
196        Ok(neighbors
197            .into_iter()
198            .filter_map(|neighbor| {
199                id_to_row
200                    .get(&neighbor.d_id)
201                    .copied()
202                    .map(|row_id| (row_id, 1.0 - neighbor.distance))
203            })
204            .collect())
205    }
206}
207
208fn insert_into_hnsw(hnsw: &HnswInner, vector: &StoredVector, data_id: usize) {
209    match hnsw {
210        HnswInner::F32(hnsw) => {
211            let Some(vector) = vector.as_f32_slice() else {
212                return;
213            };
214            hnsw.insert((vector, data_id));
215        }
216        HnswInner::Quantized(hnsw) => {
217            let encoded = vector.to_hnsw_u8();
218            if !encoded.is_empty() {
219                hnsw.insert((&encoded, data_id));
220            }
221        }
222    }
223}
224
225fn hnsw_stats<T, D>(hnsw: &Hnsw<'_, T, D>) -> (usize, usize, u8)
226where
227    T: Clone + Send + Sync,
228    D: Distance<T> + Send + Sync,
229{
230    let indexation = hnsw.get_point_indexation();
231    let layer0_neighbor_edges = indexation
232        .get_layer_iterator(0)
233        .map(|point| {
234            point
235                .get_neighborhood_id()
236                .first()
237                .map_or(0, |neighbors| neighbors.len())
238        })
239        .sum();
240    (
241        hnsw.get_nb_point(),
242        layer0_neighbor_edges,
243        hnsw.get_max_level_observed(),
244    )
245}
246
247fn select_params(count: usize, quantization: VectorQuantization) -> (usize, usize, usize) {
248    if !matches!(quantization, VectorQuantization::F32) {
249        return match count {
250            0..=5000 => (8, 32, 96.min(count.max(32))),
251            5001..=50000 => (12, 64, 128),
252            _ => (12, 64, 128),
253        };
254    }
255    match count {
256        0..=5000 => (16, 200, count.max(200)),
257        5001..=50000 => (24, 400, 400),
258        _ => (16, 200, 200),
259    }
260}
261
262fn insertion_key(entry: &StoredVectorEntry) -> u64 {
263    let mut x = entry.row_id.0 ^ entry.lsn.0 ^ entry.created_tx.0;
264    x = x.wrapping_add(0x9e37_79b9_7f4a_7c15);
265    x = (x ^ (x >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
266    x = (x ^ (x >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
267    x ^ (x >> 31)
268}