1use crate::error::Result;
35use crate::index::registry::{IndexRegistry, MultiIndexResult, MultiIndexResults};
36use crate::index::traits::SearchResult;
37use rayon::prelude::*;
38
39#[derive(Debug, Clone)]
41pub struct ParallelSearchConfig {
42 pub num_threads: usize,
44 pub min_indexes_per_thread: usize,
46 pub batch_parallel: bool,
48}
49
50impl Default for ParallelSearchConfig {
51 fn default() -> Self {
52 Self {
53 num_threads: 0, min_indexes_per_thread: 1,
55 batch_parallel: true,
56 }
57 }
58}
59
60impl ParallelSearchConfig {
61 #[must_use]
63 pub fn new() -> Self {
64 Self::default()
65 }
66
67 #[must_use]
69 pub const fn with_threads(mut self, num_threads: usize) -> Self {
70 self.num_threads = num_threads;
71 self
72 }
73
74 #[must_use]
76 pub const fn with_min_indexes_per_thread(mut self, min: usize) -> Self {
77 self.min_indexes_per_thread = min;
78 self
79 }
80}
81
82pub struct ParallelSearcher<'a> {
99 registry: &'a IndexRegistry,
101 config: ParallelSearchConfig,
103}
104
105impl<'a> ParallelSearcher<'a> {
106 #[must_use]
108 pub fn new(registry: &'a IndexRegistry) -> Self {
109 Self {
110 registry,
111 config: ParallelSearchConfig::default(),
112 }
113 }
114
115 #[must_use]
117 pub fn with_config(registry: &'a IndexRegistry, config: ParallelSearchConfig) -> Self {
118 Self { registry, config }
119 }
120
121 pub fn search_parallel(&self, query: &[f32], k: usize) -> Result<MultiIndexResults> {
132 let query_dim = query.len();
133
134 let indexes: Vec<_> = self
136 .registry
137 .info()
138 .into_iter()
139 .filter(|info| info.dimension == query_dim)
140 .map(|info| info.name)
141 .collect();
142
143 if indexes.is_empty() {
144 return Ok(MultiIndexResults::new());
145 }
146
147 let use_parallel = indexes.len() >= self.config.min_indexes_per_thread * 2;
149
150 let results: Vec<MultiIndexResult> = if use_parallel {
151 indexes
153 .par_iter()
154 .filter_map(|name| {
155 self.registry
156 .search(name, query, k)
157 .ok()
158 .map(|results| MultiIndexResult {
159 index_name: name.clone(),
160 results,
161 })
162 })
163 .collect()
164 } else {
165 indexes
167 .iter()
168 .filter_map(|name| {
169 self.registry
170 .search(name, query, k)
171 .ok()
172 .map(|results| MultiIndexResult {
173 index_name: name.clone(),
174 results,
175 })
176 })
177 .collect()
178 };
179
180 let total_count = results.iter().map(|r| r.results.len()).sum();
181
182 Ok(MultiIndexResults {
183 by_index: results,
184 total_count,
185 })
186 }
187
188 pub fn search_indexes_parallel(
196 &self,
197 names: &[&str],
198 query: &[f32],
199 k: usize,
200 ) -> Result<MultiIndexResults> {
201 let use_parallel = names.len() >= self.config.min_indexes_per_thread * 2;
202
203 let results: Vec<MultiIndexResult> = if use_parallel {
204 names
205 .par_iter()
206 .filter_map(|name| {
207 self.registry
208 .search(name, query, k)
209 .ok()
210 .map(|results| MultiIndexResult {
211 index_name: (*name).to_string(),
212 results,
213 })
214 })
215 .collect()
216 } else {
217 names
218 .iter()
219 .filter_map(|name| {
220 self.registry
221 .search(name, query, k)
222 .ok()
223 .map(|results| MultiIndexResult {
224 index_name: (*name).to_string(),
225 results,
226 })
227 })
228 .collect()
229 };
230
231 let total_count = results.iter().map(|r| r.results.len()).sum();
232
233 Ok(MultiIndexResults {
234 by_index: results,
235 total_count,
236 })
237 }
238
239 pub fn search_batch(&self, queries: &[Vec<f32>], k: usize) -> Vec<Result<MultiIndexResults>> {
250 if self.config.batch_parallel && queries.len() > 1 {
251 queries
252 .par_iter()
253 .map(|query| self.search_parallel(query, k))
254 .collect()
255 } else {
256 queries
257 .iter()
258 .map(|query| self.search_parallel(query, k))
259 .collect()
260 }
261 }
262
263 pub fn search_indexes_batch(
265 &self,
266 names: &[&str],
267 queries: &[Vec<f32>],
268 k: usize,
269 ) -> Vec<Result<MultiIndexResults>> {
270 if self.config.batch_parallel && queries.len() > 1 {
271 queries
272 .par_iter()
273 .map(|query| self.search_indexes_parallel(names, query, k))
274 .collect()
275 } else {
276 queries
277 .iter()
278 .map(|query| self.search_indexes_parallel(names, query, k))
279 .collect()
280 }
281 }
282}
283
284pub fn parallel_add_batch(
288 registry: &mut IndexRegistry,
289 index_name: &str,
290 ids: Vec<String>,
291 vectors: &[Vec<f32>],
292) -> Result<()> {
293 if ids.len() != vectors.len() {
295 return Err(crate::error::Error::InvalidQuery {
296 reason: format!(
297 "IDs count ({}) doesn't match vectors count ({})",
298 ids.len(),
299 vectors.len()
300 ),
301 });
302 }
303
304 for (id, vector) in ids.into_iter().zip(vectors.iter()) {
307 registry.add(index_name, id, vector)?;
308 }
309
310 Ok(())
311}
312
313#[derive(Debug, Default)]
315pub struct ResultsAggregator {
316 results: Vec<MultiIndexResults>,
318}
319
320impl ResultsAggregator {
321 #[must_use]
323 pub fn new() -> Self {
324 Self::default()
325 }
326
327 pub fn add(&mut self, results: MultiIndexResults) {
329 self.results.push(results);
330 }
331
332 #[must_use]
334 pub fn results(&self) -> &[MultiIndexResults] {
335 &self.results
336 }
337
338 #[must_use]
340 pub fn total_count(&self) -> usize {
341 self.results.iter().map(|r| r.total_count).sum()
342 }
343
344 #[must_use]
346 pub fn flatten_with_query(&self) -> Vec<(usize, String, SearchResult)> {
347 self.results
348 .iter()
349 .enumerate()
350 .flat_map(|(qi, mir)| {
351 mir.by_index.iter().flat_map(move |idx_result| {
352 idx_result
353 .results
354 .iter()
355 .cloned()
356 .map(move |r| (qi, idx_result.index_name.clone(), r))
357 })
358 })
359 .collect()
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use crate::index::{FlatIndex, IndexConfig, VectorIndex};
367
368 fn setup_test_registry() -> IndexRegistry {
369 let mut registry = IndexRegistry::new();
370
371 let mut idx1 = FlatIndex::new(IndexConfig::new(4));
373 idx1.add("a1".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
374 idx1.add("a2".to_string(), &[0.9, 0.1, 0.0, 0.0]).unwrap();
375
376 let mut idx2 = FlatIndex::new(IndexConfig::new(4));
377 idx2.add("b1".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
378 idx2.add("b2".to_string(), &[0.1, 0.9, 0.0, 0.0]).unwrap();
379
380 let mut idx3 = FlatIndex::new(IndexConfig::new(4));
381 idx3.add("c1".to_string(), &[0.0, 0.0, 1.0, 0.0]).unwrap();
382 idx3.add("c2".to_string(), &[0.0, 0.1, 0.9, 0.0]).unwrap();
383
384 registry.register("idx1", idx1).unwrap();
385 registry.register("idx2", idx2).unwrap();
386 registry.register("idx3", idx3).unwrap();
387
388 registry
389 }
390
391 #[test]
392 fn test_parallel_search() {
393 let registry = setup_test_registry();
394 let searcher = ParallelSearcher::new(®istry);
395
396 let query = [1.0, 0.0, 0.0, 0.0];
397 let results = searcher.search_parallel(&query, 10).unwrap();
398
399 assert_eq!(results.by_index.len(), 3);
401 assert_eq!(results.total_count, 6); }
403
404 #[test]
405 fn test_search_indexes_parallel() {
406 let registry = setup_test_registry();
407 let searcher = ParallelSearcher::new(®istry);
408
409 let query = [1.0, 0.0, 0.0, 0.0];
410 let results = searcher
411 .search_indexes_parallel(&["idx1", "idx2"], &query, 10)
412 .unwrap();
413
414 assert_eq!(results.by_index.len(), 2);
416 assert_eq!(results.total_count, 4);
417 }
418
419 #[test]
420 fn test_search_batch() {
421 let registry = setup_test_registry();
422 let searcher = ParallelSearcher::new(®istry);
423
424 let queries = vec![
425 vec![1.0, 0.0, 0.0, 0.0],
426 vec![0.0, 1.0, 0.0, 0.0],
427 vec![0.0, 0.0, 1.0, 0.0],
428 ];
429
430 let results = searcher.search_batch(&queries, 10);
431
432 assert_eq!(results.len(), 3);
433 for result in results {
434 assert!(result.is_ok());
435 }
436 }
437
438 #[test]
439 fn test_config_builder() {
440 let config = ParallelSearchConfig::new()
441 .with_threads(4)
442 .with_min_indexes_per_thread(2);
443
444 assert_eq!(config.num_threads, 4);
445 assert_eq!(config.min_indexes_per_thread, 2);
446 }
447
448 #[test]
449 fn test_results_aggregator() {
450 let mut aggregator = ResultsAggregator::new();
451
452 let mut results1 = MultiIndexResults::new();
453 results1.add(
454 "idx1".to_string(),
455 vec![SearchResult::new(
456 "a".to_string(),
457 0.5,
458 crate::index::DistanceType::L2,
459 )],
460 );
461
462 let mut results2 = MultiIndexResults::new();
463 results2.add(
464 "idx2".to_string(),
465 vec![SearchResult::new(
466 "b".to_string(),
467 0.3,
468 crate::index::DistanceType::L2,
469 )],
470 );
471
472 aggregator.add(results1);
473 aggregator.add(results2);
474
475 assert_eq!(aggregator.results().len(), 2);
476 assert_eq!(aggregator.total_count(), 2);
477
478 let flat = aggregator.flatten_with_query();
479 assert_eq!(flat.len(), 2);
480 assert_eq!(flat[0].0, 0); assert_eq!(flat[1].0, 1); }
483
484 #[test]
485 fn test_incompatible_dimension_skipped() {
486 let mut registry = IndexRegistry::new();
487
488 let mut idx1 = FlatIndex::new(IndexConfig::new(4));
490 idx1.add("a".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
491
492 let mut idx2 = FlatIndex::new(IndexConfig::new(8));
493 idx2.add("b".to_string(), &[1.0; 8]).unwrap();
494
495 registry.register("idx1", idx1).unwrap();
496 registry.register("idx2", idx2).unwrap();
497
498 let searcher = ParallelSearcher::new(®istry);
499
500 let query = [1.0, 0.0, 0.0, 0.0];
502 let results = searcher.search_parallel(&query, 10).unwrap();
503
504 assert_eq!(results.by_index.len(), 1);
505 }
506}