rag_plusplus_core/index/
parallel.rs

1//! Parallel Search
2//!
3//! High-performance parallel search across multiple indexes using rayon.
4//!
5//! # Overview
6//!
7//! This module provides parallel search capabilities for multi-index scenarios,
8//! distributing queries across CPU cores for maximum throughput.
9//!
10//! # Architecture
11//!
12//! ```text
13//! ┌────────────────────────────────────────────────────────────┐
14//! │                    ParallelSearcher                         │
15//! ├────────────────────────────────────────────────────────────┤
16//! │  registry: &IndexRegistry                                   │
17//! │  thread_pool: Option<rayon::ThreadPool>                     │
18//! ├────────────────────────────────────────────────────────────┤
19//! │  + search_parallel(query, k) -> MultiIndexResults           │
20//! │  + search_batch(queries, k) -> Vec<MultiIndexResults>       │
21//! │  + search_indexes_parallel(names, query, k)                 │
22//! └────────────────────────────────────────────────────────────┘
23//!                            │
24//!              ┌─────────────┴─────────────┐
25//!              │       rayon parallel       │
26//!              │         iteration          │
27//!              └────────────┬──────────────┘
28//!                           │
29//!       ┌───────────┬───────┴───────┬───────────┐
30//!       ▼           ▼               ▼           ▼
31//!   Index 0     Index 1         Index 2     Index N
32//! ```
33
34use crate::error::Result;
35use crate::index::registry::{IndexRegistry, MultiIndexResult, MultiIndexResults};
36use crate::index::traits::SearchResult;
37use rayon::prelude::*;
38
39/// Configuration for parallel search.
40#[derive(Debug, Clone)]
41pub struct ParallelSearchConfig {
42    /// Number of threads to use (0 = auto-detect)
43    pub num_threads: usize,
44    /// Minimum indexes per thread (avoid over-parallelization)
45    pub min_indexes_per_thread: usize,
46    /// Enable batch query parallelization
47    pub batch_parallel: bool,
48}
49
50impl Default for ParallelSearchConfig {
51    fn default() -> Self {
52        Self {
53            num_threads: 0, // auto-detect
54            min_indexes_per_thread: 1,
55            batch_parallel: true,
56        }
57    }
58}
59
60impl ParallelSearchConfig {
61    /// Create new config with default settings.
62    #[must_use]
63    pub fn new() -> Self {
64        Self::default()
65    }
66
67    /// Set number of threads.
68    #[must_use]
69    pub const fn with_threads(mut self, num_threads: usize) -> Self {
70        self.num_threads = num_threads;
71        self
72    }
73
74    /// Set minimum indexes per thread.
75    #[must_use]
76    pub const fn with_min_indexes_per_thread(mut self, min: usize) -> Self {
77        self.min_indexes_per_thread = min;
78        self
79    }
80}
81
82/// Parallel searcher for multi-index queries.
83///
84/// Wraps an `IndexRegistry` and provides parallel search operations.
85///
86/// # Example
87///
88/// ```ignore
89/// use rag_plusplus_core::index::{IndexRegistry, ParallelSearcher, FlatIndex, IndexConfig};
90///
91/// let mut registry = IndexRegistry::new();
92/// registry.register("text", FlatIndex::new(IndexConfig::new(768)))?;
93/// registry.register("code", FlatIndex::new(IndexConfig::new(512)))?;
94///
95/// let searcher = ParallelSearcher::new(&registry);
96/// let results = searcher.search_parallel(&query, 10)?;
97/// ```
98pub struct ParallelSearcher<'a> {
99    /// Reference to the registry
100    registry: &'a IndexRegistry,
101    /// Configuration
102    config: ParallelSearchConfig,
103}
104
105impl<'a> ParallelSearcher<'a> {
106    /// Create a new parallel searcher with default config.
107    #[must_use]
108    pub fn new(registry: &'a IndexRegistry) -> Self {
109        Self {
110            registry,
111            config: ParallelSearchConfig::default(),
112        }
113    }
114
115    /// Create a new parallel searcher with custom config.
116    #[must_use]
117    pub fn with_config(registry: &'a IndexRegistry, config: ParallelSearchConfig) -> Self {
118        Self { registry, config }
119    }
120
121    /// Search all compatible indexes in parallel.
122    ///
123    /// # Arguments
124    ///
125    /// * `query` - Query vector
126    /// * `k` - Number of results per index
127    ///
128    /// # Returns
129    ///
130    /// Results from all indexes with matching dimension.
131    pub fn search_parallel(&self, query: &[f32], k: usize) -> Result<MultiIndexResults> {
132        let query_dim = query.len();
133
134        // Collect compatible indexes
135        let indexes: Vec<_> = self
136            .registry
137            .info()
138            .into_iter()
139            .filter(|info| info.dimension == query_dim)
140            .map(|info| info.name)
141            .collect();
142
143        if indexes.is_empty() {
144            return Ok(MultiIndexResults::new());
145        }
146
147        // Decide on parallelization strategy
148        let use_parallel = indexes.len() >= self.config.min_indexes_per_thread * 2;
149
150        let results: Vec<MultiIndexResult> = if use_parallel {
151            // Parallel execution
152            indexes
153                .par_iter()
154                .filter_map(|name| {
155                    self.registry
156                        .search(name, query, k)
157                        .ok()
158                        .map(|results| MultiIndexResult {
159                            index_name: name.clone(),
160                            results,
161                        })
162                })
163                .collect()
164        } else {
165            // Sequential execution for small number of indexes
166            indexes
167                .iter()
168                .filter_map(|name| {
169                    self.registry
170                        .search(name, query, k)
171                        .ok()
172                        .map(|results| MultiIndexResult {
173                            index_name: name.clone(),
174                            results,
175                        })
176                })
177                .collect()
178        };
179
180        let total_count = results.iter().map(|r| r.results.len()).sum();
181
182        Ok(MultiIndexResults {
183            by_index: results,
184            total_count,
185        })
186    }
187
188    /// Search specific indexes in parallel.
189    ///
190    /// # Arguments
191    ///
192    /// * `names` - Index names to search
193    /// * `query` - Query vector
194    /// * `k` - Number of results per index
195    pub fn search_indexes_parallel(
196        &self,
197        names: &[&str],
198        query: &[f32],
199        k: usize,
200    ) -> Result<MultiIndexResults> {
201        let use_parallel = names.len() >= self.config.min_indexes_per_thread * 2;
202
203        let results: Vec<MultiIndexResult> = if use_parallel {
204            names
205                .par_iter()
206                .filter_map(|name| {
207                    self.registry
208                        .search(name, query, k)
209                        .ok()
210                        .map(|results| MultiIndexResult {
211                            index_name: (*name).to_string(),
212                            results,
213                        })
214                })
215                .collect()
216        } else {
217            names
218                .iter()
219                .filter_map(|name| {
220                    self.registry
221                        .search(name, query, k)
222                        .ok()
223                        .map(|results| MultiIndexResult {
224                            index_name: (*name).to_string(),
225                            results,
226                        })
227                })
228                .collect()
229        };
230
231        let total_count = results.iter().map(|r| r.results.len()).sum();
232
233        Ok(MultiIndexResults {
234            by_index: results,
235            total_count,
236        })
237    }
238
239    /// Batch search: run multiple queries in parallel.
240    ///
241    /// # Arguments
242    ///
243    /// * `queries` - Multiple query vectors
244    /// * `k` - Number of results per query per index
245    ///
246    /// # Returns
247    ///
248    /// Results for each query.
249    pub fn search_batch(&self, queries: &[Vec<f32>], k: usize) -> Vec<Result<MultiIndexResults>> {
250        if self.config.batch_parallel && queries.len() > 1 {
251            queries
252                .par_iter()
253                .map(|query| self.search_parallel(query, k))
254                .collect()
255        } else {
256            queries
257                .iter()
258                .map(|query| self.search_parallel(query, k))
259                .collect()
260        }
261    }
262
263    /// Batch search specific indexes.
264    pub fn search_indexes_batch(
265        &self,
266        names: &[&str],
267        queries: &[Vec<f32>],
268        k: usize,
269    ) -> Vec<Result<MultiIndexResults>> {
270        if self.config.batch_parallel && queries.len() > 1 {
271            queries
272                .par_iter()
273                .map(|query| self.search_indexes_parallel(names, query, k))
274                .collect()
275        } else {
276            queries
277                .iter()
278                .map(|query| self.search_indexes_parallel(names, query, k))
279                .collect()
280        }
281    }
282}
283
284/// Parallel add operation for batch indexing.
285///
286/// Adds vectors to multiple indexes in parallel.
287pub fn parallel_add_batch(
288    registry: &mut IndexRegistry,
289    index_name: &str,
290    ids: Vec<String>,
291    vectors: &[Vec<f32>],
292) -> Result<()> {
293    // Validate inputs
294    if ids.len() != vectors.len() {
295        return Err(crate::error::Error::InvalidQuery {
296            reason: format!(
297                "IDs count ({}) doesn't match vectors count ({})",
298                ids.len(),
299                vectors.len()
300            ),
301        });
302    }
303
304    // For now, we add sequentially but could use interior mutability
305    // patterns for true parallel writes in the future
306    for (id, vector) in ids.into_iter().zip(vectors.iter()) {
307        registry.add(index_name, id, vector)?;
308    }
309
310    Ok(())
311}
312
313/// Results aggregator for parallel searches.
314#[derive(Debug, Default)]
315pub struct ResultsAggregator {
316    /// Results by query index
317    results: Vec<MultiIndexResults>,
318}
319
320impl ResultsAggregator {
321    /// Create new aggregator.
322    #[must_use]
323    pub fn new() -> Self {
324        Self::default()
325    }
326
327    /// Add results from a query.
328    pub fn add(&mut self, results: MultiIndexResults) {
329        self.results.push(results);
330    }
331
332    /// Get all results.
333    #[must_use]
334    pub fn results(&self) -> &[MultiIndexResults] {
335        &self.results
336    }
337
338    /// Total number of results across all queries.
339    #[must_use]
340    pub fn total_count(&self) -> usize {
341        self.results.iter().map(|r| r.total_count).sum()
342    }
343
344    /// Flatten all results with query index.
345    #[must_use]
346    pub fn flatten_with_query(&self) -> Vec<(usize, String, SearchResult)> {
347        self.results
348            .iter()
349            .enumerate()
350            .flat_map(|(qi, mir)| {
351                mir.by_index.iter().flat_map(move |idx_result| {
352                    idx_result
353                        .results
354                        .iter()
355                        .cloned()
356                        .map(move |r| (qi, idx_result.index_name.clone(), r))
357                })
358            })
359            .collect()
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use crate::index::{FlatIndex, IndexConfig, VectorIndex};
367
368    fn setup_test_registry() -> IndexRegistry {
369        let mut registry = IndexRegistry::new();
370
371        // Create multiple indexes
372        let mut idx1 = FlatIndex::new(IndexConfig::new(4));
373        idx1.add("a1".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
374        idx1.add("a2".to_string(), &[0.9, 0.1, 0.0, 0.0]).unwrap();
375
376        let mut idx2 = FlatIndex::new(IndexConfig::new(4));
377        idx2.add("b1".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
378        idx2.add("b2".to_string(), &[0.1, 0.9, 0.0, 0.0]).unwrap();
379
380        let mut idx3 = FlatIndex::new(IndexConfig::new(4));
381        idx3.add("c1".to_string(), &[0.0, 0.0, 1.0, 0.0]).unwrap();
382        idx3.add("c2".to_string(), &[0.0, 0.1, 0.9, 0.0]).unwrap();
383
384        registry.register("idx1", idx1).unwrap();
385        registry.register("idx2", idx2).unwrap();
386        registry.register("idx3", idx3).unwrap();
387
388        registry
389    }
390
391    #[test]
392    fn test_parallel_search() {
393        let registry = setup_test_registry();
394        let searcher = ParallelSearcher::new(&registry);
395
396        let query = [1.0, 0.0, 0.0, 0.0];
397        let results = searcher.search_parallel(&query, 10).unwrap();
398
399        // Should search all 3 indexes
400        assert_eq!(results.by_index.len(), 3);
401        assert_eq!(results.total_count, 6); // 2 per index
402    }
403
404    #[test]
405    fn test_search_indexes_parallel() {
406        let registry = setup_test_registry();
407        let searcher = ParallelSearcher::new(&registry);
408
409        let query = [1.0, 0.0, 0.0, 0.0];
410        let results = searcher
411            .search_indexes_parallel(&["idx1", "idx2"], &query, 10)
412            .unwrap();
413
414        // Should only search specified indexes
415        assert_eq!(results.by_index.len(), 2);
416        assert_eq!(results.total_count, 4);
417    }
418
419    #[test]
420    fn test_search_batch() {
421        let registry = setup_test_registry();
422        let searcher = ParallelSearcher::new(&registry);
423
424        let queries = vec![
425            vec![1.0, 0.0, 0.0, 0.0],
426            vec![0.0, 1.0, 0.0, 0.0],
427            vec![0.0, 0.0, 1.0, 0.0],
428        ];
429
430        let results = searcher.search_batch(&queries, 10);
431
432        assert_eq!(results.len(), 3);
433        for result in results {
434            assert!(result.is_ok());
435        }
436    }
437
438    #[test]
439    fn test_config_builder() {
440        let config = ParallelSearchConfig::new()
441            .with_threads(4)
442            .with_min_indexes_per_thread(2);
443
444        assert_eq!(config.num_threads, 4);
445        assert_eq!(config.min_indexes_per_thread, 2);
446    }
447
448    #[test]
449    fn test_results_aggregator() {
450        let mut aggregator = ResultsAggregator::new();
451
452        let mut results1 = MultiIndexResults::new();
453        results1.add(
454            "idx1".to_string(),
455            vec![SearchResult::new(
456                "a".to_string(),
457                0.5,
458                crate::index::DistanceType::L2,
459            )],
460        );
461
462        let mut results2 = MultiIndexResults::new();
463        results2.add(
464            "idx2".to_string(),
465            vec![SearchResult::new(
466                "b".to_string(),
467                0.3,
468                crate::index::DistanceType::L2,
469            )],
470        );
471
472        aggregator.add(results1);
473        aggregator.add(results2);
474
475        assert_eq!(aggregator.results().len(), 2);
476        assert_eq!(aggregator.total_count(), 2);
477
478        let flat = aggregator.flatten_with_query();
479        assert_eq!(flat.len(), 2);
480        assert_eq!(flat[0].0, 0); // Query index 0
481        assert_eq!(flat[1].0, 1); // Query index 1
482    }
483
484    #[test]
485    fn test_incompatible_dimension_skipped() {
486        let mut registry = IndexRegistry::new();
487
488        // Different dimensions
489        let mut idx1 = FlatIndex::new(IndexConfig::new(4));
490        idx1.add("a".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
491
492        let mut idx2 = FlatIndex::new(IndexConfig::new(8));
493        idx2.add("b".to_string(), &[1.0; 8]).unwrap();
494
495        registry.register("idx1", idx1).unwrap();
496        registry.register("idx2", idx2).unwrap();
497
498        let searcher = ParallelSearcher::new(&registry);
499
500        // Query with dim 4 should only search idx1
501        let query = [1.0, 0.0, 0.0, 0.0];
502        let results = searcher.search_parallel(&query, 10).unwrap();
503
504        assert_eq!(results.by_index.len(), 1);
505    }
506}