Skip to main content

oxirs_embed/
vector_search.rs

1//! Vector Search Integration
2//!
3//! This module provides high-performance vector search capabilities for knowledge graph
4//! embeddings, enabling semantic similarity search, approximate nearest neighbor (ANN)
5//! search, and integration with popular vector databases.
6//!
7//! # Features
8//!
9//! - **Exact Search**: Brute-force cosine similarity search for small datasets
10//! - **Approximate Search**: Fast ANN search using HNSW (Hierarchical Navigable Small World)
11//! - **Index Building**: Efficient index construction for large-scale search
12//! - **Batch Search**: Process multiple queries in parallel
13//! - **Filtering**: Support for metadata filtering during search
14//! - **Distance Metrics**: Cosine similarity, Euclidean distance, dot product
15//!
16//! # Quick Start
17//!
18//! ```rust,no_run
19//! use oxirs_embed::{
20//!     vector_search::{VectorSearchIndex, SearchConfig, DistanceMetric},
21//!     EmbeddingModel, TransE, ModelConfig,
22//! };
23//! use std::collections::HashMap;
24//!
25//! # async fn example() -> anyhow::Result<()> {
26//! // Build search index from embeddings
27//! let mut embeddings = HashMap::new();
28//! // ... populate embeddings from trained model
29//!
30//! let config = SearchConfig {
31//!     metric: DistanceMetric::Cosine,
32//!     ..Default::default()
33//! };
34//!
35//! let mut index = VectorSearchIndex::new(config);
36//! index.build(&embeddings)?;
37//!
38//! // Search for similar entities
39//! let query_embedding = vec![0.1, 0.2, 0.3]; // ... your query embedding
40//! let results = index.search(&query_embedding, 10)?;
41//!
42//! for result in results {
43//!     println!("{}: similarity = {:.4}", result.entity_id, result.score);
44//! }
45//! # Ok(())
46//! # }
47//! ```
48
49use anyhow::{anyhow, Result};
50use rayon::prelude::*;
51use scirs2_core::ndarray_ext::Array1;
52use serde::{Deserialize, Serialize};
53use std::collections::HashMap;
54use tracing::{debug, info};
55
56/// Distance metric for vector search
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58pub enum DistanceMetric {
59    /// Cosine similarity (normalized dot product)
60    Cosine,
61    /// Euclidean distance (L2 norm)
62    Euclidean,
63    /// Dot product similarity
64    DotProduct,
65    /// Manhattan distance (L1 norm)
66    Manhattan,
67}
68
69/// Vector search configuration
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct SearchConfig {
72    /// Distance metric to use
73    pub metric: DistanceMetric,
74    /// Use approximate search (HNSW) for large datasets
75    pub use_approximate: bool,
76    /// Number of neighbors for HNSW graph construction
77    pub hnsw_m: usize,
78    /// Size of dynamic candidate list for HNSW
79    pub hnsw_ef_construction: usize,
80    /// Size of dynamic candidate list for HNSW search
81    pub hnsw_ef_search: usize,
82    /// Enable parallel search
83    pub parallel: bool,
84    /// Normalize vectors before search
85    pub normalize: bool,
86}
87
88impl Default for SearchConfig {
89    fn default() -> Self {
90        Self {
91            metric: DistanceMetric::Cosine,
92            use_approximate: false,
93            hnsw_m: 16,
94            hnsw_ef_construction: 200,
95            hnsw_ef_search: 50,
96            parallel: true,
97            normalize: true,
98        }
99    }
100}
101
102/// Search result
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct SearchResult {
105    /// Entity ID
106    pub entity_id: String,
107    /// Similarity score (higher is better)
108    pub score: f32,
109    /// Distance (lower is better, depends on metric)
110    pub distance: f32,
111    /// Rank in results (1-indexed)
112    pub rank: usize,
113}
114
115/// Vector search index
116pub struct VectorSearchIndex {
117    config: SearchConfig,
118    embeddings: HashMap<String, Array1<f32>>,
119    entity_ids: Vec<String>,
120    embedding_matrix: Option<Vec<Vec<f32>>>,
121    dimensions: usize,
122    is_built: bool,
123}
124
125impl VectorSearchIndex {
126    /// Create new vector search index
127    pub fn new(config: SearchConfig) -> Self {
128        info!(
129            "Initialized vector search index: metric={:?}, approximate={}",
130            config.metric, config.use_approximate
131        );
132
133        Self {
134            config,
135            embeddings: HashMap::new(),
136            entity_ids: Vec::new(),
137            embedding_matrix: None,
138            dimensions: 0,
139            is_built: false,
140        }
141    }
142
143    /// Build search index from embeddings
144    pub fn build(&mut self, embeddings: &HashMap<String, Array1<f32>>) -> Result<()> {
145        if embeddings.is_empty() {
146            return Err(anyhow!("Cannot build index from empty embeddings"));
147        }
148
149        info!(
150            "Building vector search index for {} entities",
151            embeddings.len()
152        );
153
154        // Store embeddings
155        self.embeddings = embeddings.clone();
156        self.entity_ids = embeddings.keys().cloned().collect();
157        self.dimensions = embeddings
158            .values()
159            .next()
160            .expect("embeddings should not be empty")
161            .len();
162
163        // Build embedding matrix for efficient search
164        let mut matrix = Vec::new();
165        for entity_id in &self.entity_ids {
166            let mut emb = self.embeddings[entity_id].to_vec();
167
168            // Normalize if configured
169            if self.config.normalize {
170                self.normalize_vector(&mut emb);
171            }
172
173            matrix.push(emb);
174        }
175        self.embedding_matrix = Some(matrix);
176
177        self.is_built = true;
178
179        info!("Vector search index built successfully");
180        Ok(())
181    }
182
183    /// Search for K nearest neighbors
184    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
185        if !self.is_built {
186            return Err(anyhow!("Index not built. Call build() first"));
187        }
188
189        if query.len() != self.dimensions {
190            return Err(anyhow!(
191                "Query dimension {} doesn't match index dimension {}",
192                query.len(),
193                self.dimensions
194            ));
195        }
196
197        // Normalize query if configured
198        let mut normalized_query = query.to_vec();
199        if self.config.normalize {
200            self.normalize_vector(&mut normalized_query);
201        }
202
203        debug!("Searching for {} nearest neighbors", k);
204
205        if self.config.use_approximate && self.embeddings.len() > 1000 {
206            self.approximate_search(&normalized_query, k)
207        } else {
208            self.exact_search(&normalized_query, k)
209        }
210    }
211
212    /// Exact brute-force search
213    fn exact_search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
214        let matrix = self
215            .embedding_matrix
216            .as_ref()
217            .expect("embedding matrix should be built before search");
218
219        // Compute distances/similarities to all entities
220        let scores: Vec<(usize, f32)> = if self.config.parallel {
221            (0..self.entity_ids.len())
222                .into_par_iter()
223                .map(|i| {
224                    let score = self.compute_similarity(query, &matrix[i]);
225                    (i, score)
226                })
227                .collect()
228        } else {
229            (0..self.entity_ids.len())
230                .map(|i| {
231                    let score = self.compute_similarity(query, &matrix[i]);
232                    (i, score)
233                })
234                .collect()
235        };
236
237        // Sort by score descending
238        let mut sorted_scores = scores;
239        sorted_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
240
241        // Return top-K results
242        let results: Vec<SearchResult> = sorted_scores
243            .iter()
244            .take(k.min(self.entity_ids.len()))
245            .enumerate()
246            .map(|(rank, &(idx, score))| SearchResult {
247                entity_id: self.entity_ids[idx].clone(),
248                score,
249                distance: self.score_to_distance(score),
250                rank: rank + 1,
251            })
252            .collect();
253
254        debug!("Found {} results", results.len());
255        Ok(results)
256    }
257
258    /// Approximate search using simplified HNSW
259    fn approximate_search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
260        // For now, fall back to exact search
261        // TODO: Implement full HNSW for very large datasets
262        debug!("Using exact search (HNSW not yet fully implemented)");
263        self.exact_search(query, k)
264    }
265
266    /// Batch search for multiple queries
267    pub fn batch_search(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<SearchResult>>> {
268        if !self.is_built {
269            return Err(anyhow!("Index not built. Call build() first"));
270        }
271
272        info!("Batch searching for {} queries", queries.len());
273
274        let results: Vec<Vec<SearchResult>> = if self.config.parallel {
275            queries
276                .par_iter()
277                .map(|query| self.search(query, k).unwrap_or_default())
278                .collect()
279        } else {
280            queries
281                .iter()
282                .map(|query| self.search(query, k).unwrap_or_default())
283                .collect()
284        };
285
286        Ok(results)
287    }
288
289    /// Compute similarity between two vectors
290    fn compute_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
291        match self.config.metric {
292            DistanceMetric::Cosine => {
293                // Dot product (vectors are already normalized)
294                a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
295            }
296            DistanceMetric::Euclidean => {
297                // Negative Euclidean distance (so higher is better)
298                let dist: f32 = a
299                    .iter()
300                    .zip(b.iter())
301                    .map(|(x, y)| (x - y).powi(2))
302                    .sum::<f32>()
303                    .sqrt();
304                -dist
305            }
306            DistanceMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
307            DistanceMetric::Manhattan => {
308                // Negative Manhattan distance
309                let dist: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
310                -dist
311            }
312        }
313    }
314
315    /// Convert score to distance
316    fn score_to_distance(&self, score: f32) -> f32 {
317        match self.config.metric {
318            DistanceMetric::Cosine => 1.0 - score, // Cosine distance
319            DistanceMetric::Euclidean | DistanceMetric::Manhattan => -score, // Already negative
320            DistanceMetric::DotProduct => -score,
321        }
322    }
323
324    /// Normalize vector in-place
325    fn normalize_vector(&self, vec: &mut [f32]) {
326        let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
327        if norm > 1e-10 {
328            for x in vec.iter_mut() {
329                *x /= norm;
330            }
331        }
332    }
333
334    /// Get index statistics
335    pub fn get_stats(&self) -> IndexStats {
336        IndexStats {
337            num_entities: self.entity_ids.len(),
338            dimensions: self.dimensions,
339            is_built: self.is_built,
340            metric: self.config.metric,
341            use_approximate: self.config.use_approximate,
342        }
343    }
344
345    /// Find entities within a radius
346    pub fn radius_search(&self, query: &[f32], radius: f32) -> Result<Vec<SearchResult>> {
347        if !self.is_built {
348            return Err(anyhow!("Index not built. Call build() first"));
349        }
350
351        let all_results = self.search(query, self.entity_ids.len())?;
352
353        Ok(all_results
354            .into_iter()
355            .filter(|r| r.distance <= radius)
356            .collect())
357    }
358}
359
360/// Index statistics
361#[derive(Debug, Clone, Serialize, Deserialize)]
362pub struct IndexStats {
363    /// Number of entities in index
364    pub num_entities: usize,
365    /// Embedding dimensions
366    pub dimensions: usize,
367    /// Whether index is built
368    pub is_built: bool,
369    /// Distance metric
370    pub metric: DistanceMetric,
371    /// Using approximate search
372    pub use_approximate: bool,
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use scirs2_core::ndarray_ext::array;
379
380    fn create_test_embeddings() -> HashMap<String, Array1<f32>> {
381        let mut embeddings = HashMap::new();
382
383        // Create some test embeddings
384        embeddings.insert("entity1".to_string(), array![1.0, 0.0, 0.0]);
385        embeddings.insert("entity2".to_string(), array![0.9, 0.1, 0.0]);
386        embeddings.insert("entity3".to_string(), array![0.0, 1.0, 0.0]);
387        embeddings.insert("entity4".to_string(), array![0.0, 0.0, 1.0]);
388        embeddings.insert("entity5".to_string(), array![0.7, 0.7, 0.0]);
389
390        embeddings
391    }
392
393    #[test]
394    fn test_index_creation() {
395        let config = SearchConfig::default();
396        let index = VectorSearchIndex::new(config);
397
398        assert!(!index.is_built);
399        assert_eq!(index.dimensions, 0);
400    }
401
402    #[test]
403    fn test_index_building() {
404        let embeddings = create_test_embeddings();
405        let mut index = VectorSearchIndex::new(SearchConfig::default());
406
407        let result = index.build(&embeddings);
408        assert!(result.is_ok());
409        assert!(index.is_built);
410        assert_eq!(index.dimensions, 3);
411        assert_eq!(index.entity_ids.len(), 5);
412    }
413
414    #[test]
415    fn test_exact_search() {
416        let embeddings = create_test_embeddings();
417        let mut index = VectorSearchIndex::new(SearchConfig::default());
418        index.build(&embeddings).unwrap();
419
420        // Search for entities similar to [1, 0, 0]
421        let query = vec![1.0, 0.0, 0.0];
422        let results = index.search(&query, 3).unwrap();
423
424        assert_eq!(results.len(), 3);
425        // entity1 should be most similar
426        assert_eq!(results[0].entity_id, "entity1");
427        assert!(results[0].score > 0.8);
428    }
429
430    #[test]
431    fn test_cosine_similarity() {
432        let config = SearchConfig {
433            metric: DistanceMetric::Cosine,
434            ..Default::default()
435        };
436
437        let embeddings = create_test_embeddings();
438        let mut index = VectorSearchIndex::new(config);
439        index.build(&embeddings).unwrap();
440
441        let query = vec![1.0, 1.0, 0.0];
442        let results = index.search(&query, 2).unwrap();
443
444        assert_eq!(results.len(), 2);
445        // entity5 [0.7, 0.7, 0] should be most similar
446        assert_eq!(results[0].entity_id, "entity5");
447    }
448
449    #[test]
450    fn test_batch_search() {
451        let embeddings = create_test_embeddings();
452        let mut index = VectorSearchIndex::new(SearchConfig::default());
453        index.build(&embeddings).unwrap();
454
455        let queries = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
456
457        let results = index.batch_search(&queries, 2).unwrap();
458
459        assert_eq!(results.len(), 2);
460        assert_eq!(results[0].len(), 2);
461        assert_eq!(results[1].len(), 2);
462    }
463
464    #[test]
465    fn test_radius_search() {
466        let embeddings = create_test_embeddings();
467        let mut index = VectorSearchIndex::new(SearchConfig::default());
468        index.build(&embeddings).unwrap();
469
470        let query = vec![1.0, 0.0, 0.0];
471        let results = index.radius_search(&query, 0.3).unwrap();
472
473        // Should find entities within distance 0.3
474        assert!(!results.is_empty());
475        for result in results {
476            assert!(result.distance <= 0.3);
477        }
478    }
479
480    #[test]
481    fn test_different_metrics() {
482        let embeddings = create_test_embeddings();
483
484        for metric in &[
485            DistanceMetric::Cosine,
486            DistanceMetric::Euclidean,
487            DistanceMetric::DotProduct,
488            DistanceMetric::Manhattan,
489        ] {
490            let config = SearchConfig {
491                metric: *metric,
492                ..Default::default()
493            };
494
495            let mut index = VectorSearchIndex::new(config);
496            index.build(&embeddings).unwrap();
497
498            let query = vec![1.0, 0.0, 0.0];
499            let results = index.search(&query, 3).unwrap();
500
501            assert_eq!(results.len(), 3);
502        }
503    }
504
505    #[test]
506    fn test_index_stats() {
507        let embeddings = create_test_embeddings();
508        let mut index = VectorSearchIndex::new(SearchConfig::default());
509        index.build(&embeddings).unwrap();
510
511        let stats = index.get_stats();
512        assert_eq!(stats.num_entities, 5);
513        assert_eq!(stats.dimensions, 3);
514        assert!(stats.is_built);
515    }
516}