oxify_vector/
multi_index.rs

1//! Multi-Index Search
2//!
3//! Search across multiple indexes in parallel and combine results.
4//!
5//! ## Use Cases
6//! - Federated search across data shards
7//! - Searching different index types (exact + approximate)
8//! - Temporal data with separate indexes per time period
9//! - Multi-tenant scenarios with per-tenant indexes
10//!
11//! ## Example
12//!
13//! ```rust
14//! use oxify_vector::{MultiIndexSearch, VectorSearchIndex, SearchConfig};
15//! use std::collections::HashMap;
16//!
17//! # fn example() -> anyhow::Result<()> {
18//! // Create multiple indexes
19//! let mut index1 = VectorSearchIndex::new(SearchConfig::default());
20//! let mut embeddings1 = HashMap::new();
21//! embeddings1.insert("doc1".to_string(), vec![1.0, 0.0]);
22//! index1.build(&embeddings1)?;
23//!
24//! let mut index2 = VectorSearchIndex::new(SearchConfig::default());
25//! let mut embeddings2 = HashMap::new();
26//! embeddings2.insert("doc2".to_string(), vec![0.0, 1.0]);
27//! index2.build(&embeddings2)?;
28//!
29//! // Search across both indexes
30//! let multi_search = MultiIndexSearch::new();
31//! let query = vec![0.5, 0.5];
32//! let results = multi_search.search(&[&index1, &index2], &query, 10)?;
33//!
34//! // Results are merged and sorted by score
35//! for result in results {
36//!     println!("{}: score = {:.4}", result.entity_id, result.score);
37//! }
38//! # Ok(())
39//! # }
40//! ```
41
42use crate::search::VectorSearchIndex;
43use crate::types::SearchResult;
44use anyhow::{anyhow, Result};
45use rayon::prelude::*;
46use serde::{Deserialize, Serialize};
47use std::collections::HashMap;
48use tracing::{debug, info};
49
50/// Configuration for multi-index search
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct MultiIndexConfig {
53    /// Whether to search indexes in parallel
54    pub parallel: bool,
55    /// Whether to deduplicate results across indexes
56    pub deduplicate: bool,
57    /// How to merge scores from different indexes
58    pub merge_strategy: ScoreMergeStrategy,
59}
60
61impl Default for MultiIndexConfig {
62    fn default() -> Self {
63        Self {
64            parallel: true,
65            deduplicate: true,
66            merge_strategy: ScoreMergeStrategy::Max,
67        }
68    }
69}
70
71/// Strategy for merging scores when same entity appears in multiple indexes
72#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
73pub enum ScoreMergeStrategy {
74    /// Take maximum score
75    Max,
76    /// Take minimum score
77    Min,
78    /// Average scores
79    Average,
80    /// Take first occurrence
81    First,
82}
83
84/// Multi-index search coordinator
85#[derive(Debug, Clone)]
86pub struct MultiIndexSearch {
87    config: MultiIndexConfig,
88}
89
90impl MultiIndexSearch {
91    /// Create a new multi-index search with default config
92    pub fn new() -> Self {
93        Self {
94            config: MultiIndexConfig::default(),
95        }
96    }
97
98    /// Create with custom configuration
99    pub fn with_config(config: MultiIndexConfig) -> Self {
100        Self { config }
101    }
102
103    /// Search across multiple indexes
104    pub fn search(
105        &self,
106        indexes: &[&VectorSearchIndex],
107        query: &[f32],
108        k: usize,
109    ) -> Result<Vec<SearchResult>> {
110        if indexes.is_empty() {
111            return Err(anyhow!("Cannot search across zero indexes"));
112        }
113
114        info!("Searching across {} indexes", indexes.len());
115
116        // Search each index
117        let all_results: Vec<Vec<SearchResult>> = if self.config.parallel {
118            indexes
119                .par_iter()
120                .map(|index| index.search(query, k).unwrap_or_default())
121                .collect()
122        } else {
123            indexes
124                .iter()
125                .map(|index| index.search(query, k).unwrap_or_default())
126                .collect()
127        };
128
129        // Merge results
130        let merged = self.merge_results(all_results, k);
131
132        info!("Multi-index search returned {} results", merged.len());
133        Ok(merged)
134    }
135
136    /// Batch search across multiple indexes
137    pub fn batch_search(
138        &self,
139        indexes: &[&VectorSearchIndex],
140        queries: &[Vec<f32>],
141        k: usize,
142    ) -> Result<Vec<Vec<SearchResult>>> {
143        if indexes.is_empty() {
144            return Err(anyhow!("Cannot search across zero indexes"));
145        }
146
147        info!(
148            "Batch searching {} queries across {} indexes",
149            queries.len(),
150            indexes.len()
151        );
152
153        // Search all queries
154        let results: Vec<Vec<SearchResult>> = if self.config.parallel {
155            queries
156                .par_iter()
157                .map(|query| self.search(indexes, query, k).unwrap_or_default())
158                .collect()
159        } else {
160            queries
161                .iter()
162                .map(|query| self.search(indexes, query, k).unwrap_or_default())
163                .collect()
164        };
165
166        Ok(results)
167    }
168
169    /// Merge results from multiple indexes
170    fn merge_results(&self, all_results: Vec<Vec<SearchResult>>, k: usize) -> Vec<SearchResult> {
171        if !self.config.deduplicate {
172            // Simple concatenation and sorting
173            let mut merged: Vec<SearchResult> = all_results.into_iter().flatten().collect();
174            merged.sort_by(|a, b| {
175                b.score
176                    .partial_cmp(&a.score)
177                    .unwrap_or(std::cmp::Ordering::Equal)
178            });
179            merged.truncate(k);
180
181            // Re-rank
182            for (i, result) in merged.iter_mut().enumerate() {
183                result.rank = i + 1;
184            }
185
186            return merged;
187        }
188
189        // Deduplicate by entity_id and merge scores
190        let mut entity_scores: HashMap<String, Vec<f32>> = HashMap::new();
191        let mut entity_distance: HashMap<String, Vec<f32>> = HashMap::new();
192
193        for results in all_results {
194            for result in results {
195                entity_scores
196                    .entry(result.entity_id.clone())
197                    .or_default()
198                    .push(result.score);
199                entity_distance
200                    .entry(result.entity_id.clone())
201                    .or_default()
202                    .push(result.distance);
203            }
204        }
205
206        // Merge scores according to strategy
207        let mut merged: Vec<SearchResult> = entity_scores
208            .into_iter()
209            .map(|(entity_id, scores)| {
210                let merged_score = match self.config.merge_strategy {
211                    ScoreMergeStrategy::Max => {
212                        scores.iter().copied().fold(f32::NEG_INFINITY, f32::max)
213                    }
214                    ScoreMergeStrategy::Min => scores.iter().copied().fold(f32::INFINITY, f32::min),
215                    ScoreMergeStrategy::Average => scores.iter().sum::<f32>() / scores.len() as f32,
216                    ScoreMergeStrategy::First => scores[0],
217                };
218
219                let merged_distance = match self.config.merge_strategy {
220                    ScoreMergeStrategy::Max => entity_distance[&entity_id]
221                        .iter()
222                        .copied()
223                        .fold(f32::INFINITY, f32::min),
224                    ScoreMergeStrategy::Min => entity_distance[&entity_id]
225                        .iter()
226                        .copied()
227                        .fold(f32::NEG_INFINITY, f32::max),
228                    ScoreMergeStrategy::Average => {
229                        entity_distance[&entity_id].iter().sum::<f32>()
230                            / entity_distance[&entity_id].len() as f32
231                    }
232                    ScoreMergeStrategy::First => entity_distance[&entity_id][0],
233                };
234
235                SearchResult {
236                    entity_id,
237                    score: merged_score,
238                    distance: merged_distance,
239                    rank: 0, // Will be set later
240                }
241            })
242            .collect();
243
244        // Sort by merged score
245        merged.sort_by(|a, b| {
246            b.score
247                .partial_cmp(&a.score)
248                .unwrap_or(std::cmp::Ordering::Equal)
249        });
250
251        // Take top-k and set ranks
252        merged.truncate(k);
253        for (i, result) in merged.iter_mut().enumerate() {
254            result.rank = i + 1;
255        }
256
257        debug!("Merged and deduplicated to {} results", merged.len());
258        merged
259    }
260
261    /// Get configuration
262    pub fn config(&self) -> &MultiIndexConfig {
263        &self.config
264    }
265}
266
267impl Default for MultiIndexSearch {
268    fn default() -> Self {
269        Self::new()
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::SearchConfig;
277    use std::collections::HashMap;
278
279    fn create_test_index(id_prefix: &str, count: usize, dim: usize) -> VectorSearchIndex {
280        let mut embeddings = HashMap::new();
281        for i in 0..count {
282            let vec: Vec<f32> = (0..dim).map(|j| (i + j) as f32 * 0.1).collect();
283            embeddings.insert(format!("{}_{}", id_prefix, i), vec);
284        }
285
286        let mut index = VectorSearchIndex::new(SearchConfig::default());
287        index.build(&embeddings).unwrap();
288        index
289    }
290
291    #[test]
292    fn test_multi_index_search() {
293        let index1 = create_test_index("doc", 5, 3);
294        let index2 = create_test_index("article", 5, 3);
295
296        let multi_search = MultiIndexSearch::new();
297        let query = vec![0.1, 0.2, 0.3];
298        let results = multi_search.search(&[&index1, &index2], &query, 5).unwrap();
299
300        assert!(results.len() <= 5);
301        assert!(!results.is_empty());
302
303        // Verify results are sorted by score
304        for i in 1..results.len() {
305            assert!(results[i - 1].score >= results[i].score);
306        }
307    }
308
309    #[test]
310    fn test_multi_index_deduplication() {
311        // Create two indexes with overlapping entities
312        let mut embeddings1 = HashMap::new();
313        embeddings1.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]);
314        embeddings1.insert("doc2".to_string(), vec![0.9, 0.1, 0.0]);
315        let mut index1 = VectorSearchIndex::new(SearchConfig::default());
316        index1.build(&embeddings1).unwrap();
317
318        let mut embeddings2 = HashMap::new();
319        embeddings2.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]); // Duplicate
320        embeddings2.insert("doc3".to_string(), vec![0.8, 0.2, 0.0]);
321        let mut index2 = VectorSearchIndex::new(SearchConfig::default());
322        index2.build(&embeddings2).unwrap();
323
324        let config = MultiIndexConfig {
325            parallel: false,
326            deduplicate: true,
327            merge_strategy: ScoreMergeStrategy::Max,
328        };
329
330        let multi_search = MultiIndexSearch::with_config(config);
331        let query = vec![1.0, 0.0, 0.0];
332        let results = multi_search
333            .search(&[&index1, &index2], &query, 10)
334            .unwrap();
335
336        // Should have 3 unique entities (doc1, doc2, doc3)
337        assert_eq!(results.len(), 3);
338
339        let entity_ids: Vec<String> = results.iter().map(|r| r.entity_id.clone()).collect();
340        assert!(entity_ids.contains(&"doc1".to_string()));
341        assert!(entity_ids.contains(&"doc2".to_string()));
342        assert!(entity_ids.contains(&"doc3".to_string()));
343    }
344
345    #[test]
346    fn test_multi_index_no_deduplication() {
347        let mut embeddings1 = HashMap::new();
348        embeddings1.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]);
349        let mut index1 = VectorSearchIndex::new(SearchConfig::default());
350        index1.build(&embeddings1).unwrap();
351
352        let mut embeddings2 = HashMap::new();
353        embeddings2.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]); // Duplicate
354        let mut index2 = VectorSearchIndex::new(SearchConfig::default());
355        index2.build(&embeddings2).unwrap();
356
357        let config = MultiIndexConfig {
358            parallel: false,
359            deduplicate: false,
360            merge_strategy: ScoreMergeStrategy::Max,
361        };
362
363        let multi_search = MultiIndexSearch::with_config(config);
364        let query = vec![1.0, 0.0, 0.0];
365        let results = multi_search
366            .search(&[&index1, &index2], &query, 10)
367            .unwrap();
368
369        // Without deduplication, we get both occurrences
370        assert_eq!(results.len(), 2);
371    }
372
373    #[test]
374    fn test_merge_strategy_max() {
375        let config = MultiIndexConfig {
376            parallel: false,
377            deduplicate: true,
378            merge_strategy: ScoreMergeStrategy::Max,
379        };
380
381        let multi_search = MultiIndexSearch::with_config(config);
382
383        // The merge logic is tested indirectly through search
384        // This test just verifies the config is set correctly
385        assert_eq!(
386            multi_search.config().merge_strategy,
387            ScoreMergeStrategy::Max
388        );
389    }
390
391    #[test]
392    fn test_batch_search() {
393        let index1 = create_test_index("doc", 5, 3);
394        let index2 = create_test_index("article", 5, 3);
395
396        let multi_search = MultiIndexSearch::new();
397        let queries = vec![
398            vec![0.1, 0.2, 0.3],
399            vec![0.2, 0.3, 0.4],
400            vec![0.3, 0.4, 0.5],
401        ];
402
403        let results = multi_search
404            .batch_search(&[&index1, &index2], &queries, 3)
405            .unwrap();
406
407        assert_eq!(results.len(), 3);
408        for result_set in results {
409            assert!(result_set.len() <= 3);
410        }
411    }
412
413    #[test]
414    fn test_empty_indexes() {
415        let multi_search = MultiIndexSearch::new();
416        let query = vec![0.1, 0.2, 0.3];
417        let result = multi_search.search(&[], &query, 10);
418
419        assert!(result.is_err());
420    }
421}