Skip to main content

vector/hnsw/
usearch.rs

1//! HNSW implementation using the usearch library.
2
3use std::collections::HashMap;
4use std::fmt;
5use std::sync::RwLock;
6
7use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
8
9use crate::error::{Error, Result};
10
11use crate::serde::centroid_chunk::CentroidEntry;
12use crate::serde::collection_meta::DistanceMetric;
13
14use super::CentroidGraph;
15
16/// Initial capacity reserved for the usearch index.
17/// Kept artificially high to avoid usearch deadlock issues near capacity limits.
18const INITIAL_CAPACITY: usize = 200_000;
19
20/// Inner state for UsearchCentroidGraph, protected by a single RwLock.
21struct UsearchCentroidGraphInner {
22    /// The usearch index (thread-safe internally)
23    index: Index,
24    /// Map from usearch key to centroid_id
25    key_to_centroid: HashMap<u64, u64>,
26    /// Reverse map from centroid_id to usearch key (for O(1) removal)
27    centroid_to_key: HashMap<u64, u64>,
28    /// Centroid vectors indexed by centroid_id
29    centroid_vectors: HashMap<u64, Vec<f32>>,
30    /// Next usearch key to allocate
31    next_key: u64,
32}
33
34/// HNSW graph implementation using the usearch library.
35///
36/// Uses interior mutability for thread-safe mutation behind `Arc<dyn CentroidGraph>`.
37/// The usearch `Index` is internally thread-safe for `add`/`remove`/`search`.
38pub struct UsearchCentroidGraph {
39    inner: RwLock<UsearchCentroidGraphInner>,
40}
41
42impl fmt::Debug for UsearchCentroidGraph {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        let inner = self.inner.read().unwrap();
45        f.debug_struct("UsearchCentroidGraph")
46            .field("num_centroids", &inner.key_to_centroid.len())
47            .finish()
48    }
49}
50
51impl UsearchCentroidGraph {
52    /// Build a new HNSW graph from centroids using usearch.
53    ///
54    /// # Arguments
55    /// * `centroids` - Vector of centroid entries with their IDs and vectors
56    /// * `distance_metric` - Distance metric to use for similarity computation
57    ///
58    /// # Returns
59    /// A UsearchCentroidGraph ready for searching
60    pub fn build(centroids: Vec<CentroidEntry>, distance_metric: DistanceMetric) -> Result<Self> {
61        if centroids.is_empty() {
62            return Err(Error::InvalidInput(
63                "Cannot build HNSW graph with no centroids".to_string(),
64            ));
65        }
66
67        // Validate all centroids have the same dimensionality
68        let dimensions = centroids[0].dimensions();
69        for centroid in &centroids {
70            if centroid.dimensions() != dimensions {
71                return Err(Error::InvalidInput(format!(
72                    "Centroid dimension mismatch: expected {}, got {}",
73                    dimensions,
74                    centroid.dimensions()
75                )));
76            }
77        }
78
79        // Convert distance metric to usearch MetricKind
80        let metric = match distance_metric {
81            DistanceMetric::L2 => MetricKind::L2sq,
82            DistanceMetric::DotProduct => MetricKind::IP,
83        };
84
85        // Create index options
86        let options = IndexOptions {
87            dimensions,
88            metric,
89            quantization: ScalarKind::F32,
90            connectivity: 16,      // M parameter
91            expansion_add: 200,    // ef_construction
92            expansion_search: 100, // ef_search default
93            multi: false,
94        };
95
96        // Create index
97        let index = Index::new(&options).map_err(|e| Error::Internal(e.to_string()))?;
98
99        // Reserve 200K capacity upfront
100        index
101            .reserve(INITIAL_CAPACITY)
102            .map_err(|e| Error::Internal(e.to_string()))?;
103
104        // Build mappings and insert
105        let mut key_to_centroid = HashMap::with_capacity(centroids.len());
106        let mut centroid_to_key = HashMap::with_capacity(centroids.len());
107        let mut centroid_vectors = HashMap::with_capacity(centroids.len());
108
109        for (key, centroid) in centroids.iter().enumerate() {
110            let key = key as u64;
111            index
112                .add(key, &centroid.vector)
113                .map_err(|e| Error::Internal(e.to_string()))?;
114            key_to_centroid.insert(key, centroid.centroid_id);
115            centroid_to_key.insert(centroid.centroid_id, key);
116            centroid_vectors.insert(centroid.centroid_id, centroid.vector.clone());
117        }
118
119        let next_key = centroids.len() as u64;
120
121        Ok(Self {
122            inner: RwLock::new(UsearchCentroidGraphInner {
123                index,
124                key_to_centroid,
125                centroid_to_key,
126                centroid_vectors,
127                next_key,
128            }),
129        })
130    }
131}
132
133impl CentroidGraph for UsearchCentroidGraph {
134    fn search(&self, query: &[f32], k: usize) -> Vec<u64> {
135        self.inner.read().expect("lock poisoned").search(query, k)
136    }
137
138    fn add_centroid(&self, entry: &CentroidEntry) -> Result<()> {
139        self.inner
140            .write()
141            .expect("lock poisoned")
142            .add_centroid(entry)
143    }
144
145    fn remove_centroid(&self, centroid_id: u64) -> Result<()> {
146        self.inner
147            .write()
148            .expect("lock poisoned")
149            .remove_centroid(centroid_id)
150    }
151
152    fn get_centroid_vector(&self, centroid_id: u64) -> Option<Vec<f32>> {
153        self.inner
154            .read()
155            .expect("lock poisoned")
156            .get_centroid_vector(centroid_id)
157    }
158
159    fn len(&self) -> usize {
160        self.inner.read().expect("lock poisoned").len()
161    }
162}
163
164impl UsearchCentroidGraphInner {
165    fn search(&self, query: &[f32], k: usize) -> Vec<u64> {
166        let k = k.min(self.key_to_centroid.len());
167        if k == 0 {
168            return Vec::new();
169        }
170
171        // Search usearch index — request more than k to account for removed entries
172        let search_k = (k + 10).min(self.key_to_centroid.len() + 10);
173        let results = match self.index.search(query, search_k) {
174            Ok(matches) => matches,
175            Err(_) => return Vec::new(),
176        };
177
178        // Map keys back to centroid_ids, filtering out removed entries
179        results
180            .keys
181            .iter()
182            .filter_map(|&key| self.key_to_centroid.get(&key).copied())
183            .take(k)
184            .collect()
185    }
186
187    fn add_centroid(&mut self, entry: &CentroidEntry) -> Result<()> {
188        let key = self.next_key;
189        self.next_key += 1;
190
191        self.index
192            .add(key, &entry.vector)
193            .map_err(|e| Error::Internal(e.to_string()))?;
194
195        self.key_to_centroid.insert(key, entry.centroid_id);
196        self.centroid_to_key.insert(entry.centroid_id, key);
197        self.centroid_vectors
198            .insert(entry.centroid_id, entry.vector.clone());
199
200        Ok(())
201    }
202
203    fn remove_centroid(&mut self, centroid_id: u64) -> Result<()> {
204        let key = self.centroid_to_key.remove(&centroid_id).ok_or_else(|| {
205            Error::Internal(format!("Centroid {} not found in graph", centroid_id))
206        })?;
207
208        self.key_to_centroid.remove(&key);
209        self.centroid_vectors.remove(&centroid_id);
210        self.index
211            .remove(key)
212            .map_err(|e| Error::Internal(e.to_string()))?;
213
214        Ok(())
215    }
216
217    fn get_centroid_vector(&self, centroid_id: u64) -> Option<Vec<f32>> {
218        self.centroid_vectors.get(&centroid_id).cloned()
219    }
220
221    fn len(&self) -> usize {
222        self.key_to_centroid.len()
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn should_build_and_search_l2_graph() {
232        // given - 3 centroids
233        let centroids = vec![
234            CentroidEntry::new(1, vec![1.0, 0.0, 0.0]),
235            CentroidEntry::new(2, vec![0.0, 1.0, 0.0]),
236            CentroidEntry::new(3, vec![0.0, 0.0, 1.0]),
237        ];
238
239        // when
240        let graph = UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap();
241        let query = vec![0.9, 0.1, 0.1];
242        let results = graph.search(&query, 1);
243
244        // then
245        assert_eq!(results.len(), 1);
246        assert_eq!(results[0], 1);
247    }
248
249    #[test]
250    fn should_return_multiple_neighbors() {
251        // given - 5 centroids in a line
252        let centroids = vec![
253            CentroidEntry::new(1, vec![0.0]),
254            CentroidEntry::new(2, vec![1.0]),
255            CentroidEntry::new(3, vec![2.0]),
256            CentroidEntry::new(4, vec![3.0]),
257            CentroidEntry::new(5, vec![4.0]),
258        ];
259
260        // when
261        let graph = UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap();
262        let query = vec![2.1];
263        let results = graph.search(&query, 3);
264
265        // then
266        assert_eq!(results.len(), 3);
267        assert_eq!(results[0], 3); // Closest
268    }
269
270    #[test]
271    fn should_reject_empty_centroids() {
272        // given
273        let centroids: Vec<CentroidEntry> = vec![];
274
275        // when
276        let result = UsearchCentroidGraph::build(centroids, DistanceMetric::L2);
277
278        // then
279        assert!(result.is_err());
280        assert!(
281            result
282                .unwrap_err()
283                .to_string()
284                .contains("Cannot build HNSW graph with no centroids")
285        );
286    }
287
288    #[test]
289    fn should_reject_mismatched_dimensions() {
290        // given
291        let centroids = vec![
292            CentroidEntry::new(1, vec![1.0, 2.0]),
293            CentroidEntry::new(2, vec![3.0, 4.0, 5.0]), // Wrong dimensions
294        ];
295
296        // when
297        let result = UsearchCentroidGraph::build(centroids, DistanceMetric::L2);
298
299        // then
300        assert!(result.is_err());
301        assert!(
302            result
303                .unwrap_err()
304                .to_string()
305                .contains("Centroid dimension mismatch")
306        );
307    }
308
309    #[test]
310    fn should_handle_k_larger_than_centroid_count() {
311        // given
312        let centroids = vec![
313            CentroidEntry::new(1, vec![1.0]),
314            CentroidEntry::new(2, vec![2.0]),
315        ];
316
317        // when
318        let graph = UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap();
319        let results = graph.search(&[1.5], 10);
320
321        // then
322        assert_eq!(results.len(), 2);
323    }
324
325    #[test]
326    fn should_add_centroid_and_find_it() {
327        // given - start with 2 centroids
328        let centroids = vec![
329            CentroidEntry::new(1, vec![1.0, 0.0, 0.0]),
330            CentroidEntry::new(2, vec![0.0, 1.0, 0.0]),
331        ];
332        let graph = UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap();
333        assert_eq!(graph.len(), 2);
334
335        // when - add a third centroid
336        let new_entry = CentroidEntry::new(3, vec![0.0, 0.0, 1.0]);
337        graph.add_centroid(&new_entry).unwrap();
338
339        // then - graph has 3 centroids and can find the new one
340        assert_eq!(graph.len(), 3);
341        let results = graph.search(&[0.0, 0.0, 0.9], 1);
342        assert_eq!(results[0], 3);
343    }
344
345    #[test]
346    fn should_remove_centroid() {
347        // given - 3 centroids
348        let centroids = vec![
349            CentroidEntry::new(1, vec![1.0, 0.0, 0.0]),
350            CentroidEntry::new(2, vec![0.0, 1.0, 0.0]),
351            CentroidEntry::new(3, vec![0.0, 0.0, 1.0]),
352        ];
353        let graph = UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap();
354
355        // when - remove centroid 2
356        graph.remove_centroid(2).unwrap();
357
358        // then - graph has 2 centroids and search near [0, 1, 0] returns 1 or 3
359        assert_eq!(graph.len(), 2);
360        let results = graph.search(&[0.0, 0.9, 0.0], 2);
361        assert_eq!(results.len(), 2);
362        assert!(!results.contains(&2), "removed centroid should not appear");
363    }
364
365    #[test]
366    fn should_search_after_add_and_remove() {
367        // given - 3 centroids
368        let centroids = vec![
369            CentroidEntry::new(1, vec![1.0, 0.0]),
370            CentroidEntry::new(2, vec![0.0, 1.0]),
371            CentroidEntry::new(3, vec![-1.0, 0.0]),
372        ];
373        let graph = UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap();
374
375        // when - remove centroid 1, add centroid 4
376        graph.remove_centroid(1).unwrap();
377        graph
378            .add_centroid(&CentroidEntry::new(4, vec![0.5, 0.5]))
379            .unwrap();
380
381        // then
382        assert_eq!(graph.len(), 3);
383        let results = graph.search(&[0.5, 0.5], 1);
384        assert_eq!(results[0], 4);
385    }
386
387    #[test]
388    fn should_get_centroid_vector() {
389        // given
390        let centroids = vec![
391            CentroidEntry::new(1, vec![1.0, 0.0, 0.0]),
392            CentroidEntry::new(2, vec![0.0, 1.0, 0.0]),
393        ];
394        let graph = UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap();
395
396        // when/then
397        assert_eq!(graph.get_centroid_vector(1), Some(vec![1.0, 0.0, 0.0]));
398        assert_eq!(graph.get_centroid_vector(2), Some(vec![0.0, 1.0, 0.0]));
399        assert_eq!(graph.get_centroid_vector(99), None);
400    }
401
402    #[test]
403    fn should_track_vectors_on_add_and_remove() {
404        // given
405        let centroids = vec![CentroidEntry::new(1, vec![1.0, 0.0])];
406        let graph = UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap();
407
408        // when - add centroid
409        graph
410            .add_centroid(&CentroidEntry::new(2, vec![0.0, 1.0]))
411            .unwrap();
412        assert_eq!(graph.get_centroid_vector(2), Some(vec![0.0, 1.0]));
413
414        // when - remove centroid
415        graph.remove_centroid(2).unwrap();
416        assert_eq!(graph.get_centroid_vector(2), None);
417    }
418}