1use std::collections::HashMap;
4use std::fmt;
5use std::sync::RwLock;
6
7use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
8
9use crate::error::{Error, Result};
10
11use crate::serde::centroid_chunk::CentroidEntry;
12use crate::serde::collection_meta::DistanceMetric;
13
14use super::CentroidGraph;
15
16const INITIAL_CAPACITY: usize = 200_000;
19
20struct UsearchCentroidGraphInner {
22 index: Index,
24 key_to_centroid: HashMap<u64, u64>,
26 centroid_to_key: HashMap<u64, u64>,
28 centroid_vectors: HashMap<u64, Vec<f32>>,
30 next_key: u64,
32}
33
34pub struct UsearchCentroidGraph {
39 inner: RwLock<UsearchCentroidGraphInner>,
40}
41
42impl fmt::Debug for UsearchCentroidGraph {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 let inner = self.inner.read().unwrap();
45 f.debug_struct("UsearchCentroidGraph")
46 .field("num_centroids", &inner.key_to_centroid.len())
47 .finish()
48 }
49}
50
51impl UsearchCentroidGraph {
52 pub fn build(centroids: Vec<CentroidEntry>, distance_metric: DistanceMetric) -> Result<Self> {
61 if centroids.is_empty() {
62 return Err(Error::InvalidInput(
63 "Cannot build HNSW graph with no centroids".to_string(),
64 ));
65 }
66
67 let dimensions = centroids[0].dimensions();
69 for centroid in ¢roids {
70 if centroid.dimensions() != dimensions {
71 return Err(Error::InvalidInput(format!(
72 "Centroid dimension mismatch: expected {}, got {}",
73 dimensions,
74 centroid.dimensions()
75 )));
76 }
77 }
78
79 let metric = match distance_metric {
81 DistanceMetric::L2 => MetricKind::L2sq,
82 DistanceMetric::DotProduct => MetricKind::IP,
83 };
84
85 let options = IndexOptions {
87 dimensions,
88 metric,
89 quantization: ScalarKind::F32,
90 connectivity: 16, expansion_add: 200, expansion_search: 100, multi: false,
94 };
95
96 let index = Index::new(&options).map_err(|e| Error::Internal(e.to_string()))?;
98
99 index
101 .reserve(INITIAL_CAPACITY)
102 .map_err(|e| Error::Internal(e.to_string()))?;
103
104 let mut key_to_centroid = HashMap::with_capacity(centroids.len());
106 let mut centroid_to_key = HashMap::with_capacity(centroids.len());
107 let mut centroid_vectors = HashMap::with_capacity(centroids.len());
108
109 for (key, centroid) in centroids.iter().enumerate() {
110 let key = key as u64;
111 index
112 .add(key, ¢roid.vector)
113 .map_err(|e| Error::Internal(e.to_string()))?;
114 key_to_centroid.insert(key, centroid.centroid_id);
115 centroid_to_key.insert(centroid.centroid_id, key);
116 centroid_vectors.insert(centroid.centroid_id, centroid.vector.clone());
117 }
118
119 let next_key = centroids.len() as u64;
120
121 Ok(Self {
122 inner: RwLock::new(UsearchCentroidGraphInner {
123 index,
124 key_to_centroid,
125 centroid_to_key,
126 centroid_vectors,
127 next_key,
128 }),
129 })
130 }
131}
132
133impl CentroidGraph for UsearchCentroidGraph {
134 fn search(&self, query: &[f32], k: usize) -> Vec<u64> {
135 self.inner.read().expect("lock poisoned").search(query, k)
136 }
137
138 fn add_centroid(&self, entry: &CentroidEntry) -> Result<()> {
139 self.inner
140 .write()
141 .expect("lock poisoned")
142 .add_centroid(entry)
143 }
144
145 fn remove_centroid(&self, centroid_id: u64) -> Result<()> {
146 self.inner
147 .write()
148 .expect("lock poisoned")
149 .remove_centroid(centroid_id)
150 }
151
152 fn get_centroid_vector(&self, centroid_id: u64) -> Option<Vec<f32>> {
153 self.inner
154 .read()
155 .expect("lock poisoned")
156 .get_centroid_vector(centroid_id)
157 }
158
159 fn len(&self) -> usize {
160 self.inner.read().expect("lock poisoned").len()
161 }
162}
163
164impl UsearchCentroidGraphInner {
165 fn search(&self, query: &[f32], k: usize) -> Vec<u64> {
166 let k = k.min(self.key_to_centroid.len());
167 if k == 0 {
168 return Vec::new();
169 }
170
171 let search_k = (k + 10).min(self.key_to_centroid.len() + 10);
173 let results = match self.index.search(query, search_k) {
174 Ok(matches) => matches,
175 Err(_) => return Vec::new(),
176 };
177
178 results
180 .keys
181 .iter()
182 .filter_map(|&key| self.key_to_centroid.get(&key).copied())
183 .take(k)
184 .collect()
185 }
186
187 fn add_centroid(&mut self, entry: &CentroidEntry) -> Result<()> {
188 let key = self.next_key;
189 self.next_key += 1;
190
191 self.index
192 .add(key, &entry.vector)
193 .map_err(|e| Error::Internal(e.to_string()))?;
194
195 self.key_to_centroid.insert(key, entry.centroid_id);
196 self.centroid_to_key.insert(entry.centroid_id, key);
197 self.centroid_vectors
198 .insert(entry.centroid_id, entry.vector.clone());
199
200 Ok(())
201 }
202
203 fn remove_centroid(&mut self, centroid_id: u64) -> Result<()> {
204 let key = self.centroid_to_key.remove(¢roid_id).ok_or_else(|| {
205 Error::Internal(format!("Centroid {} not found in graph", centroid_id))
206 })?;
207
208 self.key_to_centroid.remove(&key);
209 self.centroid_vectors.remove(¢roid_id);
210 self.index
211 .remove(key)
212 .map_err(|e| Error::Internal(e.to_string()))?;
213
214 Ok(())
215 }
216
217 fn get_centroid_vector(&self, centroid_id: u64) -> Option<Vec<f32>> {
218 self.centroid_vectors.get(¢roid_id).cloned()
219 }
220
221 fn len(&self) -> usize {
222 self.key_to_centroid.len()
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[test]
231 fn should_build_and_search_l2_graph() {
232 let centroids = vec![
234 CentroidEntry::new(1, vec![1.0, 0.0, 0.0]),
235 CentroidEntry::new(2, vec![0.0, 1.0, 0.0]),
236 CentroidEntry::new(3, vec![0.0, 0.0, 1.0]),
237 ];
238
239 let graph = UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap();
241 let query = vec![0.9, 0.1, 0.1];
242 let results = graph.search(&query, 1);
243
244 assert_eq!(results.len(), 1);
246 assert_eq!(results[0], 1);
247 }
248
249 #[test]
250 fn should_return_multiple_neighbors() {
251 let centroids = vec![
253 CentroidEntry::new(1, vec![0.0]),
254 CentroidEntry::new(2, vec![1.0]),
255 CentroidEntry::new(3, vec![2.0]),
256 CentroidEntry::new(4, vec![3.0]),
257 CentroidEntry::new(5, vec![4.0]),
258 ];
259
260 let graph = UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap();
262 let query = vec![2.1];
263 let results = graph.search(&query, 3);
264
265 assert_eq!(results.len(), 3);
267 assert_eq!(results[0], 3); }
269
270 #[test]
271 fn should_reject_empty_centroids() {
272 let centroids: Vec<CentroidEntry> = vec![];
274
275 let result = UsearchCentroidGraph::build(centroids, DistanceMetric::L2);
277
278 assert!(result.is_err());
280 assert!(
281 result
282 .unwrap_err()
283 .to_string()
284 .contains("Cannot build HNSW graph with no centroids")
285 );
286 }
287
288 #[test]
289 fn should_reject_mismatched_dimensions() {
290 let centroids = vec![
292 CentroidEntry::new(1, vec![1.0, 2.0]),
293 CentroidEntry::new(2, vec![3.0, 4.0, 5.0]), ];
295
296 let result = UsearchCentroidGraph::build(centroids, DistanceMetric::L2);
298
299 assert!(result.is_err());
301 assert!(
302 result
303 .unwrap_err()
304 .to_string()
305 .contains("Centroid dimension mismatch")
306 );
307 }
308
309 #[test]
310 fn should_handle_k_larger_than_centroid_count() {
311 let centroids = vec![
313 CentroidEntry::new(1, vec![1.0]),
314 CentroidEntry::new(2, vec![2.0]),
315 ];
316
317 let graph = UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap();
319 let results = graph.search(&[1.5], 10);
320
321 assert_eq!(results.len(), 2);
323 }
324
325 #[test]
326 fn should_add_centroid_and_find_it() {
327 let centroids = vec![
329 CentroidEntry::new(1, vec![1.0, 0.0, 0.0]),
330 CentroidEntry::new(2, vec![0.0, 1.0, 0.0]),
331 ];
332 let graph = UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap();
333 assert_eq!(graph.len(), 2);
334
335 let new_entry = CentroidEntry::new(3, vec![0.0, 0.0, 1.0]);
337 graph.add_centroid(&new_entry).unwrap();
338
339 assert_eq!(graph.len(), 3);
341 let results = graph.search(&[0.0, 0.0, 0.9], 1);
342 assert_eq!(results[0], 3);
343 }
344
345 #[test]
346 fn should_remove_centroid() {
347 let centroids = vec![
349 CentroidEntry::new(1, vec![1.0, 0.0, 0.0]),
350 CentroidEntry::new(2, vec![0.0, 1.0, 0.0]),
351 CentroidEntry::new(3, vec![0.0, 0.0, 1.0]),
352 ];
353 let graph = UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap();
354
355 graph.remove_centroid(2).unwrap();
357
358 assert_eq!(graph.len(), 2);
360 let results = graph.search(&[0.0, 0.9, 0.0], 2);
361 assert_eq!(results.len(), 2);
362 assert!(!results.contains(&2), "removed centroid should not appear");
363 }
364
365 #[test]
366 fn should_search_after_add_and_remove() {
367 let centroids = vec![
369 CentroidEntry::new(1, vec![1.0, 0.0]),
370 CentroidEntry::new(2, vec![0.0, 1.0]),
371 CentroidEntry::new(3, vec![-1.0, 0.0]),
372 ];
373 let graph = UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap();
374
375 graph.remove_centroid(1).unwrap();
377 graph
378 .add_centroid(&CentroidEntry::new(4, vec![0.5, 0.5]))
379 .unwrap();
380
381 assert_eq!(graph.len(), 3);
383 let results = graph.search(&[0.5, 0.5], 1);
384 assert_eq!(results[0], 4);
385 }
386
387 #[test]
388 fn should_get_centroid_vector() {
389 let centroids = vec![
391 CentroidEntry::new(1, vec![1.0, 0.0, 0.0]),
392 CentroidEntry::new(2, vec![0.0, 1.0, 0.0]),
393 ];
394 let graph = UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap();
395
396 assert_eq!(graph.get_centroid_vector(1), Some(vec![1.0, 0.0, 0.0]));
398 assert_eq!(graph.get_centroid_vector(2), Some(vec![0.0, 1.0, 0.0]));
399 assert_eq!(graph.get_centroid_vector(99), None);
400 }
401
402 #[test]
403 fn should_track_vectors_on_add_and_remove() {
404 let centroids = vec![CentroidEntry::new(1, vec![1.0, 0.0])];
406 let graph = UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap();
407
408 graph
410 .add_centroid(&CentroidEntry::new(2, vec![0.0, 1.0]))
411 .unwrap();
412 assert_eq!(graph.get_centroid_vector(2), Some(vec![0.0, 1.0]));
413
414 graph.remove_centroid(2).unwrap();
416 assert_eq!(graph.get_centroid_vector(2), None);
417 }
418}