lens_core/pipeline/
fusion.rs

1use crate::search::SearchResult;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use anyhow::Result;
5
6/// Result fusion engine for combining multiple search results
7pub struct ResultFusion {
8    strategies: Vec<Box<dyn FusionStrategy + Send + Sync>>,
9    weights: HashMap<String, f64>,
10}
11
12/// Trait for different fusion strategies
13#[async_trait::async_trait]
14pub trait FusionStrategy {
15    /// Get strategy name
16    fn name(&self) -> &str;
17    
18    /// Fuse results from multiple search systems
19    async fn fuse(&self, results: &[SystemResults]) -> Result<Vec<SearchResult>>;
20    
21    /// Calculate confidence score for fusion
22    fn confidence(&self, results: &[SystemResults]) -> f64;
23}
24
25/// Results from a single search system
26#[derive(Debug, Clone)]
27pub struct SystemResults {
28    pub system_name: String,
29    pub results: Vec<SearchResult>,
30    pub latency_ms: f64,
31    pub confidence: f64,
32}
33
34/// Fused result with provenance tracking
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct FusedResult {
37    pub result: SearchResult,
38    pub fusion_score: f64,
39    pub contributing_systems: Vec<String>,
40    pub fusion_strategy: String,
41    pub confidence: f64,
42}
43
44impl ResultFusion {
45    /// Create new result fusion engine
46    pub fn new() -> Self {
47        let mut fusion = Self {
48            strategies: Vec::new(),
49            weights: HashMap::new(),
50        };
51        
52        // Register default strategies
53        fusion.add_strategy(Box::new(CombSumStrategy::new()));
54        fusion.add_strategy(Box::new(CombMnzStrategy::new()));
55        fusion.add_strategy(Box::new(RankBasedFusion::new()));
56        fusion.add_strategy(Box::new(BordaCountFusion::new()));
57        
58        // Set default weights
59        fusion.set_weight("lex".to_string(), 0.3);
60        fusion.set_weight("symbols".to_string(), 0.4);
61        fusion.set_weight("semantic".to_string(), 0.3);
62        
63        fusion
64    }
65
66    /// Add fusion strategy
67    pub fn add_strategy(&mut self, strategy: Box<dyn FusionStrategy + Send + Sync>) {
68        self.strategies.push(strategy);
69    }
70
71    /// Set weight for search system
72    pub fn set_weight(&mut self, system: String, weight: f64) {
73        self.weights.insert(system, weight);
74    }
75
76    /// Fuse results from multiple systems
77    pub async fn fuse_results(&self, system_results: &[SystemResults]) -> Result<Vec<FusedResult>> {
78        if system_results.is_empty() {
79            return Ok(Vec::new());
80        }
81
82        let mut all_fused_results = Vec::new();
83
84        // Apply each fusion strategy
85        for strategy in &self.strategies {
86            let fused = strategy.fuse(system_results).await?;
87            let confidence = strategy.confidence(system_results);
88            
89            for result in fused {
90                let contributing_systems = system_results
91                    .iter()
92                    .filter(|sys| sys.results.iter().any(|r| r.file_path == result.file_path))
93                    .map(|sys| sys.system_name.clone())
94                    .collect();
95                    
96                all_fused_results.push(FusedResult {
97                    fusion_score: result.score,
98                    contributing_systems,
99                    fusion_strategy: strategy.name().to_string(),
100                    confidence,
101                    result,
102                });
103            }
104        }
105
106        // Select best fusion results (could use ensemble methods here)
107        Ok(self.select_best_fusion(all_fused_results))
108    }
109
110    /// Select the best fusion results using ensemble approach
111    fn select_best_fusion(&self, fused_results: Vec<FusedResult>) -> Vec<FusedResult> {
112        // Group by file path and select highest scoring fusion for each
113        let mut best_results: HashMap<String, FusedResult> = HashMap::new();
114        
115        for result in fused_results {
116            let key = format!("{}:{}", result.result.file_path, result.result.line_number);
117            
118            if let Some(existing) = best_results.get(&key) {
119                if result.fusion_score > existing.fusion_score {
120                    best_results.insert(key, result);
121                }
122            } else {
123                best_results.insert(key, result);
124            }
125        }
126
127        let mut final_results: Vec<FusedResult> = best_results.into_values().collect();
128        final_results.sort_by(|a, b| b.fusion_score.partial_cmp(&a.fusion_score).unwrap());
129        
130        // Limit to top results
131        final_results.truncate(50);
132        
133        final_results
134    }
135}
136
137/// CombSum fusion strategy (sum of normalized scores)
138pub struct CombSumStrategy {
139    name: String,
140}
141
142impl CombSumStrategy {
143    pub fn new() -> Self {
144        Self {
145            name: "combsum".to_string(),
146        }
147    }
148}
149
150#[async_trait::async_trait]
151impl FusionStrategy for CombSumStrategy {
152    fn name(&self) -> &str {
153        &self.name
154    }
155
156    async fn fuse(&self, results: &[SystemResults]) -> Result<Vec<SearchResult>> {
157        let mut score_map: HashMap<String, (SearchResult, f64)> = HashMap::new();
158        
159        // Normalize scores per system and combine
160        for system_result in results {
161            let max_score = system_result.results
162                .iter()
163                .map(|r| r.score)
164                .fold(0.0, f64::max);
165                
166            if max_score > 0.0 {
167                for result in &system_result.results {
168                    let normalized_score = result.score / max_score;
169                    let key = format!("{}:{}", result.file_path, result.line_number);
170                    
171                    if let Some((_, current_score)) = score_map.get(&key) {
172                        score_map.insert(key, (result.clone(), current_score + normalized_score));
173                    } else {
174                        score_map.insert(key, (result.clone(), normalized_score));
175                    }
176                }
177            }
178        }
179
180        let mut fused_results: Vec<SearchResult> = score_map
181            .into_values()
182            .map(|(mut result, score)| {
183                result.score = score;
184                result
185            })
186            .collect();
187            
188        fused_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
189        
190        Ok(fused_results)
191    }
192
193    fn confidence(&self, results: &[SystemResults]) -> f64 {
194        if results.is_empty() {
195            return 0.0;
196        }
197        
198        // Higher confidence when multiple systems agree
199        let avg_confidence: f64 = results.iter().map(|r| r.confidence).sum::<f64>() / results.len() as f64;
200        let agreement_bonus = if results.len() > 1 { 0.1 } else { 0.0 };
201        
202        (avg_confidence + agreement_bonus).min(1.0)
203    }
204}
205
206/// CombMNZ fusion strategy (CombSum * number of non-zero systems)
207pub struct CombMnzStrategy {
208    name: String,
209}
210
211impl CombMnzStrategy {
212    pub fn new() -> Self {
213        Self {
214            name: "combmnz".to_string(),
215        }
216    }
217}
218
219#[async_trait::async_trait]
220impl FusionStrategy for CombMnzStrategy {
221    fn name(&self) -> &str {
222        &self.name
223    }
224
225    async fn fuse(&self, results: &[SystemResults]) -> Result<Vec<SearchResult>> {
226        let mut score_map: HashMap<String, (SearchResult, f64, usize)> = HashMap::new();
227        
228        // Track sum of scores and count of contributing systems
229        for system_result in results {
230            let max_score = system_result.results
231                .iter()
232                .map(|r| r.score)
233                .fold(0.0, f64::max);
234                
235            if max_score > 0.0 {
236                for result in &system_result.results {
237                    let normalized_score = result.score / max_score;
238                    let key = format!("{}:{}", result.file_path, result.line_number);
239                    
240                    if let Some((_, current_score, count)) = score_map.get(&key) {
241                        score_map.insert(key, (result.clone(), current_score + normalized_score, count + 1));
242                    } else {
243                        score_map.insert(key, (result.clone(), normalized_score, 1));
244                    }
245                }
246            }
247        }
248
249        let mut fused_results: Vec<SearchResult> = score_map
250            .into_values()
251            .map(|(mut result, sum_score, count)| {
252                result.score = sum_score * count as f64; // CombMNZ formula
253                result
254            })
255            .collect();
256            
257        fused_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
258        
259        Ok(fused_results)
260    }
261
262    fn confidence(&self, results: &[SystemResults]) -> f64 {
263        // CombMNZ gives higher confidence to results agreed upon by multiple systems
264        if results.is_empty() {
265            return 0.0;
266        }
267        
268        let base_confidence: f64 = results.iter().map(|r| r.confidence).sum::<f64>() / results.len() as f64;
269        let system_bonus = (results.len() as f64 - 1.0) * 0.1; // Bonus for multiple systems
270        
271        (base_confidence + system_bonus).min(1.0)
272    }
273}
274
275/// Rank-based fusion using reciprocal rank
276pub struct RankBasedFusion {
277    name: String,
278}
279
280impl RankBasedFusion {
281    pub fn new() -> Self {
282        Self {
283            name: "rank_fusion".to_string(),
284        }
285    }
286}
287
288#[async_trait::async_trait]
289impl FusionStrategy for RankBasedFusion {
290    fn name(&self) -> &str {
291        &self.name
292    }
293
294    async fn fuse(&self, results: &[SystemResults]) -> Result<Vec<SearchResult>> {
295        let mut score_map: HashMap<String, (SearchResult, f64)> = HashMap::new();
296        
297        for system_result in results {
298            for (rank, result) in system_result.results.iter().enumerate() {
299                let reciprocal_rank = 1.0 / (rank + 1) as f64;
300                let key = format!("{}:{}", result.file_path, result.line_number);
301                
302                if let Some((_, current_score)) = score_map.get(&key) {
303                    score_map.insert(key, (result.clone(), current_score + reciprocal_rank));
304                } else {
305                    score_map.insert(key, (result.clone(), reciprocal_rank));
306                }
307            }
308        }
309
310        let mut fused_results: Vec<SearchResult> = score_map
311            .into_values()
312            .map(|(mut result, score)| {
313                result.score = score;
314                result
315            })
316            .collect();
317            
318        fused_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
319        
320        Ok(fused_results)
321    }
322
323    fn confidence(&self, _results: &[SystemResults]) -> f64 {
324        0.8 // Rank-based fusion is generally reliable
325    }
326}
327
328/// Borda count fusion strategy
329pub struct BordaCountFusion {
330    name: String,
331}
332
333impl BordaCountFusion {
334    pub fn new() -> Self {
335        Self {
336            name: "borda_count".to_string(),
337        }
338    }
339}
340
341#[async_trait::async_trait]
342impl FusionStrategy for BordaCountFusion {
343    fn name(&self) -> &str {
344        &self.name
345    }
346
347    async fn fuse(&self, results: &[SystemResults]) -> Result<Vec<SearchResult>> {
348        let mut score_map: HashMap<String, (SearchResult, f64)> = HashMap::new();
349        
350        for system_result in results {
351            let num_results = system_result.results.len();
352            
353            for (rank, result) in system_result.results.iter().enumerate() {
354                let borda_score = (num_results - rank) as f64; // Higher score for higher rank
355                let key = format!("{}:{}", result.file_path, result.line_number);
356                
357                if let Some((_, current_score)) = score_map.get(&key) {
358                    score_map.insert(key, (result.clone(), current_score + borda_score));
359                } else {
360                    score_map.insert(key, (result.clone(), borda_score));
361                }
362            }
363        }
364
365        let mut fused_results: Vec<SearchResult> = score_map
366            .into_values()
367            .map(|(mut result, score)| {
368                result.score = score;
369                result
370            })
371            .collect();
372            
373        fused_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
374        
375        Ok(fused_results)
376    }
377
378    fn confidence(&self, _results: &[SystemResults]) -> f64 {
379        0.75 // Borda count is moderately reliable
380    }
381}
382
383impl Default for ResultFusion {
384    fn default() -> Self {
385        Self::new()
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use crate::search::SearchResult;
393
394    fn create_test_search_result(file_path: &str, line_number: u32, score: f64) -> SearchResult {
395        SearchResult {
396            file_path: file_path.to_string(),
397            line_number,
398            column: 1,
399            content: format!("test content for {}", file_path),
400            score,
401            result_type: crate::search::SearchResultType::TextMatch,
402            language: Some("rust".to_string()),
403            context_lines: Some(vec![]),
404            lsp_metadata: None,
405        }
406    }
407
408    fn create_test_system_results(system_name: &str, results: Vec<SearchResult>) -> SystemResults {
409        SystemResults {
410            system_name: system_name.to_string(),
411            results,
412            latency_ms: 50.0,
413            confidence: 0.8,
414        }
415    }
416
417    #[test]
418    fn test_result_fusion_creation() {
419        let fusion = ResultFusion::new();
420        assert_eq!(fusion.strategies.len(), 4);
421        assert_eq!(fusion.weights.len(), 3);
422        assert_eq!(fusion.weights.get("lex"), Some(&0.3));
423        assert_eq!(fusion.weights.get("symbols"), Some(&0.4));
424        assert_eq!(fusion.weights.get("semantic"), Some(&0.3));
425    }
426
427    #[test]
428    fn test_result_fusion_add_strategy() {
429        let mut fusion = ResultFusion::new();
430        let initial_count = fusion.strategies.len();
431        
432        fusion.add_strategy(Box::new(CombSumStrategy::new()));
433        assert_eq!(fusion.strategies.len(), initial_count + 1);
434    }
435
436    #[test]
437    fn test_result_fusion_set_weight() {
438        let mut fusion = ResultFusion::new();
439        fusion.set_weight("new_system".to_string(), 0.5);
440        assert_eq!(fusion.weights.get("new_system"), Some(&0.5));
441    }
442
443    #[tokio::test]
444    async fn test_fuse_empty_results() {
445        let fusion = ResultFusion::new();
446        let results = fusion.fuse_results(&[]).await.unwrap();
447        assert!(results.is_empty());
448    }
449
450    #[tokio::test]
451    async fn test_fuse_single_system() {
452        let fusion = ResultFusion::new();
453        let search_results = vec![
454            create_test_search_result("file1.rs", 10, 0.9),
455            create_test_search_result("file2.rs", 20, 0.7),
456        ];
457        let system_results = vec![create_test_system_results("lex", search_results)];
458        
459        let fused = fusion.fuse_results(&system_results).await.unwrap();
460        assert!(!fused.is_empty());
461        
462        // Check that fusion scores are assigned
463        for result in &fused {
464            assert!(result.fusion_score > 0.0);
465            assert_eq!(result.contributing_systems.len(), 1);
466            assert_eq!(result.contributing_systems[0], "lex");
467        }
468    }
469
470    #[tokio::test]
471    async fn test_fuse_multiple_systems() {
472        let fusion = ResultFusion::new();
473        
474        let lex_results = vec![
475            create_test_search_result("file1.rs", 10, 0.9),
476            create_test_search_result("file2.rs", 20, 0.8),
477        ];
478        let symbols_results = vec![
479            create_test_search_result("file1.rs", 10, 0.8), // Same file/line as lex
480            create_test_search_result("file3.rs", 30, 0.7),
481        ];
482        
483        let system_results = vec![
484            create_test_system_results("lex", lex_results),
485            create_test_system_results("symbols", symbols_results),
486        ];
487        
488        let fused = fusion.fuse_results(&system_results).await.unwrap();
489        assert!(!fused.is_empty());
490        
491        // Check for results from both systems
492        let has_overlapping_result = fused.iter().any(|r| 
493            r.result.file_path == "file1.rs" && r.result.line_number == 10
494        );
495        assert!(has_overlapping_result);
496    }
497
498    #[test]
499    fn test_select_best_fusion() {
500        let fusion = ResultFusion::new();
501        let search_result = create_test_search_result("file1.rs", 10, 0.9);
502        
503        let fused_results = vec![
504            FusedResult {
505                result: search_result.clone(),
506                fusion_score: 0.8,
507                contributing_systems: vec!["lex".to_string()],
508                fusion_strategy: "combsum".to_string(),
509                confidence: 0.8,
510            },
511            FusedResult {
512                result: search_result.clone(),
513                fusion_score: 0.9, // Higher score
514                contributing_systems: vec!["symbols".to_string()],
515                fusion_strategy: "combmnz".to_string(),
516                confidence: 0.9,
517            },
518        ];
519        
520        let best = fusion.select_best_fusion(fused_results);
521        assert_eq!(best.len(), 1);
522        assert_eq!(best[0].fusion_score, 0.9);
523        assert_eq!(best[0].fusion_strategy, "combmnz");
524    }
525
526    #[test]
527    fn test_combsum_strategy() {
528        let strategy = CombSumStrategy::new();
529        assert_eq!(strategy.name(), "combsum");
530    }
531
532    #[tokio::test]
533    async fn test_combsum_fuse() {
534        let strategy = CombSumStrategy::new();
535        let system_results = vec![
536            create_test_system_results("system1", vec![
537                create_test_search_result("file1.rs", 10, 1.0),
538                create_test_search_result("file2.rs", 20, 0.8),
539            ]),
540            create_test_system_results("system2", vec![
541                create_test_search_result("file1.rs", 10, 0.6), // Same file, should combine
542                create_test_search_result("file3.rs", 30, 1.0),
543            ]),
544        ];
545        
546        let fused = strategy.fuse(&system_results).await.unwrap();
547        assert!(!fused.is_empty());
548        
549        // Check that scores are combined for overlapping results
550        let file1_result = fused.iter().find(|r| r.file_path == "file1.rs" && r.line_number == 10);
551        assert!(file1_result.is_some());
552        assert!(file1_result.unwrap().score > 1.0); // Should be sum of normalized scores
553    }
554
555    #[test]
556    fn test_combsum_confidence() {
557        let strategy = CombSumStrategy::new();
558        let system_results = vec![
559            SystemResults {
560                system_name: "system1".to_string(),
561                results: vec![],
562                latency_ms: 50.0,
563                confidence: 0.8,
564            },
565            SystemResults {
566                system_name: "system2".to_string(),
567                results: vec![],
568                latency_ms: 60.0,
569                confidence: 0.9,
570            },
571        ];
572        
573        let confidence = strategy.confidence(&system_results);
574        assert!(confidence > 0.8); // Should include agreement bonus
575        assert!(confidence <= 1.0);
576        
577        // Test empty results
578        assert_eq!(strategy.confidence(&[]), 0.0);
579    }
580
581    #[tokio::test]
582    async fn test_combmnz_fuse() {
583        let strategy = CombMnzStrategy::new();
584        assert_eq!(strategy.name(), "combmnz");
585        
586        let system_results = vec![
587            create_test_system_results("system1", vec![
588                create_test_search_result("file1.rs", 10, 1.0),
589            ]),
590            create_test_system_results("system2", vec![
591                create_test_search_result("file1.rs", 10, 0.8), // Same result in both systems
592            ]),
593        ];
594        
595        let fused = strategy.fuse(&system_results).await.unwrap();
596        assert!(!fused.is_empty());
597        
598        // CombMNZ should boost scores for results found in multiple systems
599        let result = &fused[0];
600        assert!(result.score > 1.0); // Score boosted by multiple systems
601    }
602
603    #[test]
604    fn test_combmnz_confidence() {
605        let strategy = CombMnzStrategy::new();
606        let system_results = vec![
607            SystemResults {
608                system_name: "system1".to_string(),
609                results: vec![],
610                latency_ms: 50.0,
611                confidence: 0.8,
612            },
613            SystemResults {
614                system_name: "system2".to_string(),
615                results: vec![],
616                latency_ms: 60.0,
617                confidence: 0.8,
618            },
619        ];
620        
621        let confidence = strategy.confidence(&system_results);
622        assert!(confidence > 0.8); // Should include system bonus
623        assert_eq!(strategy.confidence(&[]), 0.0);
624    }
625
626    #[tokio::test]
627    async fn test_rank_based_fusion() {
628        let strategy = RankBasedFusion::new();
629        assert_eq!(strategy.name(), "rank_fusion");
630        assert_eq!(strategy.confidence(&[]), 0.8);
631        
632        let system_results = vec![
633            create_test_system_results("system1", vec![
634                create_test_search_result("file1.rs", 10, 1.0), // Rank 1
635                create_test_search_result("file2.rs", 20, 0.9), // Rank 2
636            ]),
637        ];
638        
639        let fused = strategy.fuse(&system_results).await.unwrap();
640        assert_eq!(fused.len(), 2);
641        
642        // Higher ranked results should have higher reciprocal rank scores
643        assert!(fused[0].score >= fused[1].score);
644        assert_eq!(fused[0].score, 1.0); // 1/(0+1)
645        assert_eq!(fused[1].score, 0.5); // 1/(1+1)
646    }
647
648    #[tokio::test]
649    async fn test_borda_count_fusion() {
650        let strategy = BordaCountFusion::new();
651        assert_eq!(strategy.name(), "borda_count");
652        assert_eq!(strategy.confidence(&[]), 0.75);
653        
654        let system_results = vec![
655            create_test_system_results("system1", vec![
656                create_test_search_result("file1.rs", 10, 1.0), // Rank 1: score = 2
657                create_test_search_result("file2.rs", 20, 0.9), // Rank 2: score = 1
658            ]),
659        ];
660        
661        let fused = strategy.fuse(&system_results).await.unwrap();
662        assert_eq!(fused.len(), 2);
663        
664        // Borda count: (num_results - rank)
665        assert_eq!(fused[0].score, 2.0); // 2 - 0
666        assert_eq!(fused[1].score, 1.0); // 2 - 1
667    }
668
669    #[test]
670    fn test_system_results_creation() {
671        let results = vec![create_test_search_result("test.rs", 1, 0.9)];
672        let system_results = create_test_system_results("test_system", results);
673        
674        assert_eq!(system_results.system_name, "test_system");
675        assert_eq!(system_results.results.len(), 1);
676        assert_eq!(system_results.latency_ms, 50.0);
677        assert_eq!(system_results.confidence, 0.8);
678    }
679
680    #[test]
681    fn test_fused_result_creation() {
682        let search_result = create_test_search_result("test.rs", 1, 0.9);
683        let fused_result = FusedResult {
684            result: search_result,
685            fusion_score: 1.5,
686            contributing_systems: vec!["system1".to_string(), "system2".to_string()],
687            fusion_strategy: "combsum".to_string(),
688            confidence: 0.9,
689        };
690        
691        assert_eq!(fused_result.fusion_score, 1.5);
692        assert_eq!(fused_result.contributing_systems.len(), 2);
693        assert_eq!(fused_result.fusion_strategy, "combsum");
694        assert_eq!(fused_result.confidence, 0.9);
695    }
696}