use crate::error::Result;
use crate::index::registry::{IndexRegistry, MultiIndexResult, MultiIndexResults};
use crate::index::traits::SearchResult;
use rayon::prelude::*;
#[derive(Debug, Clone)]
pub struct ParallelSearchConfig {
pub num_threads: usize,
pub min_indexes_per_thread: usize,
pub batch_parallel: bool,
}
impl Default for ParallelSearchConfig {
fn default() -> Self {
Self {
num_threads: 0, min_indexes_per_thread: 1,
batch_parallel: true,
}
}
}
impl ParallelSearchConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_threads(mut self, num_threads: usize) -> Self {
self.num_threads = num_threads;
self
}
#[must_use]
pub const fn with_min_indexes_per_thread(mut self, min: usize) -> Self {
self.min_indexes_per_thread = min;
self
}
}
pub struct ParallelSearcher<'a> {
registry: &'a IndexRegistry,
config: ParallelSearchConfig,
}
impl<'a> ParallelSearcher<'a> {
#[must_use]
pub fn new(registry: &'a IndexRegistry) -> Self {
Self {
registry,
config: ParallelSearchConfig::default(),
}
}
#[must_use]
pub fn with_config(registry: &'a IndexRegistry, config: ParallelSearchConfig) -> Self {
Self { registry, config }
}
pub fn search_parallel(&self, query: &[f32], k: usize) -> Result<MultiIndexResults> {
let query_dim = query.len();
let indexes: Vec<_> = self
.registry
.info()
.into_iter()
.filter(|info| info.dimension == query_dim)
.map(|info| info.name)
.collect();
if indexes.is_empty() {
return Ok(MultiIndexResults::new());
}
let use_parallel = indexes.len() >= self.config.min_indexes_per_thread * 2;
let results: Vec<MultiIndexResult> = if use_parallel {
indexes
.par_iter()
.filter_map(|name| {
self.registry
.search(name, query, k)
.ok()
.map(|results| MultiIndexResult {
index_name: name.clone(),
results,
})
})
.collect()
} else {
indexes
.iter()
.filter_map(|name| {
self.registry
.search(name, query, k)
.ok()
.map(|results| MultiIndexResult {
index_name: name.clone(),
results,
})
})
.collect()
};
let total_count = results.iter().map(|r| r.results.len()).sum();
Ok(MultiIndexResults {
by_index: results,
total_count,
})
}
pub fn search_indexes_parallel(
&self,
names: &[&str],
query: &[f32],
k: usize,
) -> Result<MultiIndexResults> {
let use_parallel = names.len() >= self.config.min_indexes_per_thread * 2;
let results: Vec<MultiIndexResult> = if use_parallel {
names
.par_iter()
.filter_map(|name| {
self.registry
.search(name, query, k)
.ok()
.map(|results| MultiIndexResult {
index_name: (*name).to_string(),
results,
})
})
.collect()
} else {
names
.iter()
.filter_map(|name| {
self.registry
.search(name, query, k)
.ok()
.map(|results| MultiIndexResult {
index_name: (*name).to_string(),
results,
})
})
.collect()
};
let total_count = results.iter().map(|r| r.results.len()).sum();
Ok(MultiIndexResults {
by_index: results,
total_count,
})
}
pub fn search_batch(&self, queries: &[Vec<f32>], k: usize) -> Vec<Result<MultiIndexResults>> {
if self.config.batch_parallel && queries.len() > 1 {
queries
.par_iter()
.map(|query| self.search_parallel(query, k))
.collect()
} else {
queries
.iter()
.map(|query| self.search_parallel(query, k))
.collect()
}
}
pub fn search_indexes_batch(
&self,
names: &[&str],
queries: &[Vec<f32>],
k: usize,
) -> Vec<Result<MultiIndexResults>> {
if self.config.batch_parallel && queries.len() > 1 {
queries
.par_iter()
.map(|query| self.search_indexes_parallel(names, query, k))
.collect()
} else {
queries
.iter()
.map(|query| self.search_indexes_parallel(names, query, k))
.collect()
}
}
}
pub fn parallel_add_batch(
registry: &mut IndexRegistry,
index_name: &str,
ids: Vec<String>,
vectors: &[Vec<f32>],
) -> Result<()> {
if ids.len() != vectors.len() {
return Err(crate::error::Error::InvalidQuery {
reason: format!(
"IDs count ({}) doesn't match vectors count ({})",
ids.len(),
vectors.len()
),
});
}
for (id, vector) in ids.into_iter().zip(vectors.iter()) {
registry.add(index_name, id, vector)?;
}
Ok(())
}
#[derive(Debug, Default)]
pub struct ResultsAggregator {
results: Vec<MultiIndexResults>,
}
impl ResultsAggregator {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, results: MultiIndexResults) {
self.results.push(results);
}
#[must_use]
pub fn results(&self) -> &[MultiIndexResults] {
&self.results
}
#[must_use]
pub fn total_count(&self) -> usize {
self.results.iter().map(|r| r.total_count).sum()
}
#[must_use]
pub fn flatten_with_query(&self) -> Vec<(usize, String, SearchResult)> {
self.results
.iter()
.enumerate()
.flat_map(|(qi, mir)| {
mir.by_index.iter().flat_map(move |idx_result| {
idx_result
.results
.iter()
.cloned()
.map(move |r| (qi, idx_result.index_name.clone(), r))
})
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::{FlatIndex, IndexConfig, VectorIndex};
fn setup_test_registry() -> IndexRegistry {
let mut registry = IndexRegistry::new();
let mut idx1 = FlatIndex::new(IndexConfig::new(4));
idx1.add("a1".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
idx1.add("a2".to_string(), &[0.9, 0.1, 0.0, 0.0]).unwrap();
let mut idx2 = FlatIndex::new(IndexConfig::new(4));
idx2.add("b1".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
idx2.add("b2".to_string(), &[0.1, 0.9, 0.0, 0.0]).unwrap();
let mut idx3 = FlatIndex::new(IndexConfig::new(4));
idx3.add("c1".to_string(), &[0.0, 0.0, 1.0, 0.0]).unwrap();
idx3.add("c2".to_string(), &[0.0, 0.1, 0.9, 0.0]).unwrap();
registry.register("idx1", idx1).unwrap();
registry.register("idx2", idx2).unwrap();
registry.register("idx3", idx3).unwrap();
registry
}
#[test]
fn test_parallel_search() {
let registry = setup_test_registry();
let searcher = ParallelSearcher::new(®istry);
let query = [1.0, 0.0, 0.0, 0.0];
let results = searcher.search_parallel(&query, 10).unwrap();
assert_eq!(results.by_index.len(), 3);
assert_eq!(results.total_count, 6); }
#[test]
fn test_search_indexes_parallel() {
let registry = setup_test_registry();
let searcher = ParallelSearcher::new(®istry);
let query = [1.0, 0.0, 0.0, 0.0];
let results = searcher
.search_indexes_parallel(&["idx1", "idx2"], &query, 10)
.unwrap();
assert_eq!(results.by_index.len(), 2);
assert_eq!(results.total_count, 4);
}
#[test]
fn test_search_batch() {
let registry = setup_test_registry();
let searcher = ParallelSearcher::new(®istry);
let queries = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
];
let results = searcher.search_batch(&queries, 10);
assert_eq!(results.len(), 3);
for result in results {
assert!(result.is_ok());
}
}
#[test]
fn test_config_builder() {
let config = ParallelSearchConfig::new()
.with_threads(4)
.with_min_indexes_per_thread(2);
assert_eq!(config.num_threads, 4);
assert_eq!(config.min_indexes_per_thread, 2);
}
#[test]
fn test_results_aggregator() {
let mut aggregator = ResultsAggregator::new();
let mut results1 = MultiIndexResults::new();
results1.add(
"idx1".to_string(),
vec![SearchResult::new(
"a".to_string(),
0.5,
crate::index::DistanceType::L2,
)],
);
let mut results2 = MultiIndexResults::new();
results2.add(
"idx2".to_string(),
vec![SearchResult::new(
"b".to_string(),
0.3,
crate::index::DistanceType::L2,
)],
);
aggregator.add(results1);
aggregator.add(results2);
assert_eq!(aggregator.results().len(), 2);
assert_eq!(aggregator.total_count(), 2);
let flat = aggregator.flatten_with_query();
assert_eq!(flat.len(), 2);
assert_eq!(flat[0].0, 0); assert_eq!(flat[1].0, 1); }
#[test]
fn test_incompatible_dimension_skipped() {
let mut registry = IndexRegistry::new();
let mut idx1 = FlatIndex::new(IndexConfig::new(4));
idx1.add("a".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
let mut idx2 = FlatIndex::new(IndexConfig::new(8));
idx2.add("b".to_string(), &[1.0; 8]).unwrap();
registry.register("idx1", idx1).unwrap();
registry.register("idx2", idx2).unwrap();
let searcher = ParallelSearcher::new(®istry);
let query = [1.0, 0.0, 0.0, 0.0];
let results = searcher.search_parallel(&query, 10).unwrap();
assert_eq!(results.by_index.len(), 1);
}
}