Skip to main content

vector/hnsw/
mod.rs

1//! HNSW graph implementations for centroid search.
2//!
3//! This module provides a trait-based abstraction for HNSW indexes with
4//! an implementation backed by the usearch library.
5
6mod usearch;
7
8use crate::error::Result;
9
10use crate::serde::centroid_chunk::CentroidEntry;
11use crate::serde::collection_meta::DistanceMetric;
12
13// Re-export implementations
14pub use usearch::UsearchCentroidGraph;
15
16/// Trait for HNSW-based centroid graph implementations.
17///
18/// The graph stores centroids and enables fast approximate nearest neighbor search
19/// to find relevant clusters during query execution.
20pub trait CentroidGraph: Send + Sync {
21    /// Search for k nearest centroids to a query vector.
22    ///
23    /// # Arguments
24    /// * `query` - Query vector
25    /// * `k` - Number of nearest centroids to return
26    ///
27    /// # Returns
28    /// Vector of centroid_ids sorted by similarity (closest first)
29    fn search(&self, query: &[f32], k: usize) -> Vec<u64>;
30
31    /// Add a centroid to the graph.
32    ///
33    /// Uses interior mutability since the graph is behind `Arc<dyn CentroidGraph>`.
34    fn add_centroid(&self, entry: &CentroidEntry) -> Result<()>;
35
36    /// Remove a centroid from the graph by its ID.
37    ///
38    /// Uses interior mutability since the graph is behind `Arc<dyn CentroidGraph>`.
39    fn remove_centroid(&self, centroid_id: u64) -> Result<()>;
40
41    /// Get the vector for a centroid by its ID.
42    ///
43    /// Returns `None` if the centroid is not in the graph.
44    fn get_centroid_vector(&self, centroid_id: u64) -> Option<Vec<f32>>;
45
46    /// Returns the number of centroids in the graph.
47    fn len(&self) -> usize;
48
49    /// Returns true if the graph has no centroids.
50    fn is_empty(&self) -> bool {
51        self.len() == 0
52    }
53}
54
55/// Build a centroid graph using the default implementation (usearch).
56pub fn build_centroid_graph(
57    centroids: Vec<CentroidEntry>,
58    distance_metric: DistanceMetric,
59) -> Result<Box<dyn CentroidGraph>> {
60    let graph = UsearchCentroidGraph::build(centroids, distance_metric)?;
61    Ok(Box::new(graph))
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67
68    #[test]
69    fn should_work_through_trait_interface() {
70        // given
71        let centroids = vec![
72            CentroidEntry::new(1, vec![1.0, 0.0, 0.0]),
73            CentroidEntry::new(2, vec![0.0, 1.0, 0.0]),
74            CentroidEntry::new(3, vec![0.0, 0.0, 1.0]),
75        ];
76        let graph: Box<dyn CentroidGraph> =
77            Box::new(UsearchCentroidGraph::build(centroids, DistanceMetric::L2).unwrap());
78
79        // when / then
80        assert_eq!(graph.len(), 3);
81        assert!(!graph.is_empty());
82
83        let results = graph.search(&[0.9, 0.1, 0.1], 1);
84        assert_eq!(results.len(), 1);
85        assert_eq!(results[0], 1);
86    }
87
88    #[test]
89    fn should_build_with_default_function() {
90        // given
91        let centroids = vec![
92            CentroidEntry::new(1, vec![1.0, 0.0]),
93            CentroidEntry::new(2, vec![0.0, 1.0]),
94        ];
95
96        // when
97        let graph = build_centroid_graph(centroids, DistanceMetric::L2).unwrap();
98
99        // then
100        assert_eq!(graph.len(), 2);
101
102        let results = graph.search(&[0.9, 0.1], 1);
103        assert_eq!(results.len(), 1);
104        assert_eq!(results[0], 1);
105    }
106}