oxirs_embed/
vector_search.rs

1//! Vector Search Integration
2//!
3//! This module provides high-performance vector search capabilities for knowledge graph
4//! embeddings, enabling semantic similarity search, approximate nearest neighbor (ANN)
5//! search, and integration with popular vector databases.
6//!
7//! # Features
8//!
9//! - **Exact Search**: Brute-force cosine similarity search for small datasets
10//! - **Approximate Search**: Fast ANN search using HNSW (Hierarchical Navigable Small World)
11//! - **Index Building**: Efficient index construction for large-scale search
12//! - **Batch Search**: Process multiple queries in parallel
13//! - **Filtering**: Support for metadata filtering during search
14//! - **Distance Metrics**: Cosine similarity, Euclidean distance, dot product
15//!
16//! # Quick Start
17//!
18//! ```rust,no_run
19//! use oxirs_embed::{
20//!     vector_search::{VectorSearchIndex, SearchConfig, DistanceMetric},
21//!     EmbeddingModel, TransE, ModelConfig,
22//! };
23//! use std::collections::HashMap;
24//!
25//! # async fn example() -> anyhow::Result<()> {
26//! // Build search index from embeddings
27//! let mut embeddings = HashMap::new();
28//! // ... populate embeddings from trained model
29//!
30//! let config = SearchConfig {
31//!     metric: DistanceMetric::Cosine,
32//!     ..Default::default()
33//! };
34//!
35//! let mut index = VectorSearchIndex::new(config);
36//! index.build(&embeddings)?;
37//!
38//! // Search for similar entities
39//! let query_embedding = vec![0.1, 0.2, 0.3]; // ... your query embedding
40//! let results = index.search(&query_embedding, 10)?;
41//!
42//! for result in results {
43//!     println!("{}: similarity = {:.4}", result.entity_id, result.score);
44//! }
45//! # Ok(())
46//! # }
47//! ```
48
49use anyhow::{anyhow, Result};
50use rayon::prelude::*;
51use scirs2_core::ndarray_ext::Array1;
52use serde::{Deserialize, Serialize};
53use std::collections::HashMap;
54use tracing::{debug, info};
55
56/// Distance metric for vector search
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58pub enum DistanceMetric {
59    /// Cosine similarity (normalized dot product)
60    Cosine,
61    /// Euclidean distance (L2 norm)
62    Euclidean,
63    /// Dot product similarity
64    DotProduct,
65    /// Manhattan distance (L1 norm)
66    Manhattan,
67}
68
69/// Vector search configuration
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct SearchConfig {
72    /// Distance metric to use
73    pub metric: DistanceMetric,
74    /// Use approximate search (HNSW) for large datasets
75    pub use_approximate: bool,
76    /// Number of neighbors for HNSW graph construction
77    pub hnsw_m: usize,
78    /// Size of dynamic candidate list for HNSW
79    pub hnsw_ef_construction: usize,
80    /// Size of dynamic candidate list for HNSW search
81    pub hnsw_ef_search: usize,
82    /// Enable parallel search
83    pub parallel: bool,
84    /// Normalize vectors before search
85    pub normalize: bool,
86}
87
88impl Default for SearchConfig {
89    fn default() -> Self {
90        Self {
91            metric: DistanceMetric::Cosine,
92            use_approximate: false,
93            hnsw_m: 16,
94            hnsw_ef_construction: 200,
95            hnsw_ef_search: 50,
96            parallel: true,
97            normalize: true,
98        }
99    }
100}
101
102/// Search result
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct SearchResult {
105    /// Entity ID
106    pub entity_id: String,
107    /// Similarity score (higher is better)
108    pub score: f32,
109    /// Distance (lower is better, depends on metric)
110    pub distance: f32,
111    /// Rank in results (1-indexed)
112    pub rank: usize,
113}
114
115/// Vector search index
116pub struct VectorSearchIndex {
117    config: SearchConfig,
118    embeddings: HashMap<String, Array1<f32>>,
119    entity_ids: Vec<String>,
120    embedding_matrix: Option<Vec<Vec<f32>>>,
121    dimensions: usize,
122    is_built: bool,
123}
124
125impl VectorSearchIndex {
126    /// Create new vector search index
127    pub fn new(config: SearchConfig) -> Self {
128        info!(
129            "Initialized vector search index: metric={:?}, approximate={}",
130            config.metric, config.use_approximate
131        );
132
133        Self {
134            config,
135            embeddings: HashMap::new(),
136            entity_ids: Vec::new(),
137            embedding_matrix: None,
138            dimensions: 0,
139            is_built: false,
140        }
141    }
142
143    /// Build search index from embeddings
144    pub fn build(&mut self, embeddings: &HashMap<String, Array1<f32>>) -> Result<()> {
145        if embeddings.is_empty() {
146            return Err(anyhow!("Cannot build index from empty embeddings"));
147        }
148
149        info!(
150            "Building vector search index for {} entities",
151            embeddings.len()
152        );
153
154        // Store embeddings
155        self.embeddings = embeddings.clone();
156        self.entity_ids = embeddings.keys().cloned().collect();
157        self.dimensions = embeddings.values().next().unwrap().len();
158
159        // Build embedding matrix for efficient search
160        let mut matrix = Vec::new();
161        for entity_id in &self.entity_ids {
162            let mut emb = self.embeddings[entity_id].to_vec();
163
164            // Normalize if configured
165            if self.config.normalize {
166                self.normalize_vector(&mut emb);
167            }
168
169            matrix.push(emb);
170        }
171        self.embedding_matrix = Some(matrix);
172
173        self.is_built = true;
174
175        info!("Vector search index built successfully");
176        Ok(())
177    }
178
179    /// Search for K nearest neighbors
180    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
181        if !self.is_built {
182            return Err(anyhow!("Index not built. Call build() first"));
183        }
184
185        if query.len() != self.dimensions {
186            return Err(anyhow!(
187                "Query dimension {} doesn't match index dimension {}",
188                query.len(),
189                self.dimensions
190            ));
191        }
192
193        // Normalize query if configured
194        let mut normalized_query = query.to_vec();
195        if self.config.normalize {
196            self.normalize_vector(&mut normalized_query);
197        }
198
199        debug!("Searching for {} nearest neighbors", k);
200
201        if self.config.use_approximate && self.embeddings.len() > 1000 {
202            self.approximate_search(&normalized_query, k)
203        } else {
204            self.exact_search(&normalized_query, k)
205        }
206    }
207
208    /// Exact brute-force search
209    fn exact_search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
210        let matrix = self.embedding_matrix.as_ref().unwrap();
211
212        // Compute distances/similarities to all entities
213        let scores: Vec<(usize, f32)> = if self.config.parallel {
214            (0..self.entity_ids.len())
215                .into_par_iter()
216                .map(|i| {
217                    let score = self.compute_similarity(query, &matrix[i]);
218                    (i, score)
219                })
220                .collect()
221        } else {
222            (0..self.entity_ids.len())
223                .map(|i| {
224                    let score = self.compute_similarity(query, &matrix[i]);
225                    (i, score)
226                })
227                .collect()
228        };
229
230        // Sort by score descending
231        let mut sorted_scores = scores;
232        sorted_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
233
234        // Return top-K results
235        let results: Vec<SearchResult> = sorted_scores
236            .iter()
237            .take(k.min(self.entity_ids.len()))
238            .enumerate()
239            .map(|(rank, &(idx, score))| SearchResult {
240                entity_id: self.entity_ids[idx].clone(),
241                score,
242                distance: self.score_to_distance(score),
243                rank: rank + 1,
244            })
245            .collect();
246
247        debug!("Found {} results", results.len());
248        Ok(results)
249    }
250
251    /// Approximate search using simplified HNSW
252    fn approximate_search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
253        // For now, fall back to exact search
254        // TODO: Implement full HNSW for very large datasets
255        debug!("Using exact search (HNSW not yet fully implemented)");
256        self.exact_search(query, k)
257    }
258
259    /// Batch search for multiple queries
260    pub fn batch_search(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<SearchResult>>> {
261        if !self.is_built {
262            return Err(anyhow!("Index not built. Call build() first"));
263        }
264
265        info!("Batch searching for {} queries", queries.len());
266
267        let results: Vec<Vec<SearchResult>> = if self.config.parallel {
268            queries
269                .par_iter()
270                .map(|query| self.search(query, k).unwrap_or_default())
271                .collect()
272        } else {
273            queries
274                .iter()
275                .map(|query| self.search(query, k).unwrap_or_default())
276                .collect()
277        };
278
279        Ok(results)
280    }
281
282    /// Compute similarity between two vectors
283    fn compute_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
284        match self.config.metric {
285            DistanceMetric::Cosine => {
286                // Dot product (vectors are already normalized)
287                a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
288            }
289            DistanceMetric::Euclidean => {
290                // Negative Euclidean distance (so higher is better)
291                let dist: f32 = a
292                    .iter()
293                    .zip(b.iter())
294                    .map(|(x, y)| (x - y).powi(2))
295                    .sum::<f32>()
296                    .sqrt();
297                -dist
298            }
299            DistanceMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
300            DistanceMetric::Manhattan => {
301                // Negative Manhattan distance
302                let dist: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
303                -dist
304            }
305        }
306    }
307
308    /// Convert score to distance
309    fn score_to_distance(&self, score: f32) -> f32 {
310        match self.config.metric {
311            DistanceMetric::Cosine => 1.0 - score, // Cosine distance
312            DistanceMetric::Euclidean | DistanceMetric::Manhattan => -score, // Already negative
313            DistanceMetric::DotProduct => -score,
314        }
315    }
316
317    /// Normalize vector in-place
318    fn normalize_vector(&self, vec: &mut [f32]) {
319        let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
320        if norm > 1e-10 {
321            for x in vec.iter_mut() {
322                *x /= norm;
323            }
324        }
325    }
326
327    /// Get index statistics
328    pub fn get_stats(&self) -> IndexStats {
329        IndexStats {
330            num_entities: self.entity_ids.len(),
331            dimensions: self.dimensions,
332            is_built: self.is_built,
333            metric: self.config.metric,
334            use_approximate: self.config.use_approximate,
335        }
336    }
337
338    /// Find entities within a radius
339    pub fn radius_search(&self, query: &[f32], radius: f32) -> Result<Vec<SearchResult>> {
340        if !self.is_built {
341            return Err(anyhow!("Index not built. Call build() first"));
342        }
343
344        let all_results = self.search(query, self.entity_ids.len())?;
345
346        Ok(all_results
347            .into_iter()
348            .filter(|r| r.distance <= radius)
349            .collect())
350    }
351}
352
353/// Index statistics
354#[derive(Debug, Clone, Serialize, Deserialize)]
355pub struct IndexStats {
356    /// Number of entities in index
357    pub num_entities: usize,
358    /// Embedding dimensions
359    pub dimensions: usize,
360    /// Whether index is built
361    pub is_built: bool,
362    /// Distance metric
363    pub metric: DistanceMetric,
364    /// Using approximate search
365    pub use_approximate: bool,
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use scirs2_core::ndarray_ext::array;
372
373    fn create_test_embeddings() -> HashMap<String, Array1<f32>> {
374        let mut embeddings = HashMap::new();
375
376        // Create some test embeddings
377        embeddings.insert("entity1".to_string(), array![1.0, 0.0, 0.0]);
378        embeddings.insert("entity2".to_string(), array![0.9, 0.1, 0.0]);
379        embeddings.insert("entity3".to_string(), array![0.0, 1.0, 0.0]);
380        embeddings.insert("entity4".to_string(), array![0.0, 0.0, 1.0]);
381        embeddings.insert("entity5".to_string(), array![0.7, 0.7, 0.0]);
382
383        embeddings
384    }
385
386    #[test]
387    fn test_index_creation() {
388        let config = SearchConfig::default();
389        let index = VectorSearchIndex::new(config);
390
391        assert!(!index.is_built);
392        assert_eq!(index.dimensions, 0);
393    }
394
395    #[test]
396    fn test_index_building() {
397        let embeddings = create_test_embeddings();
398        let mut index = VectorSearchIndex::new(SearchConfig::default());
399
400        let result = index.build(&embeddings);
401        assert!(result.is_ok());
402        assert!(index.is_built);
403        assert_eq!(index.dimensions, 3);
404        assert_eq!(index.entity_ids.len(), 5);
405    }
406
407    #[test]
408    fn test_exact_search() {
409        let embeddings = create_test_embeddings();
410        let mut index = VectorSearchIndex::new(SearchConfig::default());
411        index.build(&embeddings).unwrap();
412
413        // Search for entities similar to [1, 0, 0]
414        let query = vec![1.0, 0.0, 0.0];
415        let results = index.search(&query, 3).unwrap();
416
417        assert_eq!(results.len(), 3);
418        // entity1 should be most similar
419        assert_eq!(results[0].entity_id, "entity1");
420        assert!(results[0].score > 0.8);
421    }
422
423    #[test]
424    fn test_cosine_similarity() {
425        let config = SearchConfig {
426            metric: DistanceMetric::Cosine,
427            ..Default::default()
428        };
429
430        let embeddings = create_test_embeddings();
431        let mut index = VectorSearchIndex::new(config);
432        index.build(&embeddings).unwrap();
433
434        let query = vec![1.0, 1.0, 0.0];
435        let results = index.search(&query, 2).unwrap();
436
437        assert_eq!(results.len(), 2);
438        // entity5 [0.7, 0.7, 0] should be most similar
439        assert_eq!(results[0].entity_id, "entity5");
440    }
441
442    #[test]
443    fn test_batch_search() {
444        let embeddings = create_test_embeddings();
445        let mut index = VectorSearchIndex::new(SearchConfig::default());
446        index.build(&embeddings).unwrap();
447
448        let queries = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
449
450        let results = index.batch_search(&queries, 2).unwrap();
451
452        assert_eq!(results.len(), 2);
453        assert_eq!(results[0].len(), 2);
454        assert_eq!(results[1].len(), 2);
455    }
456
457    #[test]
458    fn test_radius_search() {
459        let embeddings = create_test_embeddings();
460        let mut index = VectorSearchIndex::new(SearchConfig::default());
461        index.build(&embeddings).unwrap();
462
463        let query = vec![1.0, 0.0, 0.0];
464        let results = index.radius_search(&query, 0.3).unwrap();
465
466        // Should find entities within distance 0.3
467        assert!(!results.is_empty());
468        for result in results {
469            assert!(result.distance <= 0.3);
470        }
471    }
472
473    #[test]
474    fn test_different_metrics() {
475        let embeddings = create_test_embeddings();
476
477        for metric in &[
478            DistanceMetric::Cosine,
479            DistanceMetric::Euclidean,
480            DistanceMetric::DotProduct,
481            DistanceMetric::Manhattan,
482        ] {
483            let config = SearchConfig {
484                metric: *metric,
485                ..Default::default()
486            };
487
488            let mut index = VectorSearchIndex::new(config);
489            index.build(&embeddings).unwrap();
490
491            let query = vec![1.0, 0.0, 0.0];
492            let results = index.search(&query, 3).unwrap();
493
494            assert_eq!(results.len(), 3);
495        }
496    }
497
498    #[test]
499    fn test_index_stats() {
500        let embeddings = create_test_embeddings();
501        let mut index = VectorSearchIndex::new(SearchConfig::default());
502        index.build(&embeddings).unwrap();
503
504        let stats = index.get_stats();
505        assert_eq!(stats.num_entities, 5);
506        assert_eq!(stats.dimensions, 3);
507        assert!(stats.is_built);
508    }
509}