use crate::Vector;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResultMergingConfig {
pub combination_strategy: ScoreCombinationStrategy,
pub normalization_method: ScoreNormalizationMethod,
pub fusion_algorithm: RankFusionAlgorithm,
pub source_weights: HashMap<String, f32>,
pub confidence_intervals: bool,
pub enable_explanations: bool,
pub diversity_config: Option<DiversityConfig>,
}
impl Default for ResultMergingConfig {
fn default() -> Self {
let mut source_weights = HashMap::new();
source_weights.insert("primary".to_string(), 1.0);
Self {
combination_strategy: ScoreCombinationStrategy::WeightedSum,
normalization_method: ScoreNormalizationMethod::MinMax,
fusion_algorithm: RankFusionAlgorithm::CombSUM,
source_weights,
confidence_intervals: true,
enable_explanations: false,
diversity_config: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ScoreCombinationStrategy {
Average,
WeightedSum,
Maximum,
Minimum,
GeometricMean,
HarmonicMean,
Product,
BordaCount,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ScoreNormalizationMethod {
None,
MinMax,
ZScore,
RankBased,
Softmax,
Sigmoid,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RankFusionAlgorithm {
CombSUM,
CombMNZ,
ReciprocalRankFusion,
BordaFusion,
CondorcetFusion,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiversityConfig {
pub enable: bool,
pub metric: DiversityMetric,
pub diversity_weight: f32,
pub max_diverse_results: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DiversityMetric {
MMR,
Angular,
ClusterBased,
ContentBased,
}
#[derive(Debug, Clone)]
pub struct SourceResult {
pub source_id: String,
pub results: Vec<ScoredResult>,
pub metadata: ResultMetadata,
}
#[derive(Debug, Clone)]
pub struct ScoredResult {
pub item_id: String,
pub score: f32,
pub rank: usize,
pub vector: Option<Vector>,
pub metadata: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone)]
pub struct ResultMetadata {
pub source_type: SourceType,
pub algorithm_used: String,
pub total_candidates: usize,
pub processing_time: std::time::Duration,
pub quality_metrics: HashMap<String, f32>,
}
#[derive(Debug, Clone)]
pub enum SourceType {
VectorSearch,
TextSearch,
KnowledgeGraph,
MultiModal,
Hybrid,
}
#[derive(Debug, Clone)]
pub struct MergedResult {
pub item_id: String,
pub final_score: f32,
pub confidence_interval: Option<ConfidenceInterval>,
pub source_contributions: Vec<SourceContribution>,
pub explanation: Option<ResultExplanation>,
pub diversity_score: Option<f32>,
}
#[derive(Debug, Clone)]
pub struct ConfidenceInterval {
pub lower_bound: f32,
pub upper_bound: f32,
pub confidence_level: f32,
}
#[derive(Debug, Clone)]
pub struct SourceContribution {
pub source_id: String,
pub original_score: f32,
pub normalized_score: f32,
pub weight: f32,
pub rank: usize,
}
#[derive(Debug, Clone)]
pub struct ResultExplanation {
pub ranking_factors: Vec<RankingFactor>,
pub score_breakdown: HashMap<String, f32>,
pub similar_items: Vec<String>,
pub differentiating_features: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct RankingFactor {
pub factor_name: String,
pub importance: f32,
pub description: String,
}
pub struct AdvancedResultMerger {
config: ResultMergingConfig,
normalization_cache: HashMap<String, NormalizationParams>,
fusion_stats: FusionStatistics,
}
#[derive(Debug, Clone)]
struct NormalizationParams {
min_score: f32,
max_score: f32,
mean_score: f32,
std_dev: f32,
}
#[derive(Debug, Clone, Default)]
pub struct FusionStatistics {
pub total_merges: usize,
pub average_sources_per_merge: f32,
pub score_distribution: HashMap<String, f32>,
pub fusion_quality_metrics: HashMap<String, f32>,
}
impl AdvancedResultMerger {
pub fn new(config: ResultMergingConfig) -> Self {
Self {
config,
normalization_cache: HashMap::new(),
fusion_stats: FusionStatistics::default(),
}
}
pub fn merge_results(&mut self, sources: Vec<SourceResult>) -> Result<Vec<MergedResult>> {
if sources.is_empty() {
return Ok(Vec::new());
}
self.fusion_stats.total_merges += 1;
self.fusion_stats.average_sources_per_merge = (self.fusion_stats.average_sources_per_merge
* (self.fusion_stats.total_merges - 1) as f32
+ sources.len() as f32)
/ self.fusion_stats.total_merges as f32;
let normalized_sources = self.normalize_sources(&sources)?;
let all_items = self.collect_unique_items(&normalized_sources);
let mut merged_results = match self.config.fusion_algorithm {
RankFusionAlgorithm::CombSUM => self.apply_combsum(&normalized_sources, &all_items)?,
RankFusionAlgorithm::CombMNZ => self.apply_combmnz(&normalized_sources, &all_items)?,
RankFusionAlgorithm::ReciprocalRankFusion => {
self.apply_rrf(&normalized_sources, &all_items)?
}
RankFusionAlgorithm::BordaFusion => {
self.apply_borda(&normalized_sources, &all_items)?
}
RankFusionAlgorithm::CondorcetFusion => {
self.apply_condorcet(&normalized_sources, &all_items)?
}
};
merged_results = self.apply_score_combination(merged_results, &normalized_sources)?;
if self.config.confidence_intervals {
merged_results =
self.calculate_confidence_intervals(merged_results, &normalized_sources)?;
}
if self.config.enable_explanations {
merged_results = self.generate_explanations(merged_results, &normalized_sources)?;
}
if let Some(diversity_config) = &self.config.diversity_config {
if diversity_config.enable {
merged_results = self.enhance_diversity(merged_results, diversity_config)?;
}
}
merged_results.sort_by(|a, b| {
b.final_score
.partial_cmp(&a.final_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(merged_results)
}
fn normalize_sources(&mut self, sources: &[SourceResult]) -> Result<Vec<SourceResult>> {
let mut normalized = Vec::new();
for source in sources {
let normalized_source = self.normalize_source(source)?;
normalized.push(normalized_source);
}
Ok(normalized)
}
fn normalize_source(&mut self, source: &SourceResult) -> Result<SourceResult> {
if source.results.is_empty() {
return Ok(source.clone());
}
let scores: Vec<f32> = source.results.iter().map(|r| r.score).collect();
let normalization_params = self.calculate_normalization_params(&scores);
self.normalization_cache
.insert(source.source_id.clone(), normalization_params.clone());
let normalized_results: Vec<ScoredResult> = source
.results
.iter()
.map(|result| {
let normalized_score = self.normalize_score(result.score, &normalization_params);
ScoredResult {
item_id: result.item_id.clone(),
score: normalized_score,
rank: result.rank,
vector: result.vector.clone(),
metadata: result.metadata.clone(),
}
})
.collect();
Ok(SourceResult {
source_id: source.source_id.clone(),
results: normalized_results,
metadata: source.metadata.clone(),
})
}
fn calculate_normalization_params(&self, scores: &[f32]) -> NormalizationParams {
let min_score = scores.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_score = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mean_score = scores.iter().sum::<f32>() / scores.len() as f32;
let variance = scores
.iter()
.map(|&x| (x - mean_score).powi(2))
.sum::<f32>()
/ scores.len() as f32;
let std_dev = variance.sqrt();
NormalizationParams {
min_score,
max_score,
mean_score,
std_dev,
}
}
fn normalize_score(&self, score: f32, params: &NormalizationParams) -> f32 {
match self.config.normalization_method {
ScoreNormalizationMethod::None => score,
ScoreNormalizationMethod::MinMax => {
if params.max_score == params.min_score {
0.5 } else {
(score - params.min_score) / (params.max_score - params.min_score)
}
}
ScoreNormalizationMethod::ZScore => {
if params.std_dev == 0.0 {
0.0 } else {
(score - params.mean_score) / params.std_dev
}
}
ScoreNormalizationMethod::Softmax => {
(score - params.min_score).exp()
}
ScoreNormalizationMethod::Sigmoid => 1.0 / (1.0 + (-score).exp()),
ScoreNormalizationMethod::RankBased => {
score / params.max_score
}
}
}
fn collect_unique_items(&self, sources: &[SourceResult]) -> HashSet<String> {
let mut items = HashSet::new();
for source in sources {
for result in &source.results {
items.insert(result.item_id.clone());
}
}
items
}
fn apply_combsum(
&self,
sources: &[SourceResult],
items: &HashSet<String>,
) -> Result<Vec<MergedResult>> {
let mut merged_results = Vec::new();
for item_id in items {
let mut total_score = 0.0;
let mut source_contributions = Vec::new();
for source in sources {
if let Some(result) = source.results.iter().find(|r| r.item_id == *item_id) {
let weight = self
.config
.source_weights
.get(&source.source_id)
.copied()
.unwrap_or(1.0);
let weighted_score = result.score * weight;
total_score += weighted_score;
source_contributions.push(SourceContribution {
source_id: source.source_id.clone(),
original_score: result.score,
normalized_score: result.score,
weight,
rank: result.rank,
});
}
}
merged_results.push(MergedResult {
item_id: item_id.clone(),
final_score: total_score,
confidence_interval: None,
source_contributions,
explanation: None,
diversity_score: None,
});
}
Ok(merged_results)
}
fn apply_combmnz(
&self,
sources: &[SourceResult],
items: &HashSet<String>,
) -> Result<Vec<MergedResult>> {
let mut merged_results = Vec::new();
for item_id in items {
let mut total_score = 0.0;
let mut non_zero_count = 0;
let mut source_contributions = Vec::new();
for source in sources {
if let Some(result) = source.results.iter().find(|r| r.item_id == *item_id) {
let weight = self
.config
.source_weights
.get(&source.source_id)
.copied()
.unwrap_or(1.0);
let weighted_score = result.score * weight;
if weighted_score > 0.0 {
total_score += weighted_score;
non_zero_count += 1;
}
source_contributions.push(SourceContribution {
source_id: source.source_id.clone(),
original_score: result.score,
normalized_score: result.score,
weight,
rank: result.rank,
});
}
}
let final_score = if non_zero_count > 0 {
total_score * non_zero_count as f32
} else {
0.0
};
merged_results.push(MergedResult {
item_id: item_id.clone(),
final_score,
confidence_interval: None,
source_contributions,
explanation: None,
diversity_score: None,
});
}
Ok(merged_results)
}
fn apply_rrf(
&self,
sources: &[SourceResult],
items: &HashSet<String>,
) -> Result<Vec<MergedResult>> {
let k = 60.0; let mut merged_results = Vec::new();
for item_id in items {
let mut rrf_score = 0.0;
let mut source_contributions = Vec::new();
for source in sources {
if let Some(result) = source.results.iter().find(|r| r.item_id == *item_id) {
let weight = self
.config
.source_weights
.get(&source.source_id)
.copied()
.unwrap_or(1.0);
let rrf_contribution = weight / (k + result.rank as f32);
rrf_score += rrf_contribution;
source_contributions.push(SourceContribution {
source_id: source.source_id.clone(),
original_score: result.score,
normalized_score: rrf_contribution,
weight,
rank: result.rank,
});
}
}
merged_results.push(MergedResult {
item_id: item_id.clone(),
final_score: rrf_score,
confidence_interval: None,
source_contributions,
explanation: None,
diversity_score: None,
});
}
Ok(merged_results)
}
fn apply_borda(
&self,
sources: &[SourceResult],
items: &HashSet<String>,
) -> Result<Vec<MergedResult>> {
let mut merged_results = Vec::new();
for item_id in items {
let mut borda_score = 0.0;
let mut source_contributions = Vec::new();
for source in sources {
if let Some(result) = source.results.iter().find(|r| r.item_id == *item_id) {
let weight = self
.config
.source_weights
.get(&source.source_id)
.copied()
.unwrap_or(1.0);
let max_rank = source.results.len() as f32;
let borda_contribution = weight * (max_rank - result.rank as f32);
borda_score += borda_contribution;
source_contributions.push(SourceContribution {
source_id: source.source_id.clone(),
original_score: result.score,
normalized_score: borda_contribution,
weight,
rank: result.rank,
});
}
}
merged_results.push(MergedResult {
item_id: item_id.clone(),
final_score: borda_score,
confidence_interval: None,
source_contributions,
explanation: None,
diversity_score: None,
});
}
Ok(merged_results)
}
fn apply_condorcet(
&self,
sources: &[SourceResult],
items: &HashSet<String>,
) -> Result<Vec<MergedResult>> {
self.apply_borda(sources, items)
}
fn apply_score_combination(
&self,
mut results: Vec<MergedResult>,
_sources: &[SourceResult],
) -> Result<Vec<MergedResult>> {
match self.config.combination_strategy {
ScoreCombinationStrategy::Average => {
for result in &mut results {
if !result.source_contributions.is_empty() {
result.final_score = result
.source_contributions
.iter()
.map(|c| c.normalized_score)
.sum::<f32>()
/ result.source_contributions.len() as f32;
}
}
}
ScoreCombinationStrategy::WeightedSum => {
}
ScoreCombinationStrategy::Maximum => {
for result in &mut results {
result.final_score = result
.source_contributions
.iter()
.map(|c| c.normalized_score)
.fold(0.0, f32::max);
}
}
ScoreCombinationStrategy::Minimum => {
for result in &mut results {
result.final_score = result
.source_contributions
.iter()
.map(|c| c.normalized_score)
.fold(f32::INFINITY, f32::min);
}
}
ScoreCombinationStrategy::GeometricMean => {
for result in &mut results {
let product: f32 = result
.source_contributions
.iter()
.map(|c| c.normalized_score.max(0.001)) .product();
result.final_score =
product.powf(1.0 / result.source_contributions.len() as f32);
}
}
_ => {
}
}
Ok(results)
}
fn calculate_confidence_intervals(
&self,
mut results: Vec<MergedResult>,
_sources: &[SourceResult],
) -> Result<Vec<MergedResult>> {
for result in &mut results {
if result.source_contributions.len() > 1 {
let scores: Vec<f32> = result
.source_contributions
.iter()
.map(|c| c.normalized_score)
.collect();
let mean = scores.iter().sum::<f32>() / scores.len() as f32;
let variance =
scores.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / scores.len() as f32;
let std_dev = variance.sqrt();
let margin = 1.96 * std_dev / (scores.len() as f32).sqrt();
result.confidence_interval = Some(ConfidenceInterval {
lower_bound: (mean - margin).max(0.0),
upper_bound: (mean + margin).min(1.0),
confidence_level: 0.95,
});
}
}
Ok(results)
}
fn generate_explanations(
&self,
mut results: Vec<MergedResult>,
_sources: &[SourceResult],
) -> Result<Vec<MergedResult>> {
for result in &mut results {
let mut ranking_factors = Vec::new();
let mut score_breakdown = HashMap::new();
for contribution in &result.source_contributions {
ranking_factors.push(RankingFactor {
factor_name: format!("Source: {}", contribution.source_id),
importance: contribution.normalized_score,
description: format!(
"Contribution from {} with weight {}",
contribution.source_id, contribution.weight
),
});
score_breakdown.insert(
contribution.source_id.clone(),
contribution.normalized_score,
);
}
result.explanation = Some(ResultExplanation {
ranking_factors,
score_breakdown,
similar_items: Vec::new(), differentiating_features: Vec::new(), });
}
Ok(results)
}
fn enhance_diversity(
&self,
results: Vec<MergedResult>,
diversity_config: &DiversityConfig,
) -> Result<Vec<MergedResult>> {
if results.len() <= diversity_config.max_diverse_results {
return Ok(results);
}
let mut selected = Vec::new();
let mut remaining = results;
if !remaining.is_empty() {
let top_result = remaining.remove(0);
selected.push(top_result);
}
while selected.len() < diversity_config.max_diverse_results && !remaining.is_empty() {
let mut best_idx = 0;
let mut best_mmr = f32::NEG_INFINITY;
for (i, candidate) in remaining.iter().enumerate() {
let relevance = candidate.final_score;
let max_similarity =
self.calculate_max_similarity_to_selected(candidate, &selected);
let mmr = diversity_config.diversity_weight * relevance
- (1.0 - diversity_config.diversity_weight) * max_similarity;
if mmr > best_mmr {
best_mmr = mmr;
best_idx = i;
}
}
let selected_result = remaining.remove(best_idx);
selected.push(selected_result);
}
for result in &mut selected {
result.diversity_score = Some(0.8); }
Ok(selected)
}
fn calculate_max_similarity_to_selected(
&self,
candidate: &MergedResult,
selected: &[MergedResult],
) -> f32 {
if selected.is_empty() {
return 0.0;
}
let mut max_similarity: f32 = 0.0;
for selected_result in selected {
let similarity: f32 = 1.0 - (candidate.final_score - selected_result.final_score).abs();
max_similarity = max_similarity.max(similarity);
}
max_similarity
}
pub fn get_statistics(&self) -> &FusionStatistics {
&self.fusion_stats
}
pub fn reset_statistics(&mut self) {
self.fusion_stats = FusionStatistics::default();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn create_test_source(source_id: &str, results: Vec<(String, f32, usize)>) -> SourceResult {
let scored_results = results
.into_iter()
.map(|(id, score, rank)| ScoredResult {
item_id: id,
score,
rank,
vector: None,
metadata: None,
})
.collect();
SourceResult {
source_id: source_id.to_string(),
results: scored_results,
metadata: ResultMetadata {
source_type: SourceType::VectorSearch,
algorithm_used: "test".to_string(),
total_candidates: 100,
processing_time: Duration::from_millis(10),
quality_metrics: HashMap::new(),
},
}
}
#[test]
fn test_combsum_fusion() -> Result<()> {
let config = ResultMergingConfig::default();
let mut merger = AdvancedResultMerger::new(config);
let source1 = create_test_source(
"source1",
vec![("doc1".to_string(), 0.9, 1), ("doc2".to_string(), 0.8, 2)],
);
let source2 = create_test_source(
"source2",
vec![("doc1".to_string(), 0.7, 1), ("doc3".to_string(), 0.6, 2)],
);
let merged = merger.merge_results(vec![source1, source2])?;
assert_eq!(merged.len(), 3);
let doc1_result = merged
.iter()
.find(|r| r.item_id == "doc1")
.expect("doc1 not found");
assert!(doc1_result.final_score > 1.0); Ok(())
}
#[test]
fn test_reciprocal_rank_fusion() -> Result<()> {
let config = ResultMergingConfig {
fusion_algorithm: RankFusionAlgorithm::ReciprocalRankFusion,
..Default::default()
};
let mut merger = AdvancedResultMerger::new(config);
let source1 = create_test_source(
"source1",
vec![("doc1".to_string(), 0.9, 1), ("doc2".to_string(), 0.8, 2)],
);
let source2 = create_test_source(
"source2",
vec![("doc2".to_string(), 0.7, 1), ("doc1".to_string(), 0.6, 2)],
);
let merged = merger.merge_results(vec![source1, source2])?;
assert_eq!(merged.len(), 2);
for result in &merged {
assert!(result.final_score > 0.0);
assert_eq!(result.source_contributions.len(), 2);
}
Ok(())
}
#[test]
fn test_confidence_intervals() -> Result<()> {
let config = ResultMergingConfig {
confidence_intervals: true,
..Default::default()
};
let mut merger = AdvancedResultMerger::new(config);
let source1 = create_test_source("source1", vec![("doc1".to_string(), 0.9, 1)]);
let source2 = create_test_source("source2", vec![("doc1".to_string(), 0.7, 1)]);
let merged = merger.merge_results(vec![source1, source2])?;
assert_eq!(merged.len(), 1);
let result = &merged[0];
assert!(result.confidence_interval.is_some());
let ci = result
.confidence_interval
.as_ref()
.expect("confidence_interval was None");
assert!(ci.lower_bound <= ci.upper_bound);
assert_eq!(ci.confidence_level, 0.95);
Ok(())
}
#[test]
fn test_score_normalization() -> Result<()> {
let config = ResultMergingConfig {
normalization_method: ScoreNormalizationMethod::MinMax,
..Default::default()
};
let mut merger = AdvancedResultMerger::new(config);
let source = create_test_source(
"source1",
vec![
("doc1".to_string(), 10.0, 1),
("doc2".to_string(), 5.0, 2),
("doc3".to_string(), 0.0, 3),
],
);
let normalized = merger.normalize_source(&source)?;
for result in &normalized.results {
assert!(result.score >= 0.0 && result.score <= 1.0);
}
Ok(())
}
}