1use crate::search::SearchResult;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use anyhow::Result;
5
6pub struct ResultFusion {
8 strategies: Vec<Box<dyn FusionStrategy + Send + Sync>>,
9 weights: HashMap<String, f64>,
10}
11
12#[async_trait::async_trait]
14pub trait FusionStrategy {
15 fn name(&self) -> &str;
17
18 async fn fuse(&self, results: &[SystemResults]) -> Result<Vec<SearchResult>>;
20
21 fn confidence(&self, results: &[SystemResults]) -> f64;
23}
24
25#[derive(Debug, Clone)]
27pub struct SystemResults {
28 pub system_name: String,
29 pub results: Vec<SearchResult>,
30 pub latency_ms: f64,
31 pub confidence: f64,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct FusedResult {
37 pub result: SearchResult,
38 pub fusion_score: f64,
39 pub contributing_systems: Vec<String>,
40 pub fusion_strategy: String,
41 pub confidence: f64,
42}
43
44impl ResultFusion {
45 pub fn new() -> Self {
47 let mut fusion = Self {
48 strategies: Vec::new(),
49 weights: HashMap::new(),
50 };
51
52 fusion.add_strategy(Box::new(CombSumStrategy::new()));
54 fusion.add_strategy(Box::new(CombMnzStrategy::new()));
55 fusion.add_strategy(Box::new(RankBasedFusion::new()));
56 fusion.add_strategy(Box::new(BordaCountFusion::new()));
57
58 fusion.set_weight("lex".to_string(), 0.3);
60 fusion.set_weight("symbols".to_string(), 0.4);
61 fusion.set_weight("semantic".to_string(), 0.3);
62
63 fusion
64 }
65
66 pub fn add_strategy(&mut self, strategy: Box<dyn FusionStrategy + Send + Sync>) {
68 self.strategies.push(strategy);
69 }
70
71 pub fn set_weight(&mut self, system: String, weight: f64) {
73 self.weights.insert(system, weight);
74 }
75
76 pub async fn fuse_results(&self, system_results: &[SystemResults]) -> Result<Vec<FusedResult>> {
78 if system_results.is_empty() {
79 return Ok(Vec::new());
80 }
81
82 let mut all_fused_results = Vec::new();
83
84 for strategy in &self.strategies {
86 let fused = strategy.fuse(system_results).await?;
87 let confidence = strategy.confidence(system_results);
88
89 for result in fused {
90 let contributing_systems = system_results
91 .iter()
92 .filter(|sys| sys.results.iter().any(|r| r.file_path == result.file_path))
93 .map(|sys| sys.system_name.clone())
94 .collect();
95
96 all_fused_results.push(FusedResult {
97 fusion_score: result.score,
98 contributing_systems,
99 fusion_strategy: strategy.name().to_string(),
100 confidence,
101 result,
102 });
103 }
104 }
105
106 Ok(self.select_best_fusion(all_fused_results))
108 }
109
110 fn select_best_fusion(&self, fused_results: Vec<FusedResult>) -> Vec<FusedResult> {
112 let mut best_results: HashMap<String, FusedResult> = HashMap::new();
114
115 for result in fused_results {
116 let key = format!("{}:{}", result.result.file_path, result.result.line_number);
117
118 if let Some(existing) = best_results.get(&key) {
119 if result.fusion_score > existing.fusion_score {
120 best_results.insert(key, result);
121 }
122 } else {
123 best_results.insert(key, result);
124 }
125 }
126
127 let mut final_results: Vec<FusedResult> = best_results.into_values().collect();
128 final_results.sort_by(|a, b| b.fusion_score.partial_cmp(&a.fusion_score).unwrap());
129
130 final_results.truncate(50);
132
133 final_results
134 }
135}
136
137pub struct CombSumStrategy {
139 name: String,
140}
141
142impl CombSumStrategy {
143 pub fn new() -> Self {
144 Self {
145 name: "combsum".to_string(),
146 }
147 }
148}
149
150#[async_trait::async_trait]
151impl FusionStrategy for CombSumStrategy {
152 fn name(&self) -> &str {
153 &self.name
154 }
155
156 async fn fuse(&self, results: &[SystemResults]) -> Result<Vec<SearchResult>> {
157 let mut score_map: HashMap<String, (SearchResult, f64)> = HashMap::new();
158
159 for system_result in results {
161 let max_score = system_result.results
162 .iter()
163 .map(|r| r.score)
164 .fold(0.0, f64::max);
165
166 if max_score > 0.0 {
167 for result in &system_result.results {
168 let normalized_score = result.score / max_score;
169 let key = format!("{}:{}", result.file_path, result.line_number);
170
171 if let Some((_, current_score)) = score_map.get(&key) {
172 score_map.insert(key, (result.clone(), current_score + normalized_score));
173 } else {
174 score_map.insert(key, (result.clone(), normalized_score));
175 }
176 }
177 }
178 }
179
180 let mut fused_results: Vec<SearchResult> = score_map
181 .into_values()
182 .map(|(mut result, score)| {
183 result.score = score;
184 result
185 })
186 .collect();
187
188 fused_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
189
190 Ok(fused_results)
191 }
192
193 fn confidence(&self, results: &[SystemResults]) -> f64 {
194 if results.is_empty() {
195 return 0.0;
196 }
197
198 let avg_confidence: f64 = results.iter().map(|r| r.confidence).sum::<f64>() / results.len() as f64;
200 let agreement_bonus = if results.len() > 1 { 0.1 } else { 0.0 };
201
202 (avg_confidence + agreement_bonus).min(1.0)
203 }
204}
205
206pub struct CombMnzStrategy {
208 name: String,
209}
210
211impl CombMnzStrategy {
212 pub fn new() -> Self {
213 Self {
214 name: "combmnz".to_string(),
215 }
216 }
217}
218
219#[async_trait::async_trait]
220impl FusionStrategy for CombMnzStrategy {
221 fn name(&self) -> &str {
222 &self.name
223 }
224
225 async fn fuse(&self, results: &[SystemResults]) -> Result<Vec<SearchResult>> {
226 let mut score_map: HashMap<String, (SearchResult, f64, usize)> = HashMap::new();
227
228 for system_result in results {
230 let max_score = system_result.results
231 .iter()
232 .map(|r| r.score)
233 .fold(0.0, f64::max);
234
235 if max_score > 0.0 {
236 for result in &system_result.results {
237 let normalized_score = result.score / max_score;
238 let key = format!("{}:{}", result.file_path, result.line_number);
239
240 if let Some((_, current_score, count)) = score_map.get(&key) {
241 score_map.insert(key, (result.clone(), current_score + normalized_score, count + 1));
242 } else {
243 score_map.insert(key, (result.clone(), normalized_score, 1));
244 }
245 }
246 }
247 }
248
249 let mut fused_results: Vec<SearchResult> = score_map
250 .into_values()
251 .map(|(mut result, sum_score, count)| {
252 result.score = sum_score * count as f64; result
254 })
255 .collect();
256
257 fused_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
258
259 Ok(fused_results)
260 }
261
262 fn confidence(&self, results: &[SystemResults]) -> f64 {
263 if results.is_empty() {
265 return 0.0;
266 }
267
268 let base_confidence: f64 = results.iter().map(|r| r.confidence).sum::<f64>() / results.len() as f64;
269 let system_bonus = (results.len() as f64 - 1.0) * 0.1; (base_confidence + system_bonus).min(1.0)
272 }
273}
274
275pub struct RankBasedFusion {
277 name: String,
278}
279
280impl RankBasedFusion {
281 pub fn new() -> Self {
282 Self {
283 name: "rank_fusion".to_string(),
284 }
285 }
286}
287
288#[async_trait::async_trait]
289impl FusionStrategy for RankBasedFusion {
290 fn name(&self) -> &str {
291 &self.name
292 }
293
294 async fn fuse(&self, results: &[SystemResults]) -> Result<Vec<SearchResult>> {
295 let mut score_map: HashMap<String, (SearchResult, f64)> = HashMap::new();
296
297 for system_result in results {
298 for (rank, result) in system_result.results.iter().enumerate() {
299 let reciprocal_rank = 1.0 / (rank + 1) as f64;
300 let key = format!("{}:{}", result.file_path, result.line_number);
301
302 if let Some((_, current_score)) = score_map.get(&key) {
303 score_map.insert(key, (result.clone(), current_score + reciprocal_rank));
304 } else {
305 score_map.insert(key, (result.clone(), reciprocal_rank));
306 }
307 }
308 }
309
310 let mut fused_results: Vec<SearchResult> = score_map
311 .into_values()
312 .map(|(mut result, score)| {
313 result.score = score;
314 result
315 })
316 .collect();
317
318 fused_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
319
320 Ok(fused_results)
321 }
322
323 fn confidence(&self, _results: &[SystemResults]) -> f64 {
324 0.8 }
326}
327
328pub struct BordaCountFusion {
330 name: String,
331}
332
333impl BordaCountFusion {
334 pub fn new() -> Self {
335 Self {
336 name: "borda_count".to_string(),
337 }
338 }
339}
340
341#[async_trait::async_trait]
342impl FusionStrategy for BordaCountFusion {
343 fn name(&self) -> &str {
344 &self.name
345 }
346
347 async fn fuse(&self, results: &[SystemResults]) -> Result<Vec<SearchResult>> {
348 let mut score_map: HashMap<String, (SearchResult, f64)> = HashMap::new();
349
350 for system_result in results {
351 let num_results = system_result.results.len();
352
353 for (rank, result) in system_result.results.iter().enumerate() {
354 let borda_score = (num_results - rank) as f64; let key = format!("{}:{}", result.file_path, result.line_number);
356
357 if let Some((_, current_score)) = score_map.get(&key) {
358 score_map.insert(key, (result.clone(), current_score + borda_score));
359 } else {
360 score_map.insert(key, (result.clone(), borda_score));
361 }
362 }
363 }
364
365 let mut fused_results: Vec<SearchResult> = score_map
366 .into_values()
367 .map(|(mut result, score)| {
368 result.score = score;
369 result
370 })
371 .collect();
372
373 fused_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
374
375 Ok(fused_results)
376 }
377
378 fn confidence(&self, _results: &[SystemResults]) -> f64 {
379 0.75 }
381}
382
383impl Default for ResultFusion {
384 fn default() -> Self {
385 Self::new()
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use crate::search::SearchResult;
393
394 fn create_test_search_result(file_path: &str, line_number: u32, score: f64) -> SearchResult {
395 SearchResult {
396 file_path: file_path.to_string(),
397 line_number,
398 column: 1,
399 content: format!("test content for {}", file_path),
400 score,
401 result_type: crate::search::SearchResultType::TextMatch,
402 language: Some("rust".to_string()),
403 context_lines: Some(vec![]),
404 lsp_metadata: None,
405 }
406 }
407
408 fn create_test_system_results(system_name: &str, results: Vec<SearchResult>) -> SystemResults {
409 SystemResults {
410 system_name: system_name.to_string(),
411 results,
412 latency_ms: 50.0,
413 confidence: 0.8,
414 }
415 }
416
417 #[test]
418 fn test_result_fusion_creation() {
419 let fusion = ResultFusion::new();
420 assert_eq!(fusion.strategies.len(), 4);
421 assert_eq!(fusion.weights.len(), 3);
422 assert_eq!(fusion.weights.get("lex"), Some(&0.3));
423 assert_eq!(fusion.weights.get("symbols"), Some(&0.4));
424 assert_eq!(fusion.weights.get("semantic"), Some(&0.3));
425 }
426
427 #[test]
428 fn test_result_fusion_add_strategy() {
429 let mut fusion = ResultFusion::new();
430 let initial_count = fusion.strategies.len();
431
432 fusion.add_strategy(Box::new(CombSumStrategy::new()));
433 assert_eq!(fusion.strategies.len(), initial_count + 1);
434 }
435
436 #[test]
437 fn test_result_fusion_set_weight() {
438 let mut fusion = ResultFusion::new();
439 fusion.set_weight("new_system".to_string(), 0.5);
440 assert_eq!(fusion.weights.get("new_system"), Some(&0.5));
441 }
442
443 #[tokio::test]
444 async fn test_fuse_empty_results() {
445 let fusion = ResultFusion::new();
446 let results = fusion.fuse_results(&[]).await.unwrap();
447 assert!(results.is_empty());
448 }
449
450 #[tokio::test]
451 async fn test_fuse_single_system() {
452 let fusion = ResultFusion::new();
453 let search_results = vec![
454 create_test_search_result("file1.rs", 10, 0.9),
455 create_test_search_result("file2.rs", 20, 0.7),
456 ];
457 let system_results = vec![create_test_system_results("lex", search_results)];
458
459 let fused = fusion.fuse_results(&system_results).await.unwrap();
460 assert!(!fused.is_empty());
461
462 for result in &fused {
464 assert!(result.fusion_score > 0.0);
465 assert_eq!(result.contributing_systems.len(), 1);
466 assert_eq!(result.contributing_systems[0], "lex");
467 }
468 }
469
470 #[tokio::test]
471 async fn test_fuse_multiple_systems() {
472 let fusion = ResultFusion::new();
473
474 let lex_results = vec![
475 create_test_search_result("file1.rs", 10, 0.9),
476 create_test_search_result("file2.rs", 20, 0.8),
477 ];
478 let symbols_results = vec![
479 create_test_search_result("file1.rs", 10, 0.8), create_test_search_result("file3.rs", 30, 0.7),
481 ];
482
483 let system_results = vec![
484 create_test_system_results("lex", lex_results),
485 create_test_system_results("symbols", symbols_results),
486 ];
487
488 let fused = fusion.fuse_results(&system_results).await.unwrap();
489 assert!(!fused.is_empty());
490
491 let has_overlapping_result = fused.iter().any(|r|
493 r.result.file_path == "file1.rs" && r.result.line_number == 10
494 );
495 assert!(has_overlapping_result);
496 }
497
498 #[test]
499 fn test_select_best_fusion() {
500 let fusion = ResultFusion::new();
501 let search_result = create_test_search_result("file1.rs", 10, 0.9);
502
503 let fused_results = vec![
504 FusedResult {
505 result: search_result.clone(),
506 fusion_score: 0.8,
507 contributing_systems: vec!["lex".to_string()],
508 fusion_strategy: "combsum".to_string(),
509 confidence: 0.8,
510 },
511 FusedResult {
512 result: search_result.clone(),
513 fusion_score: 0.9, contributing_systems: vec!["symbols".to_string()],
515 fusion_strategy: "combmnz".to_string(),
516 confidence: 0.9,
517 },
518 ];
519
520 let best = fusion.select_best_fusion(fused_results);
521 assert_eq!(best.len(), 1);
522 assert_eq!(best[0].fusion_score, 0.9);
523 assert_eq!(best[0].fusion_strategy, "combmnz");
524 }
525
526 #[test]
527 fn test_combsum_strategy() {
528 let strategy = CombSumStrategy::new();
529 assert_eq!(strategy.name(), "combsum");
530 }
531
532 #[tokio::test]
533 async fn test_combsum_fuse() {
534 let strategy = CombSumStrategy::new();
535 let system_results = vec![
536 create_test_system_results("system1", vec![
537 create_test_search_result("file1.rs", 10, 1.0),
538 create_test_search_result("file2.rs", 20, 0.8),
539 ]),
540 create_test_system_results("system2", vec![
541 create_test_search_result("file1.rs", 10, 0.6), create_test_search_result("file3.rs", 30, 1.0),
543 ]),
544 ];
545
546 let fused = strategy.fuse(&system_results).await.unwrap();
547 assert!(!fused.is_empty());
548
549 let file1_result = fused.iter().find(|r| r.file_path == "file1.rs" && r.line_number == 10);
551 assert!(file1_result.is_some());
552 assert!(file1_result.unwrap().score > 1.0); }
554
555 #[test]
556 fn test_combsum_confidence() {
557 let strategy = CombSumStrategy::new();
558 let system_results = vec![
559 SystemResults {
560 system_name: "system1".to_string(),
561 results: vec![],
562 latency_ms: 50.0,
563 confidence: 0.8,
564 },
565 SystemResults {
566 system_name: "system2".to_string(),
567 results: vec![],
568 latency_ms: 60.0,
569 confidence: 0.9,
570 },
571 ];
572
573 let confidence = strategy.confidence(&system_results);
574 assert!(confidence > 0.8); assert!(confidence <= 1.0);
576
577 assert_eq!(strategy.confidence(&[]), 0.0);
579 }
580
581 #[tokio::test]
582 async fn test_combmnz_fuse() {
583 let strategy = CombMnzStrategy::new();
584 assert_eq!(strategy.name(), "combmnz");
585
586 let system_results = vec![
587 create_test_system_results("system1", vec![
588 create_test_search_result("file1.rs", 10, 1.0),
589 ]),
590 create_test_system_results("system2", vec![
591 create_test_search_result("file1.rs", 10, 0.8), ]),
593 ];
594
595 let fused = strategy.fuse(&system_results).await.unwrap();
596 assert!(!fused.is_empty());
597
598 let result = &fused[0];
600 assert!(result.score > 1.0); }
602
603 #[test]
604 fn test_combmnz_confidence() {
605 let strategy = CombMnzStrategy::new();
606 let system_results = vec![
607 SystemResults {
608 system_name: "system1".to_string(),
609 results: vec![],
610 latency_ms: 50.0,
611 confidence: 0.8,
612 },
613 SystemResults {
614 system_name: "system2".to_string(),
615 results: vec![],
616 latency_ms: 60.0,
617 confidence: 0.8,
618 },
619 ];
620
621 let confidence = strategy.confidence(&system_results);
622 assert!(confidence > 0.8); assert_eq!(strategy.confidence(&[]), 0.0);
624 }
625
626 #[tokio::test]
627 async fn test_rank_based_fusion() {
628 let strategy = RankBasedFusion::new();
629 assert_eq!(strategy.name(), "rank_fusion");
630 assert_eq!(strategy.confidence(&[]), 0.8);
631
632 let system_results = vec![
633 create_test_system_results("system1", vec![
634 create_test_search_result("file1.rs", 10, 1.0), create_test_search_result("file2.rs", 20, 0.9), ]),
637 ];
638
639 let fused = strategy.fuse(&system_results).await.unwrap();
640 assert_eq!(fused.len(), 2);
641
642 assert!(fused[0].score >= fused[1].score);
644 assert_eq!(fused[0].score, 1.0); assert_eq!(fused[1].score, 0.5); }
647
648 #[tokio::test]
649 async fn test_borda_count_fusion() {
650 let strategy = BordaCountFusion::new();
651 assert_eq!(strategy.name(), "borda_count");
652 assert_eq!(strategy.confidence(&[]), 0.75);
653
654 let system_results = vec![
655 create_test_system_results("system1", vec![
656 create_test_search_result("file1.rs", 10, 1.0), create_test_search_result("file2.rs", 20, 0.9), ]),
659 ];
660
661 let fused = strategy.fuse(&system_results).await.unwrap();
662 assert_eq!(fused.len(), 2);
663
664 assert_eq!(fused[0].score, 2.0); assert_eq!(fused[1].score, 1.0); }
668
669 #[test]
670 fn test_system_results_creation() {
671 let results = vec![create_test_search_result("test.rs", 1, 0.9)];
672 let system_results = create_test_system_results("test_system", results);
673
674 assert_eq!(system_results.system_name, "test_system");
675 assert_eq!(system_results.results.len(), 1);
676 assert_eq!(system_results.latency_ms, 50.0);
677 assert_eq!(system_results.confidence, 0.8);
678 }
679
680 #[test]
681 fn test_fused_result_creation() {
682 let search_result = create_test_search_result("test.rs", 1, 0.9);
683 let fused_result = FusedResult {
684 result: search_result,
685 fusion_score: 1.5,
686 contributing_systems: vec!["system1".to_string(), "system2".to_string()],
687 fusion_strategy: "combsum".to_string(),
688 confidence: 0.9,
689 };
690
691 assert_eq!(fused_result.fusion_score, 1.5);
692 assert_eq!(fused_result.contributing_systems.len(), 2);
693 assert_eq!(fused_result.fusion_strategy, "combsum");
694 assert_eq!(fused_result.confidence, 0.9);
695 }
696}