Skip to main content

grafeo_core/index/vector/
mod.rs

1//! Vector similarity search support.
2//!
3//! This module provides infrastructure for storing and searching vector embeddings,
4//! enabling AI/ML use cases like RAG, semantic search, and recommendations.
5//!
6//! # Distance Metrics
7//!
8//! Choose the metric based on your embedding type:
9//!
10//! | Metric | Best For | Range |
11//! |--------|----------|-------|
12//! | [`Cosine`](DistanceMetric::Cosine) | Normalized embeddings (text) | [0, 2] |
13//! | [`Euclidean`](DistanceMetric::Euclidean) | Raw embeddings | [0, inf) |
14//! | [`DotProduct`](DistanceMetric::DotProduct) | Max inner product search | (-inf, inf) |
15//! | [`Manhattan`](DistanceMetric::Manhattan) | Outlier-resistant | [0, inf) |
16//!
17//! # Index Types
18//!
19//! | Index | Complexity | Use Case |
20//! |-------|------------|----------|
21//! | [`brute_force_knn`] | O(n) | Small datasets, exact results |
22//! | [`HnswIndex`] | O(log n) | Large datasets, approximate results |
23//!
24//! # Example
25//!
26//! ```
27//! use grafeo_core::index::vector::{compute_distance, DistanceMetric, brute_force_knn};
28//! use grafeo_common::types::NodeId;
29//!
30//! // Compute distance between two vectors
31//! let query = [0.1f32, 0.2, 0.3];
32//! let doc1 = [0.1f32, 0.2, 0.35];
33//! let doc2 = [0.5f32, 0.6, 0.7];
34//!
35//! let dist1 = compute_distance(&query, &doc1, DistanceMetric::Cosine);
36//! let dist2 = compute_distance(&query, &doc2, DistanceMetric::Cosine);
37//!
38//! // doc1 is more similar (smaller distance)
39//! assert!(dist1 < dist2);
40//!
41//! // Brute-force k-NN search
42//! let vectors = vec![
43//!     (NodeId::new(1), doc1.as_slice()),
44//!     (NodeId::new(2), doc2.as_slice()),
45//! ];
46//!
47//! let results = brute_force_knn(vectors.into_iter(), &query, 1, DistanceMetric::Cosine);
48//! assert_eq!(results[0].0, NodeId::new(1)); // doc1 is closest
49//! ```
50//!
51//! # HNSW Index (requires `vector-index` feature)
52//!
53//! For larger datasets, use the HNSW approximate nearest neighbor index:
54//!
55//! ```no_run
56//! # #[cfg(feature = "vector-index")]
57//! # {
58//! use grafeo_core::index::vector::{HnswIndex, HnswConfig, DistanceMetric, VectorAccessor};
59//! use grafeo_common::types::NodeId;
60//! use std::sync::Arc;
61//! use std::collections::HashMap;
62//!
63//! let config = HnswConfig::new(384, DistanceMetric::Cosine);
64//! let index = HnswIndex::new(config);
65//!
66//! // Build an accessor backed by a HashMap
67//! let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
68//! let embedding: Arc<[f32]> = vec![0.1f32; 384].into();
69//! map.insert(NodeId::new(1), embedding.clone());
70//! let accessor = |id: NodeId| -> Option<Arc<[f32]>> { map.get(&id).cloned() };
71//!
72//! // Insert vectors (requires accessor for neighbor lookups)
73//! index.insert(NodeId::new(1), &embedding, &accessor);
74//!
75//! // Search (O(log n))
76//! let query = vec![0.15f32; 384];
77//! let results = index.search(&query, 10, &accessor);
78//! # }
79//! ```
80
81mod accessor;
82mod distance;
83mod mmr;
84pub mod quantization;
85mod simd;
86pub mod storage;
87pub mod zone_map;
88
89#[cfg(feature = "vector-index")]
90mod config;
91#[cfg(feature = "vector-index")]
92mod hnsw;
93#[cfg(feature = "vector-index")]
94mod quantized_hnsw;
95
96pub use accessor::{PropertyVectorAccessor, VectorAccessor};
97pub use distance::{
98    DistanceMetric, compute_distance, cosine_distance, cosine_similarity, dot_product,
99    euclidean_distance, euclidean_distance_squared, l2_norm, manhattan_distance, normalize,
100    simd_support,
101};
102pub use mmr::mmr_select;
103pub use quantization::{BinaryQuantizer, ProductQuantizer, QuantizationType, ScalarQuantizer};
104#[cfg(feature = "mmap")]
105pub use storage::MmapStorage;
106pub use storage::{RamStorage, StorageBackend, VectorStorage};
107pub use zone_map::VectorZoneMap;
108
109#[cfg(feature = "vector-index")]
110pub use config::HnswConfig;
111#[cfg(feature = "vector-index")]
112pub use hnsw::HnswIndex;
113#[cfg(feature = "vector-index")]
114pub use quantized_hnsw::QuantizedHnswIndex;
115
116use grafeo_common::types::NodeId;
117
118/// Configuration for vector search operations.
119#[derive(Debug, Clone)]
120pub struct VectorConfig {
121    /// Expected vector dimensions (for validation).
122    pub dimensions: usize,
123    /// Distance metric for similarity computation.
124    pub metric: DistanceMetric,
125}
126
127impl VectorConfig {
128    /// Creates a new vector configuration.
129    #[must_use]
130    pub const fn new(dimensions: usize, metric: DistanceMetric) -> Self {
131        Self { dimensions, metric }
132    }
133
134    /// Creates a configuration for cosine similarity with the given dimensions.
135    #[must_use]
136    pub const fn cosine(dimensions: usize) -> Self {
137        Self::new(dimensions, DistanceMetric::Cosine)
138    }
139
140    /// Creates a configuration for Euclidean distance with the given dimensions.
141    #[must_use]
142    pub const fn euclidean(dimensions: usize) -> Self {
143        Self::new(dimensions, DistanceMetric::Euclidean)
144    }
145}
146
147impl Default for VectorConfig {
148    fn default() -> Self {
149        Self {
150            dimensions: 384, // Common embedding size (MiniLM, etc.)
151            metric: DistanceMetric::default(),
152        }
153    }
154}
155
156/// Performs brute-force k-nearest neighbor search.
157///
158/// This is O(n) where n is the number of vectors. Use this for:
159/// - Small datasets (< 10K vectors)
160/// - Baseline comparisons
161/// - Exact nearest neighbor search
162///
163/// For larger datasets, use an approximate index like HNSW.
164///
165/// # Arguments
166///
167/// * `vectors` - Iterator of (id, vector) pairs to search
168/// * `query` - The query vector
169/// * `k` - Number of nearest neighbors to return
170/// * `metric` - Distance metric to use
171///
172/// # Returns
173///
174/// Vector of (id, distance) pairs sorted by distance (ascending).
175///
176/// # Example
177///
178/// ```
179/// use grafeo_core::index::vector::{brute_force_knn, DistanceMetric};
180/// use grafeo_common::types::NodeId;
181///
182/// let vectors = vec![
183///     (NodeId::new(1), [0.1f32, 0.2, 0.3].as_slice()),
184///     (NodeId::new(2), [0.4f32, 0.5, 0.6].as_slice()),
185///     (NodeId::new(3), [0.7f32, 0.8, 0.9].as_slice()),
186/// ];
187///
188/// let query = [0.15f32, 0.25, 0.35];
189/// let results = brute_force_knn(vectors.into_iter(), &query, 2, DistanceMetric::Euclidean);
190///
191/// assert_eq!(results.len(), 2);
192/// assert_eq!(results[0].0, NodeId::new(1)); // Closest
193/// ```
194pub fn brute_force_knn<'a, I>(
195    vectors: I,
196    query: &[f32],
197    k: usize,
198    metric: DistanceMetric,
199) -> Vec<(NodeId, f32)>
200where
201    I: Iterator<Item = (NodeId, &'a [f32])>,
202{
203    let mut results: Vec<(NodeId, f32)> = vectors
204        .map(|(id, vec)| (id, compute_distance(query, vec, metric)))
205        .collect();
206
207    // Sort by distance (ascending)
208    results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
209
210    // Truncate to k
211    results.truncate(k);
212    results
213}
214
215/// Performs brute-force k-nearest neighbor search with a filter predicate.
216///
217/// Only considers vectors where the predicate returns true.
218///
219/// # Arguments
220///
221/// * `vectors` - Iterator of (id, vector) pairs to search
222/// * `query` - The query vector
223/// * `k` - Number of nearest neighbors to return
224/// * `metric` - Distance metric to use
225/// * `predicate` - Filter function; only vectors where this returns true are considered
226///
227/// # Returns
228///
229/// Vector of (id, distance) pairs sorted by distance (ascending).
230pub fn brute_force_knn_filtered<'a, I, F>(
231    vectors: I,
232    query: &[f32],
233    k: usize,
234    metric: DistanceMetric,
235    predicate: F,
236) -> Vec<(NodeId, f32)>
237where
238    I: Iterator<Item = (NodeId, &'a [f32])>,
239    F: Fn(NodeId) -> bool,
240{
241    let mut results: Vec<(NodeId, f32)> = vectors
242        .filter(|(id, _)| predicate(*id))
243        .map(|(id, vec)| (id, compute_distance(query, vec, metric)))
244        .collect();
245
246    results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
247    results.truncate(k);
248    results
249}
250
251/// Computes the distance between a query and multiple vectors in batch.
252///
253/// More efficient than computing distances one by one for large batches.
254///
255/// # Returns
256///
257/// Vector of (id, distance) pairs in the same order as input.
258pub fn batch_distances<'a, I>(
259    vectors: I,
260    query: &[f32],
261    metric: DistanceMetric,
262) -> Vec<(NodeId, f32)>
263where
264    I: Iterator<Item = (NodeId, &'a [f32])>,
265{
266    vectors
267        .map(|(id, vec)| (id, compute_distance(query, vec, metric)))
268        .collect()
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_vector_config_default() {
277        let config = VectorConfig::default();
278        assert_eq!(config.dimensions, 384);
279        assert_eq!(config.metric, DistanceMetric::Cosine);
280    }
281
282    #[test]
283    fn test_vector_config_constructors() {
284        let cosine = VectorConfig::cosine(768);
285        assert_eq!(cosine.dimensions, 768);
286        assert_eq!(cosine.metric, DistanceMetric::Cosine);
287
288        let euclidean = VectorConfig::euclidean(1536);
289        assert_eq!(euclidean.dimensions, 1536);
290        assert_eq!(euclidean.metric, DistanceMetric::Euclidean);
291    }
292
293    #[test]
294    fn test_brute_force_knn() {
295        let vectors = vec![
296            (NodeId::new(1), [0.0f32, 0.0, 0.0].as_slice()),
297            (NodeId::new(2), [1.0f32, 0.0, 0.0].as_slice()),
298            (NodeId::new(3), [2.0f32, 0.0, 0.0].as_slice()),
299            (NodeId::new(4), [3.0f32, 0.0, 0.0].as_slice()),
300        ];
301
302        let query = [0.5f32, 0.0, 0.0];
303        let results = brute_force_knn(vectors.into_iter(), &query, 2, DistanceMetric::Euclidean);
304
305        assert_eq!(results.len(), 2);
306        // Closest should be node 1 (dist 0.5) or node 2 (dist 0.5)
307        assert!(results[0].0 == NodeId::new(1) || results[0].0 == NodeId::new(2));
308    }
309
310    #[test]
311    fn test_brute_force_knn_empty() {
312        let vectors: Vec<(NodeId, &[f32])> = vec![];
313        let query = [0.0f32, 0.0];
314        let results = brute_force_knn(vectors.into_iter(), &query, 10, DistanceMetric::Cosine);
315        assert!(results.is_empty());
316    }
317
318    #[test]
319    fn test_brute_force_knn_k_larger_than_n() {
320        let vectors = vec![
321            (NodeId::new(1), [0.0f32, 0.0].as_slice()),
322            (NodeId::new(2), [1.0f32, 0.0].as_slice()),
323        ];
324
325        let query = [0.0f32, 0.0];
326        let results = brute_force_knn(vectors.into_iter(), &query, 10, DistanceMetric::Euclidean);
327
328        // Should return all 2 vectors, not 10
329        assert_eq!(results.len(), 2);
330    }
331
332    #[test]
333    fn test_brute_force_knn_filtered() {
334        let vectors = vec![
335            (NodeId::new(1), [0.0f32, 0.0].as_slice()),
336            (NodeId::new(2), [1.0f32, 0.0].as_slice()),
337            (NodeId::new(3), [2.0f32, 0.0].as_slice()),
338        ];
339
340        let query = [0.0f32, 0.0];
341
342        // Only consider even IDs
343        let results = brute_force_knn_filtered(
344            vectors.into_iter(),
345            &query,
346            10,
347            DistanceMetric::Euclidean,
348            |id| id.as_u64() % 2 == 0,
349        );
350
351        assert_eq!(results.len(), 1);
352        assert_eq!(results[0].0, NodeId::new(2));
353    }
354
355    #[test]
356    fn test_batch_distances() {
357        let vectors = vec![
358            (NodeId::new(1), [0.0f32, 0.0].as_slice()),
359            (NodeId::new(2), [3.0f32, 4.0].as_slice()),
360        ];
361
362        let query = [0.0f32, 0.0];
363        let results = batch_distances(vectors.into_iter(), &query, DistanceMetric::Euclidean);
364
365        assert_eq!(results.len(), 2);
366        assert_eq!(results[0].0, NodeId::new(1));
367        assert!((results[0].1 - 0.0).abs() < 0.001);
368        assert_eq!(results[1].0, NodeId::new(2));
369        assert!((results[1].1 - 5.0).abs() < 0.001); // 3-4-5 triangle
370    }
371}