oxirs_vec/
result_fusion.rs

1//! Result fusion and score combination algorithms for merging vector search results
2//!
3//! This module provides advanced algorithms for combining vector search results from
4//! multiple sources, including federated endpoints, different similarity metrics,
5//! and heterogeneous scoring schemes.
6
7use crate::{sparql_integration::VectorServiceResult, Vector};
8use anyhow::{anyhow, Result};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::Duration;
12
13/// Configuration for result fusion
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct FusionConfig {
16    /// Maximum number of results to return after fusion
17    pub max_results: usize,
18    /// Minimum score threshold for inclusion
19    pub min_score_threshold: f32,
20    /// Score normalization strategy
21    pub normalization_strategy: ScoreNormalizationStrategy,
22    /// Fusion algorithm to use
23    pub fusion_algorithm: FusionAlgorithm,
24    /// Weights for different sources (source_id -> weight)
25    pub source_weights: HashMap<String, f32>,
26    /// Enable result diversification
27    pub enable_diversification: bool,
28    /// Diversification factor (0.0 = no diversification, 1.0 = maximum diversification)
29    pub diversification_factor: f32,
30    /// Enable result explanation
31    pub enable_explanation: bool,
32}
33
34impl Default for FusionConfig {
35    fn default() -> Self {
36        Self {
37            max_results: 100,
38            min_score_threshold: 0.0,
39            normalization_strategy: ScoreNormalizationStrategy::MinMax,
40            fusion_algorithm: FusionAlgorithm::CombSum,
41            source_weights: HashMap::new(),
42            enable_diversification: false,
43            diversification_factor: 0.2,
44            enable_explanation: false,
45        }
46    }
47}
48
49/// Score normalization strategies
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub enum ScoreNormalizationStrategy {
52    /// No normalization
53    None,
54    /// Min-max normalization to [0, 1]
55    MinMax,
56    /// Z-score normalization (standardization)
57    ZScore,
58    /// Rank-based normalization
59    Rank,
60    /// Sigmoid normalization
61    Sigmoid,
62    /// Softmax normalization
63    Softmax,
64}
65
66/// Fusion algorithms for combining scores
67#[derive(Debug, Clone, Serialize, Deserialize, Default)]
68pub enum FusionAlgorithm {
69    /// Sum of normalized scores
70    #[default]
71    CombSum,
72    /// Maximum score across sources
73    CombMax,
74    /// Minimum score across sources
75    CombMin,
76    /// Average of scores
77    CombAvg,
78    /// Median of scores
79    CombMedian,
80    /// Weighted sum with source weights
81    WeightedSum,
82    /// Reciprocal rank fusion
83    RRF,
84    /// Borda count fusion
85    BordaCount,
86    /// Condorcet fusion
87    Condorcet,
88    /// Machine learning-based fusion
89    MLFusion,
90}
91
92/// A single result from a vector search
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct VectorSearchResult {
95    /// Resource identifier
96    pub resource: String,
97    /// Similarity score
98    pub score: f32,
99    /// Normalized score (computed during fusion)
100    pub normalized_score: Option<f32>,
101    /// Source identifier
102    pub source: String,
103    /// Original rank in source results
104    pub original_rank: usize,
105    /// Final rank after fusion
106    pub final_rank: Option<usize>,
107    /// Associated vector (optional)
108    pub vector: Option<Vector>,
109    /// Additional metadata
110    pub metadata: HashMap<String, String>,
111    /// Explanation of score computation
112    pub explanation: Option<String>,
113}
114
115/// Collection of results from a single source
116#[derive(Debug, Clone)]
117pub struct SourceResults {
118    /// Source identifier
119    pub source_id: String,
120    /// Results from this source
121    pub results: Vec<VectorSearchResult>,
122    /// Source-specific metadata
123    pub metadata: HashMap<String, String>,
124    /// Response time from source
125    pub response_time: Option<Duration>,
126    /// Source weight (if different from config)
127    pub weight: Option<f32>,
128}
129
130/// Fused results from multiple sources
131#[derive(Debug, Clone)]
132pub struct FusedResults {
133    /// Final ranked results
134    pub results: Vec<VectorSearchResult>,
135    /// Fusion statistics
136    pub fusion_stats: FusionStats,
137    /// Configuration used for fusion
138    pub config: FusionConfig,
139    /// Total processing time
140    pub processing_time: Duration,
141}
142
143/// Statistics about the fusion process
144#[derive(Debug, Clone, Default)]
145pub struct FusionStats {
146    /// Number of input sources
147    pub source_count: usize,
148    /// Total number of input results
149    pub total_input_results: usize,
150    /// Number of unique resources
151    pub unique_resources: usize,
152    /// Number of results after fusion
153    pub final_result_count: usize,
154    /// Average score before normalization
155    pub avg_score_before: f32,
156    /// Average score after normalization
157    pub avg_score_after: f32,
158    /// Score distribution by source
159    pub score_distribution: HashMap<String, ScoreDistribution>,
160    /// Fusion algorithm used
161    pub fusion_algorithm: FusionAlgorithm,
162}
163
164/// Score distribution statistics for a source
165#[derive(Debug, Clone, Default)]
166pub struct ScoreDistribution {
167    pub min: f32,
168    pub max: f32,
169    pub mean: f32,
170    pub std_dev: f32,
171    pub count: usize,
172}
173
174/// Result fusion engine
175pub struct ResultFusionEngine {
176    config: FusionConfig,
177}
178
179impl ResultFusionEngine {
180    /// Create a new fusion engine with default configuration
181    pub fn new() -> Self {
182        Self {
183            config: FusionConfig::default(),
184        }
185    }
186
187    /// Create fusion engine with custom configuration
188    pub fn with_config(config: FusionConfig) -> Self {
189        Self { config }
190    }
191
192    /// Fuse results from multiple sources
193    pub fn fuse_results(&self, sources: Vec<SourceResults>) -> Result<FusedResults> {
194        let start_time = std::time::Instant::now();
195
196        if sources.is_empty() {
197            return Ok(FusedResults {
198                results: Vec::new(),
199                fusion_stats: FusionStats::default(),
200                config: self.config.clone(),
201                processing_time: start_time.elapsed(),
202            });
203        }
204
205        // Collect all results with source information
206        let mut all_results = Vec::new();
207        let mut fusion_stats = FusionStats {
208            source_count: sources.len(),
209            fusion_algorithm: self.config.fusion_algorithm.clone(),
210            ..Default::default()
211        };
212
213        for source in &sources {
214            for (rank, result) in source.results.iter().enumerate() {
215                let mut enriched_result = result.clone();
216                enriched_result.original_rank = rank;
217                enriched_result.source = source.source_id.clone();
218                all_results.push(enriched_result);
219            }
220            fusion_stats.total_input_results += source.results.len();
221        }
222
223        // Calculate score distributions
224        self.calculate_score_distributions(&sources, &mut fusion_stats);
225
226        // Normalize scores
227        let normalized_results = self.normalize_scores(all_results)?;
228
229        // Group results by resource
230        let grouped_results = self.group_by_resource(normalized_results);
231        fusion_stats.unique_resources = grouped_results.len();
232
233        // Apply fusion algorithm
234        let fused_results = self.apply_fusion_algorithm(grouped_results)?;
235
236        // Apply diversification if enabled
237        let diversified_results = if self.config.enable_diversification {
238            self.apply_diversification(fused_results)?
239        } else {
240            fused_results
241        };
242
243        // Filter by threshold and limit
244        let mut final_results = diversified_results
245            .into_iter()
246            .filter(|r| r.score >= self.config.min_score_threshold)
247            .take(self.config.max_results)
248            .collect::<Vec<_>>();
249
250        // Assign final ranks
251        for (rank, result) in final_results.iter_mut().enumerate() {
252            result.final_rank = Some(rank + 1);
253        }
254
255        // Update statistics
256        fusion_stats.final_result_count = final_results.len();
257        if !final_results.is_empty() {
258            fusion_stats.avg_score_after =
259                final_results.iter().map(|r| r.score).sum::<f32>() / final_results.len() as f32;
260        }
261
262        Ok(FusedResults {
263            results: final_results,
264            fusion_stats,
265            config: self.config.clone(),
266            processing_time: start_time.elapsed(),
267        })
268    }
269
270    /// Calculate score distributions for each source
271    fn calculate_score_distributions(&self, sources: &[SourceResults], stats: &mut FusionStats) {
272        for source in sources {
273            if source.results.is_empty() {
274                continue;
275            }
276
277            let scores: Vec<f32> = source.results.iter().map(|r| r.score).collect();
278            let min = scores.iter().fold(f32::INFINITY, |a, &b| a.min(b));
279            let max = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
280            let mean = scores.iter().sum::<f32>() / scores.len() as f32;
281
282            let variance =
283                scores.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / scores.len() as f32;
284            let std_dev = variance.sqrt();
285
286            stats.score_distribution.insert(
287                source.source_id.clone(),
288                ScoreDistribution {
289                    min,
290                    max,
291                    mean,
292                    std_dev,
293                    count: scores.len(),
294                },
295            );
296        }
297
298        // Calculate overall statistics
299        let all_scores: Vec<f32> = sources
300            .iter()
301            .flat_map(|s| s.results.iter().map(|r| r.score))
302            .collect();
303
304        if !all_scores.is_empty() {
305            stats.avg_score_before = all_scores.iter().sum::<f32>() / all_scores.len() as f32;
306        }
307    }
308
309    /// Normalize scores across all results
310    fn normalize_scores(
311        &self,
312        mut results: Vec<VectorSearchResult>,
313    ) -> Result<Vec<VectorSearchResult>> {
314        match self.config.normalization_strategy {
315            ScoreNormalizationStrategy::None => {
316                for result in &mut results {
317                    result.normalized_score = Some(result.score);
318                }
319            }
320            ScoreNormalizationStrategy::MinMax => {
321                self.apply_minmax_normalization(&mut results)?;
322            }
323            ScoreNormalizationStrategy::ZScore => {
324                self.apply_zscore_normalization(&mut results)?;
325            }
326            ScoreNormalizationStrategy::Rank => {
327                self.apply_rank_normalization(&mut results)?;
328            }
329            ScoreNormalizationStrategy::Sigmoid => {
330                self.apply_sigmoid_normalization(&mut results)?;
331            }
332            ScoreNormalizationStrategy::Softmax => {
333                self.apply_softmax_normalization(&mut results)?;
334            }
335        }
336
337        Ok(results)
338    }
339
340    /// Apply min-max normalization
341    fn apply_minmax_normalization(&self, results: &mut [VectorSearchResult]) -> Result<()> {
342        if results.is_empty() {
343            return Ok(());
344        }
345
346        let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
347        let min_score = scores.iter().fold(f32::INFINITY, |a, &b| a.min(b));
348        let max_score = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
349
350        let range = max_score - min_score;
351        if range == 0.0 {
352            for result in results {
353                result.normalized_score = Some(1.0);
354            }
355        } else {
356            for result in results {
357                result.normalized_score = Some((result.score - min_score) / range);
358            }
359        }
360
361        Ok(())
362    }
363
364    /// Apply z-score normalization
365    fn apply_zscore_normalization(&self, results: &mut [VectorSearchResult]) -> Result<()> {
366        if results.is_empty() {
367            return Ok(());
368        }
369
370        let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
371        let mean = scores.iter().sum::<f32>() / scores.len() as f32;
372        let variance =
373            scores.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / scores.len() as f32;
374        let std_dev = variance.sqrt();
375
376        if std_dev == 0.0 {
377            for result in results {
378                result.normalized_score = Some(0.0);
379            }
380        } else {
381            for result in results {
382                result.normalized_score = Some((result.score - mean) / std_dev);
383            }
384        }
385
386        Ok(())
387    }
388
389    /// Apply rank-based normalization
390    fn apply_rank_normalization(&self, results: &mut [VectorSearchResult]) -> Result<()> {
391        if results.is_empty() {
392            return Ok(());
393        }
394
395        // Sort by score descending
396        let mut indexed_results: Vec<(usize, f32)> = results
397            .iter()
398            .enumerate()
399            .map(|(i, r)| (i, r.score))
400            .collect();
401        indexed_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
402
403        // Assign rank-based scores
404        let total_count = results.len() as f32;
405        for (rank, (original_index, _)) in indexed_results.iter().enumerate() {
406            let normalized_score = (total_count - rank as f32) / total_count;
407            results[*original_index].normalized_score = Some(normalized_score);
408        }
409
410        Ok(())
411    }
412
413    /// Apply sigmoid normalization
414    fn apply_sigmoid_normalization(&self, results: &mut [VectorSearchResult]) -> Result<()> {
415        for result in results {
416            let sigmoid_score = 1.0 / (1.0 + (-result.score).exp());
417            result.normalized_score = Some(sigmoid_score);
418        }
419        Ok(())
420    }
421
422    /// Apply softmax normalization
423    fn apply_softmax_normalization(&self, results: &mut [VectorSearchResult]) -> Result<()> {
424        if results.is_empty() {
425            return Ok(());
426        }
427
428        // Calculate softmax
429        let max_score = results
430            .iter()
431            .map(|r| r.score)
432            .fold(f32::NEG_INFINITY, |a, b| a.max(b));
433
434        let exp_scores: Vec<f32> = results
435            .iter()
436            .map(|r| (r.score - max_score).exp())
437            .collect();
438
439        let sum_exp: f32 = exp_scores.iter().sum();
440
441        for (i, result) in results.iter_mut().enumerate() {
442            result.normalized_score = Some(exp_scores[i] / sum_exp);
443        }
444
445        Ok(())
446    }
447
448    /// Group results by resource identifier
449    fn group_by_resource(
450        &self,
451        results: Vec<VectorSearchResult>,
452    ) -> HashMap<String, Vec<VectorSearchResult>> {
453        let mut grouped = HashMap::new();
454
455        for result in results {
456            grouped
457                .entry(result.resource.clone())
458                .or_insert_with(Vec::new)
459                .push(result);
460        }
461
462        grouped
463    }
464
465    /// Apply the configured fusion algorithm
466    fn apply_fusion_algorithm(
467        &self,
468        grouped_results: HashMap<String, Vec<VectorSearchResult>>,
469    ) -> Result<Vec<VectorSearchResult>> {
470        let mut fused_results = Vec::new();
471
472        for (_resource, mut resource_results) in grouped_results {
473            let fused_result = match &self.config.fusion_algorithm {
474                FusionAlgorithm::CombSum => self.apply_combsum(&resource_results)?,
475                FusionAlgorithm::CombMax => self.apply_combmax(&resource_results)?,
476                FusionAlgorithm::CombMin => self.apply_combmin(&resource_results)?,
477                FusionAlgorithm::CombAvg => self.apply_combavg(&resource_results)?,
478                FusionAlgorithm::CombMedian => self.apply_combmedian(&mut resource_results)?,
479                FusionAlgorithm::WeightedSum => self.apply_weighted_sum(&resource_results)?,
480                FusionAlgorithm::RRF => self.apply_rrf(&resource_results)?,
481                FusionAlgorithm::BordaCount => self.apply_borda_count(&resource_results)?,
482                FusionAlgorithm::Condorcet => self.apply_condorcet(&resource_results)?,
483                FusionAlgorithm::MLFusion => self.apply_ml_fusion(&resource_results)?,
484            };
485
486            fused_results.push(fused_result);
487        }
488
489        // Sort by fused score descending
490        fused_results.sort_by(|a, b| {
491            b.score
492                .partial_cmp(&a.score)
493                .unwrap_or(std::cmp::Ordering::Equal)
494        });
495
496        Ok(fused_results)
497    }
498
499    /// CombSum: Sum of normalized scores
500    fn apply_combsum(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
501        let sum_score = results
502            .iter()
503            .map(|r| r.normalized_score.unwrap_or(r.score))
504            .sum::<f32>();
505
506        Ok(self.create_fused_result(results, sum_score, "CombSum"))
507    }
508
509    /// CombMax: Maximum normalized score
510    fn apply_combmax(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
511        let max_score = results
512            .iter()
513            .map(|r| r.normalized_score.unwrap_or(r.score))
514            .fold(f32::NEG_INFINITY, |a, b| a.max(b));
515
516        Ok(self.create_fused_result(results, max_score, "CombMax"))
517    }
518
519    /// CombMin: Minimum normalized score
520    fn apply_combmin(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
521        let min_score = results
522            .iter()
523            .map(|r| r.normalized_score.unwrap_or(r.score))
524            .fold(f32::INFINITY, |a, b| a.min(b));
525
526        Ok(self.create_fused_result(results, min_score, "CombMin"))
527    }
528
529    /// CombAvg: Average of normalized scores
530    fn apply_combavg(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
531        let avg_score = results
532            .iter()
533            .map(|r| r.normalized_score.unwrap_or(r.score))
534            .sum::<f32>()
535            / results.len() as f32;
536
537        Ok(self.create_fused_result(results, avg_score, "CombAvg"))
538    }
539
540    /// CombMedian: Median of normalized scores
541    fn apply_combmedian(&self, results: &mut [VectorSearchResult]) -> Result<VectorSearchResult> {
542        let mut scores: Vec<f32> = results
543            .iter()
544            .map(|r| r.normalized_score.unwrap_or(r.score))
545            .collect();
546
547        scores.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
548
549        let median_score = if scores.len() % 2 == 0 {
550            let mid = scores.len() / 2;
551            (scores[mid - 1] + scores[mid]) / 2.0
552        } else {
553            scores[scores.len() / 2]
554        };
555
556        Ok(self.create_fused_result(results, median_score, "CombMedian"))
557    }
558
559    /// WeightedSum: Weighted sum with source weights
560    fn apply_weighted_sum(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
561        let mut weighted_sum = 0.0;
562        let mut total_weight = 0.0;
563
564        for result in results {
565            let weight = self
566                .config
567                .source_weights
568                .get(&result.source)
569                .copied()
570                .unwrap_or(1.0);
571            let score = result.normalized_score.unwrap_or(result.score);
572            weighted_sum += score * weight;
573            total_weight += weight;
574        }
575
576        let final_score = if total_weight > 0.0 {
577            weighted_sum / total_weight
578        } else {
579            0.0
580        };
581
582        Ok(self.create_fused_result(results, final_score, "WeightedSum"))
583    }
584
585    /// Reciprocal Rank Fusion (RRF)
586    fn apply_rrf(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
587        let k = 60.0; // Standard RRF parameter
588        let rrf_score = results
589            .iter()
590            .map(|r| 1.0 / (k + r.original_rank as f32 + 1.0))
591            .sum::<f32>();
592
593        Ok(self.create_fused_result(results, rrf_score, "RRF"))
594    }
595
596    /// Borda Count fusion
597    fn apply_borda_count(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
598        // For each source, assign points based on rank (higher rank = more points)
599        let total_sources = results.len();
600        let borda_score = results
601            .iter()
602            .map(|r| (total_sources - r.original_rank) as f32)
603            .sum::<f32>();
604
605        Ok(self.create_fused_result(results, borda_score, "BordaCount"))
606    }
607
608    /// Condorcet fusion (simplified pairwise comparison)
609    fn apply_condorcet(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
610        // Simplified Condorcet: average of normalized scores with rank consideration
611        let condorcet_score = results
612            .iter()
613            .map(|r| {
614                let score = r.normalized_score.unwrap_or(r.score);
615                let rank_penalty = 1.0 / (r.original_rank as f32 + 1.0);
616                score * rank_penalty
617            })
618            .sum::<f32>()
619            / results.len() as f32;
620
621        Ok(self.create_fused_result(results, condorcet_score, "Condorcet"))
622    }
623
624    /// Machine learning-based fusion (placeholder)
625    fn apply_ml_fusion(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
626        // Placeholder: In practice, this would use a trained ML model
627        // For now, use a weighted combination of features
628        let mut ml_score = 0.0;
629
630        for result in results {
631            let score = result.normalized_score.unwrap_or(result.score);
632            let rank_feature = 1.0 / (result.original_rank as f32 + 1.0);
633            let source_weight = self
634                .config
635                .source_weights
636                .get(&result.source)
637                .copied()
638                .unwrap_or(1.0);
639
640            // Simple linear combination (in practice, would use trained weights)
641            ml_score += 0.5 * score + 0.3 * rank_feature + 0.2 * source_weight;
642        }
643
644        ml_score /= results.len() as f32;
645
646        Ok(self.create_fused_result(results, ml_score, "MLFusion"))
647    }
648
649    /// Create a fused result from multiple source results
650    fn create_fused_result(
651        &self,
652        results: &[VectorSearchResult],
653        fused_score: f32,
654        algorithm: &str,
655    ) -> VectorSearchResult {
656        let first_result = &results[0];
657        let mut metadata = first_result.metadata.clone();
658
659        // Add fusion information
660        metadata.insert("fusion_algorithm".to_string(), algorithm.to_string());
661        metadata.insert("source_count".to_string(), results.len().to_string());
662        metadata.insert(
663            "sources".to_string(),
664            results
665                .iter()
666                .map(|r| r.source.clone())
667                .collect::<Vec<_>>()
668                .join(","),
669        );
670
671        let explanation = if self.config.enable_explanation {
672            Some(format!(
673                "{} fusion of {} results from sources: [{}] with final score: {:.4}",
674                algorithm,
675                results.len(),
676                results
677                    .iter()
678                    .map(|r| format!("{}:{:.3}", r.source, r.score))
679                    .collect::<Vec<_>>()
680                    .join(", "),
681                fused_score
682            ))
683        } else {
684            None
685        };
686
687        VectorSearchResult {
688            resource: first_result.resource.clone(),
689            score: fused_score,
690            normalized_score: Some(fused_score),
691            source: "FUSED".to_string(),
692            original_rank: 0,
693            final_rank: None,
694            vector: first_result.vector.clone(),
695            metadata,
696            explanation,
697        }
698    }
699
700    /// Apply result diversification to reduce redundancy
701    fn apply_diversification(
702        &self,
703        results: Vec<VectorSearchResult>,
704    ) -> Result<Vec<VectorSearchResult>> {
705        if results.len() <= 1 || self.config.diversification_factor == 0.0 {
706            return Ok(results);
707        }
708
709        let mut diversified = Vec::new();
710        let mut remaining = results;
711
712        // Always take the top result
713        if !remaining.is_empty() {
714            diversified.push(remaining.remove(0));
715        }
716
717        // For each subsequent position, balance relevance and diversity
718        while !remaining.is_empty() && diversified.len() < self.config.max_results {
719            let mut best_index = 0;
720            let mut best_score = f32::NEG_INFINITY;
721
722            for (i, candidate) in remaining.iter().enumerate() {
723                // Calculate diversity penalty
724                let diversity_penalty = self.calculate_diversity_penalty(candidate, &diversified);
725
726                // Combine relevance and diversity
727                let combined_score = (1.0 - self.config.diversification_factor) * candidate.score
728                    + self.config.diversification_factor * diversity_penalty;
729
730                if combined_score > best_score {
731                    best_score = combined_score;
732                    best_index = i;
733                }
734            }
735
736            diversified.push(remaining.remove(best_index));
737        }
738
739        Ok(diversified)
740    }
741
742    /// Calculate diversity penalty for a candidate result
743    fn calculate_diversity_penalty(
744        &self,
745        candidate: &VectorSearchResult,
746        selected: &[VectorSearchResult],
747    ) -> f32 {
748        if selected.is_empty() {
749            return 1.0;
750        }
751
752        // Simple diversity measure based on string similarity
753        let mut min_similarity = f32::INFINITY;
754
755        for selected_result in selected {
756            let similarity =
757                self.calculate_string_similarity(&candidate.resource, &selected_result.resource);
758            min_similarity = min_similarity.min(similarity);
759        }
760
761        // Convert similarity to diversity (higher diversity = lower similarity)
762        1.0 - min_similarity
763    }
764
765    /// Calculate string similarity between two resources
766    fn calculate_string_similarity(&self, s1: &str, s2: &str) -> f32 {
767        // Simple Jaccard similarity on character bigrams
768        let bigrams1 = self.get_character_bigrams(s1);
769        let bigrams2 = self.get_character_bigrams(s2);
770
771        let intersection: usize = bigrams1
772            .iter()
773            .filter(|&bigram| bigrams2.contains(bigram))
774            .count();
775
776        let union_size = bigrams1.len() + bigrams2.len() - intersection;
777
778        if union_size == 0 {
779            1.0
780        } else {
781            intersection as f32 / union_size as f32
782        }
783    }
784
785    /// Get character bigrams from a string
786    fn get_character_bigrams(&self, s: &str) -> std::collections::HashSet<String> {
787        let chars: Vec<char> = s.chars().collect();
788        let mut bigrams = std::collections::HashSet::new();
789
790        for i in 0..chars.len().saturating_sub(1) {
791            let bigram = format!("{}{}", chars[i], chars[i + 1]);
792            bigrams.insert(bigram);
793        }
794
795        bigrams
796    }
797}
798
799impl Default for ResultFusionEngine {
800    fn default() -> Self {
801        Self::new()
802    }
803}
804
805/// Utility functions for working with fusion results
806pub mod fusion_utils {
807    use super::*;
808
809    /// Convert vector service results to source results
810    pub fn convert_service_results(
811        source_id: String,
812        service_result: VectorServiceResult,
813    ) -> Result<SourceResults> {
814        let results = match service_result {
815            VectorServiceResult::SimilarityList(list) => list
816                .into_iter()
817                .enumerate()
818                .map(|(rank, (resource, score))| VectorSearchResult {
819                    resource,
820                    score,
821                    normalized_score: None,
822                    source: source_id.clone(),
823                    original_rank: rank,
824                    final_rank: None,
825                    vector: None,
826                    metadata: HashMap::new(),
827                    explanation: None,
828                })
829                .collect(),
830            VectorServiceResult::DetailedSimilarityList(detailed_list) => detailed_list
831                .into_iter()
832                .enumerate()
833                .map(|(rank, detailed)| VectorSearchResult {
834                    resource: detailed.0,
835                    score: detailed.1,
836                    normalized_score: None,
837                    source: source_id.clone(),
838                    original_rank: rank,
839                    final_rank: None,
840                    vector: None,
841                    metadata: detailed.2,
842                    explanation: None,
843                })
844                .collect(),
845            _ => {
846                return Err(anyhow!(
847                    "Cannot convert non-similarity result to source results"
848                ));
849            }
850        };
851
852        Ok(SourceResults {
853            source_id,
854            results,
855            metadata: HashMap::new(),
856            response_time: None,
857            weight: None,
858        })
859    }
860
861    /// Create source results from simple tuples
862    pub fn create_source_results(source_id: String, results: Vec<(String, f32)>) -> SourceResults {
863        let search_results = results
864            .into_iter()
865            .enumerate()
866            .map(|(rank, (resource, score))| VectorSearchResult {
867                resource,
868                score,
869                normalized_score: None,
870                source: source_id.clone(),
871                original_rank: rank,
872                final_rank: None,
873                vector: None,
874                metadata: HashMap::new(),
875                explanation: None,
876            })
877            .collect();
878
879        SourceResults {
880            source_id,
881            results: search_results,
882            metadata: HashMap::new(),
883            response_time: None,
884            weight: None,
885        }
886    }
887
888    /// Calculate fusion quality metrics
889    pub fn calculate_fusion_quality(
890        fused_results: &FusedResults,
891        ground_truth: Option<&[String]>,
892    ) -> FusionQualityMetrics {
893        let mut metrics = FusionQualityMetrics {
894            result_count: fused_results.results.len(),
895            ..Default::default()
896        };
897        if !fused_results.results.is_empty() {
898            metrics.avg_score = fused_results.results.iter().map(|r| r.score).sum::<f32>()
899                / fused_results.results.len() as f32;
900            metrics.min_score = fused_results
901                .results
902                .iter()
903                .map(|r| r.score)
904                .fold(f32::INFINITY, |a, b| a.min(b));
905            metrics.max_score = fused_results
906                .results
907                .iter()
908                .map(|r| r.score)
909                .fold(f32::NEG_INFINITY, |a, b| a.max(b));
910        }
911
912        // Calculate diversity
913        metrics.diversity = calculate_result_diversity(&fused_results.results);
914
915        // Calculate relevance metrics if ground truth is provided
916        if let Some(gt) = ground_truth {
917            let relevant_count = fused_results
918                .results
919                .iter()
920                .filter(|r| gt.contains(&r.resource))
921                .count();
922
923            metrics.precision = if fused_results.results.is_empty() {
924                0.0
925            } else {
926                relevant_count as f32 / fused_results.results.len() as f32
927            };
928
929            metrics.recall = if gt.is_empty() {
930                0.0
931            } else {
932                relevant_count as f32 / gt.len() as f32
933            };
934
935            metrics.f1_score = if metrics.precision + metrics.recall == 0.0 {
936                0.0
937            } else {
938                2.0 * metrics.precision * metrics.recall / (metrics.precision + metrics.recall)
939            };
940        }
941
942        metrics
943    }
944
945    /// Calculate diversity among results
946    fn calculate_result_diversity(results: &[VectorSearchResult]) -> f32 {
947        if results.len() <= 1 {
948            return 1.0;
949        }
950
951        let mut total_similarity = 0.0;
952        let mut pair_count = 0;
953
954        for i in 0..results.len() {
955            for j in i + 1..results.len() {
956                // Simple string similarity
957                let sim = jaccard_similarity(&results[i].resource, &results[j].resource);
958                total_similarity += sim;
959                pair_count += 1;
960            }
961        }
962
963        if pair_count == 0 {
964            1.0
965        } else {
966            1.0 - (total_similarity / pair_count as f32)
967        }
968    }
969
970    /// Calculate Jaccard similarity between two strings
971    fn jaccard_similarity(s1: &str, s2: &str) -> f32 {
972        let chars1: std::collections::HashSet<char> = s1.chars().collect();
973        let chars2: std::collections::HashSet<char> = s2.chars().collect();
974
975        let intersection = chars1.intersection(&chars2).count();
976        let union = chars1.union(&chars2).count();
977
978        if union == 0 {
979            1.0
980        } else {
981            intersection as f32 / union as f32
982        }
983    }
984}
985
986/// Quality metrics for fusion results
987#[derive(Debug, Clone, Default)]
988pub struct FusionQualityMetrics {
989    pub result_count: usize,
990    pub avg_score: f32,
991    pub min_score: f32,
992    pub max_score: f32,
993    pub diversity: f32,
994    pub precision: f32,
995    pub recall: f32,
996    pub f1_score: f32,
997}
998
999#[cfg(test)]
1000mod tests {
1001    use super::*;
1002
1003    #[test]
1004    fn test_combsum_fusion() {
1005        let fusion_engine = ResultFusionEngine::new();
1006
1007        let source1 = SourceResults {
1008            source_id: "source1".to_string(),
1009            results: vec![
1010                VectorSearchResult {
1011                    resource: "doc1".to_string(),
1012                    score: 0.9,
1013                    normalized_score: None,
1014                    source: "source1".to_string(),
1015                    original_rank: 0,
1016                    final_rank: None,
1017                    vector: None,
1018                    metadata: HashMap::new(),
1019                    explanation: None,
1020                },
1021                VectorSearchResult {
1022                    resource: "doc2".to_string(),
1023                    score: 0.7,
1024                    normalized_score: None,
1025                    source: "source1".to_string(),
1026                    original_rank: 1,
1027                    final_rank: None,
1028                    vector: None,
1029                    metadata: HashMap::new(),
1030                    explanation: None,
1031                },
1032            ],
1033            metadata: HashMap::new(),
1034            response_time: None,
1035            weight: None,
1036        };
1037
1038        let source2 = SourceResults {
1039            source_id: "source2".to_string(),
1040            results: vec![
1041                VectorSearchResult {
1042                    resource: "doc1".to_string(),
1043                    score: 0.8,
1044                    normalized_score: None,
1045                    source: "source2".to_string(),
1046                    original_rank: 0,
1047                    final_rank: None,
1048                    vector: None,
1049                    metadata: HashMap::new(),
1050                    explanation: None,
1051                },
1052                VectorSearchResult {
1053                    resource: "doc3".to_string(),
1054                    score: 0.6,
1055                    normalized_score: None,
1056                    source: "source2".to_string(),
1057                    original_rank: 1,
1058                    final_rank: None,
1059                    vector: None,
1060                    metadata: HashMap::new(),
1061                    explanation: None,
1062                },
1063            ],
1064            metadata: HashMap::new(),
1065            response_time: None,
1066            weight: None,
1067        };
1068
1069        let result = fusion_engine.fuse_results(vec![source1, source2]).unwrap();
1070
1071        assert_eq!(result.results.len(), 3); // doc1, doc2, doc3
1072        assert_eq!(result.fusion_stats.source_count, 2);
1073        assert_eq!(result.fusion_stats.unique_resources, 3);
1074
1075        // doc1 should have highest score (fusion of 0.9 and 0.8)
1076        assert_eq!(result.results[0].resource, "doc1");
1077        assert!(result.results[0].score > result.results[1].score);
1078    }
1079
1080    #[test]
1081    fn test_rrf_fusion() {
1082        let config = FusionConfig {
1083            fusion_algorithm: FusionAlgorithm::RRF,
1084            ..Default::default()
1085        };
1086        let fusion_engine = ResultFusionEngine::with_config(config);
1087
1088        // Create test data where doc2 appears in both sources with different ranks
1089        let source1 = fusion_utils::create_source_results(
1090            "source1".to_string(),
1091            vec![("doc1".to_string(), 0.9), ("doc2".to_string(), 0.7)],
1092        );
1093
1094        let source2 = fusion_utils::create_source_results(
1095            "source2".to_string(),
1096            vec![("doc2".to_string(), 0.8), ("doc3".to_string(), 0.6)],
1097        );
1098
1099        let result = fusion_engine.fuse_results(vec![source1, source2]).unwrap();
1100
1101        assert!(!result.results.is_empty());
1102        assert_eq!(result.fusion_stats.unique_resources, 3);
1103    }
1104
1105    #[test]
1106    fn test_score_normalization() {
1107        let config = FusionConfig {
1108            normalization_strategy: ScoreNormalizationStrategy::MinMax,
1109            ..Default::default()
1110        };
1111        let fusion_engine = ResultFusionEngine::with_config(config);
1112
1113        let source = fusion_utils::create_source_results(
1114            "test".to_string(),
1115            vec![
1116                ("doc1".to_string(), 0.2),
1117                ("doc2".to_string(), 0.8),
1118                ("doc3".to_string(), 0.5),
1119            ],
1120        );
1121
1122        let result = fusion_engine.fuse_results(vec![source]).unwrap();
1123
1124        // After min-max normalization, scores should be in [0, 1]
1125        for res in &result.results {
1126            assert!(res.score >= 0.0 && res.score <= 1.0);
1127        }
1128    }
1129
1130    #[test]
1131    fn test_fusion_quality_metrics() {
1132        let fusion_results = FusedResults {
1133            results: vec![
1134                VectorSearchResult {
1135                    resource: "relevant1".to_string(),
1136                    score: 0.9,
1137                    normalized_score: Some(0.9),
1138                    source: "test".to_string(),
1139                    original_rank: 0,
1140                    final_rank: Some(1),
1141                    vector: None,
1142                    metadata: HashMap::new(),
1143                    explanation: None,
1144                },
1145                VectorSearchResult {
1146                    resource: "irrelevant1".to_string(),
1147                    score: 0.8,
1148                    normalized_score: Some(0.8),
1149                    source: "test".to_string(),
1150                    original_rank: 1,
1151                    final_rank: Some(2),
1152                    vector: None,
1153                    metadata: HashMap::new(),
1154                    explanation: None,
1155                },
1156            ],
1157            fusion_stats: FusionStats::default(),
1158            config: FusionConfig::default(),
1159            processing_time: Duration::from_millis(10),
1160        };
1161
1162        let ground_truth = vec!["relevant1".to_string(), "relevant2".to_string()];
1163        let metrics = fusion_utils::calculate_fusion_quality(&fusion_results, Some(&ground_truth));
1164
1165        assert_eq!(metrics.result_count, 2);
1166        assert_eq!(metrics.precision, 0.5); // 1 relevant out of 2 results
1167        assert_eq!(metrics.recall, 0.5); // 1 relevant out of 2 ground truth
1168        assert!(metrics.diversity > 0.0);
1169    }
1170}