1use crate::search::VectorSearchIndex;
43use crate::types::SearchResult;
44use anyhow::{anyhow, Result};
45use rayon::prelude::*;
46use serde::{Deserialize, Serialize};
47use std::collections::HashMap;
48use tracing::{debug, info};
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct MultiIndexConfig {
53 pub parallel: bool,
55 pub deduplicate: bool,
57 pub merge_strategy: ScoreMergeStrategy,
59}
60
61impl Default for MultiIndexConfig {
62 fn default() -> Self {
63 Self {
64 parallel: true,
65 deduplicate: true,
66 merge_strategy: ScoreMergeStrategy::Max,
67 }
68 }
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
73pub enum ScoreMergeStrategy {
74 Max,
76 Min,
78 Average,
80 First,
82}
83
84#[derive(Debug, Clone)]
86pub struct MultiIndexSearch {
87 config: MultiIndexConfig,
88}
89
90impl MultiIndexSearch {
91 pub fn new() -> Self {
93 Self {
94 config: MultiIndexConfig::default(),
95 }
96 }
97
98 pub fn with_config(config: MultiIndexConfig) -> Self {
100 Self { config }
101 }
102
103 pub fn search(
105 &self,
106 indexes: &[&VectorSearchIndex],
107 query: &[f32],
108 k: usize,
109 ) -> Result<Vec<SearchResult>> {
110 if indexes.is_empty() {
111 return Err(anyhow!("Cannot search across zero indexes"));
112 }
113
114 info!("Searching across {} indexes", indexes.len());
115
116 let all_results: Vec<Vec<SearchResult>> = if self.config.parallel {
118 indexes
119 .par_iter()
120 .map(|index| index.search(query, k).unwrap_or_default())
121 .collect()
122 } else {
123 indexes
124 .iter()
125 .map(|index| index.search(query, k).unwrap_or_default())
126 .collect()
127 };
128
129 let merged = self.merge_results(all_results, k);
131
132 info!("Multi-index search returned {} results", merged.len());
133 Ok(merged)
134 }
135
136 pub fn batch_search(
138 &self,
139 indexes: &[&VectorSearchIndex],
140 queries: &[Vec<f32>],
141 k: usize,
142 ) -> Result<Vec<Vec<SearchResult>>> {
143 if indexes.is_empty() {
144 return Err(anyhow!("Cannot search across zero indexes"));
145 }
146
147 info!(
148 "Batch searching {} queries across {} indexes",
149 queries.len(),
150 indexes.len()
151 );
152
153 let results: Vec<Vec<SearchResult>> = if self.config.parallel {
155 queries
156 .par_iter()
157 .map(|query| self.search(indexes, query, k).unwrap_or_default())
158 .collect()
159 } else {
160 queries
161 .iter()
162 .map(|query| self.search(indexes, query, k).unwrap_or_default())
163 .collect()
164 };
165
166 Ok(results)
167 }
168
169 fn merge_results(&self, all_results: Vec<Vec<SearchResult>>, k: usize) -> Vec<SearchResult> {
171 if !self.config.deduplicate {
172 let mut merged: Vec<SearchResult> = all_results.into_iter().flatten().collect();
174 merged.sort_by(|a, b| {
175 b.score
176 .partial_cmp(&a.score)
177 .unwrap_or(std::cmp::Ordering::Equal)
178 });
179 merged.truncate(k);
180
181 for (i, result) in merged.iter_mut().enumerate() {
183 result.rank = i + 1;
184 }
185
186 return merged;
187 }
188
189 let mut entity_scores: HashMap<String, Vec<f32>> = HashMap::new();
191 let mut entity_distance: HashMap<String, Vec<f32>> = HashMap::new();
192
193 for results in all_results {
194 for result in results {
195 entity_scores
196 .entry(result.entity_id.clone())
197 .or_default()
198 .push(result.score);
199 entity_distance
200 .entry(result.entity_id.clone())
201 .or_default()
202 .push(result.distance);
203 }
204 }
205
206 let mut merged: Vec<SearchResult> = entity_scores
208 .into_iter()
209 .map(|(entity_id, scores)| {
210 let merged_score = match self.config.merge_strategy {
211 ScoreMergeStrategy::Max => {
212 scores.iter().copied().fold(f32::NEG_INFINITY, f32::max)
213 }
214 ScoreMergeStrategy::Min => scores.iter().copied().fold(f32::INFINITY, f32::min),
215 ScoreMergeStrategy::Average => scores.iter().sum::<f32>() / scores.len() as f32,
216 ScoreMergeStrategy::First => scores[0],
217 };
218
219 let merged_distance = match self.config.merge_strategy {
220 ScoreMergeStrategy::Max => entity_distance[&entity_id]
221 .iter()
222 .copied()
223 .fold(f32::INFINITY, f32::min),
224 ScoreMergeStrategy::Min => entity_distance[&entity_id]
225 .iter()
226 .copied()
227 .fold(f32::NEG_INFINITY, f32::max),
228 ScoreMergeStrategy::Average => {
229 entity_distance[&entity_id].iter().sum::<f32>()
230 / entity_distance[&entity_id].len() as f32
231 }
232 ScoreMergeStrategy::First => entity_distance[&entity_id][0],
233 };
234
235 SearchResult {
236 entity_id,
237 score: merged_score,
238 distance: merged_distance,
239 rank: 0, }
241 })
242 .collect();
243
244 merged.sort_by(|a, b| {
246 b.score
247 .partial_cmp(&a.score)
248 .unwrap_or(std::cmp::Ordering::Equal)
249 });
250
251 merged.truncate(k);
253 for (i, result) in merged.iter_mut().enumerate() {
254 result.rank = i + 1;
255 }
256
257 debug!("Merged and deduplicated to {} results", merged.len());
258 merged
259 }
260
261 pub fn config(&self) -> &MultiIndexConfig {
263 &self.config
264 }
265}
266
267impl Default for MultiIndexSearch {
268 fn default() -> Self {
269 Self::new()
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::SearchConfig;
277 use std::collections::HashMap;
278
279 fn create_test_index(id_prefix: &str, count: usize, dim: usize) -> VectorSearchIndex {
280 let mut embeddings = HashMap::new();
281 for i in 0..count {
282 let vec: Vec<f32> = (0..dim).map(|j| (i + j) as f32 * 0.1).collect();
283 embeddings.insert(format!("{}_{}", id_prefix, i), vec);
284 }
285
286 let mut index = VectorSearchIndex::new(SearchConfig::default());
287 index.build(&embeddings).unwrap();
288 index
289 }
290
291 #[test]
292 fn test_multi_index_search() {
293 let index1 = create_test_index("doc", 5, 3);
294 let index2 = create_test_index("article", 5, 3);
295
296 let multi_search = MultiIndexSearch::new();
297 let query = vec![0.1, 0.2, 0.3];
298 let results = multi_search.search(&[&index1, &index2], &query, 5).unwrap();
299
300 assert!(results.len() <= 5);
301 assert!(!results.is_empty());
302
303 for i in 1..results.len() {
305 assert!(results[i - 1].score >= results[i].score);
306 }
307 }
308
309 #[test]
310 fn test_multi_index_deduplication() {
311 let mut embeddings1 = HashMap::new();
313 embeddings1.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]);
314 embeddings1.insert("doc2".to_string(), vec![0.9, 0.1, 0.0]);
315 let mut index1 = VectorSearchIndex::new(SearchConfig::default());
316 index1.build(&embeddings1).unwrap();
317
318 let mut embeddings2 = HashMap::new();
319 embeddings2.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]); embeddings2.insert("doc3".to_string(), vec![0.8, 0.2, 0.0]);
321 let mut index2 = VectorSearchIndex::new(SearchConfig::default());
322 index2.build(&embeddings2).unwrap();
323
324 let config = MultiIndexConfig {
325 parallel: false,
326 deduplicate: true,
327 merge_strategy: ScoreMergeStrategy::Max,
328 };
329
330 let multi_search = MultiIndexSearch::with_config(config);
331 let query = vec![1.0, 0.0, 0.0];
332 let results = multi_search
333 .search(&[&index1, &index2], &query, 10)
334 .unwrap();
335
336 assert_eq!(results.len(), 3);
338
339 let entity_ids: Vec<String> = results.iter().map(|r| r.entity_id.clone()).collect();
340 assert!(entity_ids.contains(&"doc1".to_string()));
341 assert!(entity_ids.contains(&"doc2".to_string()));
342 assert!(entity_ids.contains(&"doc3".to_string()));
343 }
344
345 #[test]
346 fn test_multi_index_no_deduplication() {
347 let mut embeddings1 = HashMap::new();
348 embeddings1.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]);
349 let mut index1 = VectorSearchIndex::new(SearchConfig::default());
350 index1.build(&embeddings1).unwrap();
351
352 let mut embeddings2 = HashMap::new();
353 embeddings2.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]); let mut index2 = VectorSearchIndex::new(SearchConfig::default());
355 index2.build(&embeddings2).unwrap();
356
357 let config = MultiIndexConfig {
358 parallel: false,
359 deduplicate: false,
360 merge_strategy: ScoreMergeStrategy::Max,
361 };
362
363 let multi_search = MultiIndexSearch::with_config(config);
364 let query = vec![1.0, 0.0, 0.0];
365 let results = multi_search
366 .search(&[&index1, &index2], &query, 10)
367 .unwrap();
368
369 assert_eq!(results.len(), 2);
371 }
372
373 #[test]
374 fn test_merge_strategy_max() {
375 let config = MultiIndexConfig {
376 parallel: false,
377 deduplicate: true,
378 merge_strategy: ScoreMergeStrategy::Max,
379 };
380
381 let multi_search = MultiIndexSearch::with_config(config);
382
383 assert_eq!(
386 multi_search.config().merge_strategy,
387 ScoreMergeStrategy::Max
388 );
389 }
390
391 #[test]
392 fn test_batch_search() {
393 let index1 = create_test_index("doc", 5, 3);
394 let index2 = create_test_index("article", 5, 3);
395
396 let multi_search = MultiIndexSearch::new();
397 let queries = vec![
398 vec![0.1, 0.2, 0.3],
399 vec![0.2, 0.3, 0.4],
400 vec![0.3, 0.4, 0.5],
401 ];
402
403 let results = multi_search
404 .batch_search(&[&index1, &index2], &queries, 3)
405 .unwrap();
406
407 assert_eq!(results.len(), 3);
408 for result_set in results {
409 assert!(result_set.len() <= 3);
410 }
411 }
412
413 #[test]
414 fn test_empty_indexes() {
415 let multi_search = MultiIndexSearch::new();
416 let query = vec![0.1, 0.2, 0.3];
417 let result = multi_search.search(&[], &query, 10);
418
419 assert!(result.is_err());
420 }
421}