1use crate::{sparql_integration::VectorServiceResult, Vector};
8use anyhow::{anyhow, Result};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::Duration;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct FusionConfig {
16 pub max_results: usize,
18 pub min_score_threshold: f32,
20 pub normalization_strategy: ScoreNormalizationStrategy,
22 pub fusion_algorithm: FusionAlgorithm,
24 pub source_weights: HashMap<String, f32>,
26 pub enable_diversification: bool,
28 pub diversification_factor: f32,
30 pub enable_explanation: bool,
32}
33
34impl Default for FusionConfig {
35 fn default() -> Self {
36 Self {
37 max_results: 100,
38 min_score_threshold: 0.0,
39 normalization_strategy: ScoreNormalizationStrategy::MinMax,
40 fusion_algorithm: FusionAlgorithm::CombSum,
41 source_weights: HashMap::new(),
42 enable_diversification: false,
43 diversification_factor: 0.2,
44 enable_explanation: false,
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub enum ScoreNormalizationStrategy {
52 None,
54 MinMax,
56 ZScore,
58 Rank,
60 Sigmoid,
62 Softmax,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize, Default)]
68pub enum FusionAlgorithm {
69 #[default]
71 CombSum,
72 CombMax,
74 CombMin,
76 CombAvg,
78 CombMedian,
80 WeightedSum,
82 RRF,
84 BordaCount,
86 Condorcet,
88 MLFusion,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct VectorSearchResult {
95 pub resource: String,
97 pub score: f32,
99 pub normalized_score: Option<f32>,
101 pub source: String,
103 pub original_rank: usize,
105 pub final_rank: Option<usize>,
107 pub vector: Option<Vector>,
109 pub metadata: HashMap<String, String>,
111 pub explanation: Option<String>,
113}
114
115#[derive(Debug, Clone)]
117pub struct SourceResults {
118 pub source_id: String,
120 pub results: Vec<VectorSearchResult>,
122 pub metadata: HashMap<String, String>,
124 pub response_time: Option<Duration>,
126 pub weight: Option<f32>,
128}
129
130#[derive(Debug, Clone)]
132pub struct FusedResults {
133 pub results: Vec<VectorSearchResult>,
135 pub fusion_stats: FusionStats,
137 pub config: FusionConfig,
139 pub processing_time: Duration,
141}
142
143#[derive(Debug, Clone, Default)]
145pub struct FusionStats {
146 pub source_count: usize,
148 pub total_input_results: usize,
150 pub unique_resources: usize,
152 pub final_result_count: usize,
154 pub avg_score_before: f32,
156 pub avg_score_after: f32,
158 pub score_distribution: HashMap<String, ScoreDistribution>,
160 pub fusion_algorithm: FusionAlgorithm,
162}
163
164#[derive(Debug, Clone, Default)]
166pub struct ScoreDistribution {
167 pub min: f32,
168 pub max: f32,
169 pub mean: f32,
170 pub std_dev: f32,
171 pub count: usize,
172}
173
174pub struct ResultFusionEngine {
176 config: FusionConfig,
177}
178
179impl ResultFusionEngine {
180 pub fn new() -> Self {
182 Self {
183 config: FusionConfig::default(),
184 }
185 }
186
187 pub fn with_config(config: FusionConfig) -> Self {
189 Self { config }
190 }
191
192 pub fn fuse_results(&self, sources: Vec<SourceResults>) -> Result<FusedResults> {
194 let start_time = std::time::Instant::now();
195
196 if sources.is_empty() {
197 return Ok(FusedResults {
198 results: Vec::new(),
199 fusion_stats: FusionStats::default(),
200 config: self.config.clone(),
201 processing_time: start_time.elapsed(),
202 });
203 }
204
205 let mut all_results = Vec::new();
207 let mut fusion_stats = FusionStats {
208 source_count: sources.len(),
209 fusion_algorithm: self.config.fusion_algorithm.clone(),
210 ..Default::default()
211 };
212
213 for source in &sources {
214 for (rank, result) in source.results.iter().enumerate() {
215 let mut enriched_result = result.clone();
216 enriched_result.original_rank = rank;
217 enriched_result.source = source.source_id.clone();
218 all_results.push(enriched_result);
219 }
220 fusion_stats.total_input_results += source.results.len();
221 }
222
223 self.calculate_score_distributions(&sources, &mut fusion_stats);
225
226 let normalized_results = self.normalize_scores(all_results)?;
228
229 let grouped_results = self.group_by_resource(normalized_results);
231 fusion_stats.unique_resources = grouped_results.len();
232
233 let fused_results = self.apply_fusion_algorithm(grouped_results)?;
235
236 let diversified_results = if self.config.enable_diversification {
238 self.apply_diversification(fused_results)?
239 } else {
240 fused_results
241 };
242
243 let mut final_results = diversified_results
245 .into_iter()
246 .filter(|r| r.score >= self.config.min_score_threshold)
247 .take(self.config.max_results)
248 .collect::<Vec<_>>();
249
250 for (rank, result) in final_results.iter_mut().enumerate() {
252 result.final_rank = Some(rank + 1);
253 }
254
255 fusion_stats.final_result_count = final_results.len();
257 if !final_results.is_empty() {
258 fusion_stats.avg_score_after =
259 final_results.iter().map(|r| r.score).sum::<f32>() / final_results.len() as f32;
260 }
261
262 Ok(FusedResults {
263 results: final_results,
264 fusion_stats,
265 config: self.config.clone(),
266 processing_time: start_time.elapsed(),
267 })
268 }
269
270 fn calculate_score_distributions(&self, sources: &[SourceResults], stats: &mut FusionStats) {
272 for source in sources {
273 if source.results.is_empty() {
274 continue;
275 }
276
277 let scores: Vec<f32> = source.results.iter().map(|r| r.score).collect();
278 let min = scores.iter().fold(f32::INFINITY, |a, &b| a.min(b));
279 let max = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
280 let mean = scores.iter().sum::<f32>() / scores.len() as f32;
281
282 let variance =
283 scores.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / scores.len() as f32;
284 let std_dev = variance.sqrt();
285
286 stats.score_distribution.insert(
287 source.source_id.clone(),
288 ScoreDistribution {
289 min,
290 max,
291 mean,
292 std_dev,
293 count: scores.len(),
294 },
295 );
296 }
297
298 let all_scores: Vec<f32> = sources
300 .iter()
301 .flat_map(|s| s.results.iter().map(|r| r.score))
302 .collect();
303
304 if !all_scores.is_empty() {
305 stats.avg_score_before = all_scores.iter().sum::<f32>() / all_scores.len() as f32;
306 }
307 }
308
309 fn normalize_scores(
311 &self,
312 mut results: Vec<VectorSearchResult>,
313 ) -> Result<Vec<VectorSearchResult>> {
314 match self.config.normalization_strategy {
315 ScoreNormalizationStrategy::None => {
316 for result in &mut results {
317 result.normalized_score = Some(result.score);
318 }
319 }
320 ScoreNormalizationStrategy::MinMax => {
321 self.apply_minmax_normalization(&mut results)?;
322 }
323 ScoreNormalizationStrategy::ZScore => {
324 self.apply_zscore_normalization(&mut results)?;
325 }
326 ScoreNormalizationStrategy::Rank => {
327 self.apply_rank_normalization(&mut results)?;
328 }
329 ScoreNormalizationStrategy::Sigmoid => {
330 self.apply_sigmoid_normalization(&mut results)?;
331 }
332 ScoreNormalizationStrategy::Softmax => {
333 self.apply_softmax_normalization(&mut results)?;
334 }
335 }
336
337 Ok(results)
338 }
339
340 fn apply_minmax_normalization(&self, results: &mut [VectorSearchResult]) -> Result<()> {
342 if results.is_empty() {
343 return Ok(());
344 }
345
346 let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
347 let min_score = scores.iter().fold(f32::INFINITY, |a, &b| a.min(b));
348 let max_score = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
349
350 let range = max_score - min_score;
351 if range == 0.0 {
352 for result in results {
353 result.normalized_score = Some(1.0);
354 }
355 } else {
356 for result in results {
357 result.normalized_score = Some((result.score - min_score) / range);
358 }
359 }
360
361 Ok(())
362 }
363
364 fn apply_zscore_normalization(&self, results: &mut [VectorSearchResult]) -> Result<()> {
366 if results.is_empty() {
367 return Ok(());
368 }
369
370 let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
371 let mean = scores.iter().sum::<f32>() / scores.len() as f32;
372 let variance =
373 scores.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / scores.len() as f32;
374 let std_dev = variance.sqrt();
375
376 if std_dev == 0.0 {
377 for result in results {
378 result.normalized_score = Some(0.0);
379 }
380 } else {
381 for result in results {
382 result.normalized_score = Some((result.score - mean) / std_dev);
383 }
384 }
385
386 Ok(())
387 }
388
389 fn apply_rank_normalization(&self, results: &mut [VectorSearchResult]) -> Result<()> {
391 if results.is_empty() {
392 return Ok(());
393 }
394
395 let mut indexed_results: Vec<(usize, f32)> = results
397 .iter()
398 .enumerate()
399 .map(|(i, r)| (i, r.score))
400 .collect();
401 indexed_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
402
403 let total_count = results.len() as f32;
405 for (rank, (original_index, _)) in indexed_results.iter().enumerate() {
406 let normalized_score = (total_count - rank as f32) / total_count;
407 results[*original_index].normalized_score = Some(normalized_score);
408 }
409
410 Ok(())
411 }
412
413 fn apply_sigmoid_normalization(&self, results: &mut [VectorSearchResult]) -> Result<()> {
415 for result in results {
416 let sigmoid_score = 1.0 / (1.0 + (-result.score).exp());
417 result.normalized_score = Some(sigmoid_score);
418 }
419 Ok(())
420 }
421
422 fn apply_softmax_normalization(&self, results: &mut [VectorSearchResult]) -> Result<()> {
424 if results.is_empty() {
425 return Ok(());
426 }
427
428 let max_score = results
430 .iter()
431 .map(|r| r.score)
432 .fold(f32::NEG_INFINITY, |a, b| a.max(b));
433
434 let exp_scores: Vec<f32> = results
435 .iter()
436 .map(|r| (r.score - max_score).exp())
437 .collect();
438
439 let sum_exp: f32 = exp_scores.iter().sum();
440
441 for (i, result) in results.iter_mut().enumerate() {
442 result.normalized_score = Some(exp_scores[i] / sum_exp);
443 }
444
445 Ok(())
446 }
447
448 fn group_by_resource(
450 &self,
451 results: Vec<VectorSearchResult>,
452 ) -> HashMap<String, Vec<VectorSearchResult>> {
453 let mut grouped = HashMap::new();
454
455 for result in results {
456 grouped
457 .entry(result.resource.clone())
458 .or_insert_with(Vec::new)
459 .push(result);
460 }
461
462 grouped
463 }
464
465 fn apply_fusion_algorithm(
467 &self,
468 grouped_results: HashMap<String, Vec<VectorSearchResult>>,
469 ) -> Result<Vec<VectorSearchResult>> {
470 let mut fused_results = Vec::new();
471
472 for (_resource, mut resource_results) in grouped_results {
473 let fused_result = match &self.config.fusion_algorithm {
474 FusionAlgorithm::CombSum => self.apply_combsum(&resource_results)?,
475 FusionAlgorithm::CombMax => self.apply_combmax(&resource_results)?,
476 FusionAlgorithm::CombMin => self.apply_combmin(&resource_results)?,
477 FusionAlgorithm::CombAvg => self.apply_combavg(&resource_results)?,
478 FusionAlgorithm::CombMedian => self.apply_combmedian(&mut resource_results)?,
479 FusionAlgorithm::WeightedSum => self.apply_weighted_sum(&resource_results)?,
480 FusionAlgorithm::RRF => self.apply_rrf(&resource_results)?,
481 FusionAlgorithm::BordaCount => self.apply_borda_count(&resource_results)?,
482 FusionAlgorithm::Condorcet => self.apply_condorcet(&resource_results)?,
483 FusionAlgorithm::MLFusion => self.apply_ml_fusion(&resource_results)?,
484 };
485
486 fused_results.push(fused_result);
487 }
488
489 fused_results.sort_by(|a, b| {
491 b.score
492 .partial_cmp(&a.score)
493 .unwrap_or(std::cmp::Ordering::Equal)
494 });
495
496 Ok(fused_results)
497 }
498
499 fn apply_combsum(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
501 let sum_score = results
502 .iter()
503 .map(|r| r.normalized_score.unwrap_or(r.score))
504 .sum::<f32>();
505
506 Ok(self.create_fused_result(results, sum_score, "CombSum"))
507 }
508
509 fn apply_combmax(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
511 let max_score = results
512 .iter()
513 .map(|r| r.normalized_score.unwrap_or(r.score))
514 .fold(f32::NEG_INFINITY, |a, b| a.max(b));
515
516 Ok(self.create_fused_result(results, max_score, "CombMax"))
517 }
518
519 fn apply_combmin(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
521 let min_score = results
522 .iter()
523 .map(|r| r.normalized_score.unwrap_or(r.score))
524 .fold(f32::INFINITY, |a, b| a.min(b));
525
526 Ok(self.create_fused_result(results, min_score, "CombMin"))
527 }
528
529 fn apply_combavg(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
531 let avg_score = results
532 .iter()
533 .map(|r| r.normalized_score.unwrap_or(r.score))
534 .sum::<f32>()
535 / results.len() as f32;
536
537 Ok(self.create_fused_result(results, avg_score, "CombAvg"))
538 }
539
540 fn apply_combmedian(&self, results: &mut [VectorSearchResult]) -> Result<VectorSearchResult> {
542 let mut scores: Vec<f32> = results
543 .iter()
544 .map(|r| r.normalized_score.unwrap_or(r.score))
545 .collect();
546
547 scores.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
548
549 let median_score = if scores.len() % 2 == 0 {
550 let mid = scores.len() / 2;
551 (scores[mid - 1] + scores[mid]) / 2.0
552 } else {
553 scores[scores.len() / 2]
554 };
555
556 Ok(self.create_fused_result(results, median_score, "CombMedian"))
557 }
558
559 fn apply_weighted_sum(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
561 let mut weighted_sum = 0.0;
562 let mut total_weight = 0.0;
563
564 for result in results {
565 let weight = self
566 .config
567 .source_weights
568 .get(&result.source)
569 .copied()
570 .unwrap_or(1.0);
571 let score = result.normalized_score.unwrap_or(result.score);
572 weighted_sum += score * weight;
573 total_weight += weight;
574 }
575
576 let final_score = if total_weight > 0.0 {
577 weighted_sum / total_weight
578 } else {
579 0.0
580 };
581
582 Ok(self.create_fused_result(results, final_score, "WeightedSum"))
583 }
584
585 fn apply_rrf(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
587 let k = 60.0; let rrf_score = results
589 .iter()
590 .map(|r| 1.0 / (k + r.original_rank as f32 + 1.0))
591 .sum::<f32>();
592
593 Ok(self.create_fused_result(results, rrf_score, "RRF"))
594 }
595
596 fn apply_borda_count(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
598 let total_sources = results.len();
600 let borda_score = results
601 .iter()
602 .map(|r| (total_sources - r.original_rank) as f32)
603 .sum::<f32>();
604
605 Ok(self.create_fused_result(results, borda_score, "BordaCount"))
606 }
607
608 fn apply_condorcet(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
610 let condorcet_score = results
612 .iter()
613 .map(|r| {
614 let score = r.normalized_score.unwrap_or(r.score);
615 let rank_penalty = 1.0 / (r.original_rank as f32 + 1.0);
616 score * rank_penalty
617 })
618 .sum::<f32>()
619 / results.len() as f32;
620
621 Ok(self.create_fused_result(results, condorcet_score, "Condorcet"))
622 }
623
624 fn apply_ml_fusion(&self, results: &[VectorSearchResult]) -> Result<VectorSearchResult> {
626 let mut ml_score = 0.0;
629
630 for result in results {
631 let score = result.normalized_score.unwrap_or(result.score);
632 let rank_feature = 1.0 / (result.original_rank as f32 + 1.0);
633 let source_weight = self
634 .config
635 .source_weights
636 .get(&result.source)
637 .copied()
638 .unwrap_or(1.0);
639
640 ml_score += 0.5 * score + 0.3 * rank_feature + 0.2 * source_weight;
642 }
643
644 ml_score /= results.len() as f32;
645
646 Ok(self.create_fused_result(results, ml_score, "MLFusion"))
647 }
648
649 fn create_fused_result(
651 &self,
652 results: &[VectorSearchResult],
653 fused_score: f32,
654 algorithm: &str,
655 ) -> VectorSearchResult {
656 let first_result = &results[0];
657 let mut metadata = first_result.metadata.clone();
658
659 metadata.insert("fusion_algorithm".to_string(), algorithm.to_string());
661 metadata.insert("source_count".to_string(), results.len().to_string());
662 metadata.insert(
663 "sources".to_string(),
664 results
665 .iter()
666 .map(|r| r.source.clone())
667 .collect::<Vec<_>>()
668 .join(","),
669 );
670
671 let explanation = if self.config.enable_explanation {
672 Some(format!(
673 "{} fusion of {} results from sources: [{}] with final score: {:.4}",
674 algorithm,
675 results.len(),
676 results
677 .iter()
678 .map(|r| format!("{}:{:.3}", r.source, r.score))
679 .collect::<Vec<_>>()
680 .join(", "),
681 fused_score
682 ))
683 } else {
684 None
685 };
686
687 VectorSearchResult {
688 resource: first_result.resource.clone(),
689 score: fused_score,
690 normalized_score: Some(fused_score),
691 source: "FUSED".to_string(),
692 original_rank: 0,
693 final_rank: None,
694 vector: first_result.vector.clone(),
695 metadata,
696 explanation,
697 }
698 }
699
700 fn apply_diversification(
702 &self,
703 results: Vec<VectorSearchResult>,
704 ) -> Result<Vec<VectorSearchResult>> {
705 if results.len() <= 1 || self.config.diversification_factor == 0.0 {
706 return Ok(results);
707 }
708
709 let mut diversified = Vec::new();
710 let mut remaining = results;
711
712 if !remaining.is_empty() {
714 diversified.push(remaining.remove(0));
715 }
716
717 while !remaining.is_empty() && diversified.len() < self.config.max_results {
719 let mut best_index = 0;
720 let mut best_score = f32::NEG_INFINITY;
721
722 for (i, candidate) in remaining.iter().enumerate() {
723 let diversity_penalty = self.calculate_diversity_penalty(candidate, &diversified);
725
726 let combined_score = (1.0 - self.config.diversification_factor) * candidate.score
728 + self.config.diversification_factor * diversity_penalty;
729
730 if combined_score > best_score {
731 best_score = combined_score;
732 best_index = i;
733 }
734 }
735
736 diversified.push(remaining.remove(best_index));
737 }
738
739 Ok(diversified)
740 }
741
742 fn calculate_diversity_penalty(
744 &self,
745 candidate: &VectorSearchResult,
746 selected: &[VectorSearchResult],
747 ) -> f32 {
748 if selected.is_empty() {
749 return 1.0;
750 }
751
752 let mut min_similarity = f32::INFINITY;
754
755 for selected_result in selected {
756 let similarity =
757 self.calculate_string_similarity(&candidate.resource, &selected_result.resource);
758 min_similarity = min_similarity.min(similarity);
759 }
760
761 1.0 - min_similarity
763 }
764
765 fn calculate_string_similarity(&self, s1: &str, s2: &str) -> f32 {
767 let bigrams1 = self.get_character_bigrams(s1);
769 let bigrams2 = self.get_character_bigrams(s2);
770
771 let intersection: usize = bigrams1
772 .iter()
773 .filter(|&bigram| bigrams2.contains(bigram))
774 .count();
775
776 let union_size = bigrams1.len() + bigrams2.len() - intersection;
777
778 if union_size == 0 {
779 1.0
780 } else {
781 intersection as f32 / union_size as f32
782 }
783 }
784
785 fn get_character_bigrams(&self, s: &str) -> std::collections::HashSet<String> {
787 let chars: Vec<char> = s.chars().collect();
788 let mut bigrams = std::collections::HashSet::new();
789
790 for i in 0..chars.len().saturating_sub(1) {
791 let bigram = format!("{}{}", chars[i], chars[i + 1]);
792 bigrams.insert(bigram);
793 }
794
795 bigrams
796 }
797}
798
799impl Default for ResultFusionEngine {
800 fn default() -> Self {
801 Self::new()
802 }
803}
804
805pub mod fusion_utils {
807 use super::*;
808
809 pub fn convert_service_results(
811 source_id: String,
812 service_result: VectorServiceResult,
813 ) -> Result<SourceResults> {
814 let results = match service_result {
815 VectorServiceResult::SimilarityList(list) => list
816 .into_iter()
817 .enumerate()
818 .map(|(rank, (resource, score))| VectorSearchResult {
819 resource,
820 score,
821 normalized_score: None,
822 source: source_id.clone(),
823 original_rank: rank,
824 final_rank: None,
825 vector: None,
826 metadata: HashMap::new(),
827 explanation: None,
828 })
829 .collect(),
830 VectorServiceResult::DetailedSimilarityList(detailed_list) => detailed_list
831 .into_iter()
832 .enumerate()
833 .map(|(rank, detailed)| VectorSearchResult {
834 resource: detailed.0,
835 score: detailed.1,
836 normalized_score: None,
837 source: source_id.clone(),
838 original_rank: rank,
839 final_rank: None,
840 vector: None,
841 metadata: detailed.2,
842 explanation: None,
843 })
844 .collect(),
845 _ => {
846 return Err(anyhow!(
847 "Cannot convert non-similarity result to source results"
848 ));
849 }
850 };
851
852 Ok(SourceResults {
853 source_id,
854 results,
855 metadata: HashMap::new(),
856 response_time: None,
857 weight: None,
858 })
859 }
860
861 pub fn create_source_results(source_id: String, results: Vec<(String, f32)>) -> SourceResults {
863 let search_results = results
864 .into_iter()
865 .enumerate()
866 .map(|(rank, (resource, score))| VectorSearchResult {
867 resource,
868 score,
869 normalized_score: None,
870 source: source_id.clone(),
871 original_rank: rank,
872 final_rank: None,
873 vector: None,
874 metadata: HashMap::new(),
875 explanation: None,
876 })
877 .collect();
878
879 SourceResults {
880 source_id,
881 results: search_results,
882 metadata: HashMap::new(),
883 response_time: None,
884 weight: None,
885 }
886 }
887
888 pub fn calculate_fusion_quality(
890 fused_results: &FusedResults,
891 ground_truth: Option<&[String]>,
892 ) -> FusionQualityMetrics {
893 let mut metrics = FusionQualityMetrics {
894 result_count: fused_results.results.len(),
895 ..Default::default()
896 };
897 if !fused_results.results.is_empty() {
898 metrics.avg_score = fused_results.results.iter().map(|r| r.score).sum::<f32>()
899 / fused_results.results.len() as f32;
900 metrics.min_score = fused_results
901 .results
902 .iter()
903 .map(|r| r.score)
904 .fold(f32::INFINITY, |a, b| a.min(b));
905 metrics.max_score = fused_results
906 .results
907 .iter()
908 .map(|r| r.score)
909 .fold(f32::NEG_INFINITY, |a, b| a.max(b));
910 }
911
912 metrics.diversity = calculate_result_diversity(&fused_results.results);
914
915 if let Some(gt) = ground_truth {
917 let relevant_count = fused_results
918 .results
919 .iter()
920 .filter(|r| gt.contains(&r.resource))
921 .count();
922
923 metrics.precision = if fused_results.results.is_empty() {
924 0.0
925 } else {
926 relevant_count as f32 / fused_results.results.len() as f32
927 };
928
929 metrics.recall = if gt.is_empty() {
930 0.0
931 } else {
932 relevant_count as f32 / gt.len() as f32
933 };
934
935 metrics.f1_score = if metrics.precision + metrics.recall == 0.0 {
936 0.0
937 } else {
938 2.0 * metrics.precision * metrics.recall / (metrics.precision + metrics.recall)
939 };
940 }
941
942 metrics
943 }
944
945 fn calculate_result_diversity(results: &[VectorSearchResult]) -> f32 {
947 if results.len() <= 1 {
948 return 1.0;
949 }
950
951 let mut total_similarity = 0.0;
952 let mut pair_count = 0;
953
954 for i in 0..results.len() {
955 for j in i + 1..results.len() {
956 let sim = jaccard_similarity(&results[i].resource, &results[j].resource);
958 total_similarity += sim;
959 pair_count += 1;
960 }
961 }
962
963 if pair_count == 0 {
964 1.0
965 } else {
966 1.0 - (total_similarity / pair_count as f32)
967 }
968 }
969
970 fn jaccard_similarity(s1: &str, s2: &str) -> f32 {
972 let chars1: std::collections::HashSet<char> = s1.chars().collect();
973 let chars2: std::collections::HashSet<char> = s2.chars().collect();
974
975 let intersection = chars1.intersection(&chars2).count();
976 let union = chars1.union(&chars2).count();
977
978 if union == 0 {
979 1.0
980 } else {
981 intersection as f32 / union as f32
982 }
983 }
984}
985
986#[derive(Debug, Clone, Default)]
988pub struct FusionQualityMetrics {
989 pub result_count: usize,
990 pub avg_score: f32,
991 pub min_score: f32,
992 pub max_score: f32,
993 pub diversity: f32,
994 pub precision: f32,
995 pub recall: f32,
996 pub f1_score: f32,
997}
998
999#[cfg(test)]
1000mod tests {
1001 use super::*;
1002
1003 #[test]
1004 fn test_combsum_fusion() {
1005 let fusion_engine = ResultFusionEngine::new();
1006
1007 let source1 = SourceResults {
1008 source_id: "source1".to_string(),
1009 results: vec![
1010 VectorSearchResult {
1011 resource: "doc1".to_string(),
1012 score: 0.9,
1013 normalized_score: None,
1014 source: "source1".to_string(),
1015 original_rank: 0,
1016 final_rank: None,
1017 vector: None,
1018 metadata: HashMap::new(),
1019 explanation: None,
1020 },
1021 VectorSearchResult {
1022 resource: "doc2".to_string(),
1023 score: 0.7,
1024 normalized_score: None,
1025 source: "source1".to_string(),
1026 original_rank: 1,
1027 final_rank: None,
1028 vector: None,
1029 metadata: HashMap::new(),
1030 explanation: None,
1031 },
1032 ],
1033 metadata: HashMap::new(),
1034 response_time: None,
1035 weight: None,
1036 };
1037
1038 let source2 = SourceResults {
1039 source_id: "source2".to_string(),
1040 results: vec![
1041 VectorSearchResult {
1042 resource: "doc1".to_string(),
1043 score: 0.8,
1044 normalized_score: None,
1045 source: "source2".to_string(),
1046 original_rank: 0,
1047 final_rank: None,
1048 vector: None,
1049 metadata: HashMap::new(),
1050 explanation: None,
1051 },
1052 VectorSearchResult {
1053 resource: "doc3".to_string(),
1054 score: 0.6,
1055 normalized_score: None,
1056 source: "source2".to_string(),
1057 original_rank: 1,
1058 final_rank: None,
1059 vector: None,
1060 metadata: HashMap::new(),
1061 explanation: None,
1062 },
1063 ],
1064 metadata: HashMap::new(),
1065 response_time: None,
1066 weight: None,
1067 };
1068
1069 let result = fusion_engine.fuse_results(vec![source1, source2]).unwrap();
1070
1071 assert_eq!(result.results.len(), 3); assert_eq!(result.fusion_stats.source_count, 2);
1073 assert_eq!(result.fusion_stats.unique_resources, 3);
1074
1075 assert_eq!(result.results[0].resource, "doc1");
1077 assert!(result.results[0].score > result.results[1].score);
1078 }
1079
1080 #[test]
1081 fn test_rrf_fusion() {
1082 let config = FusionConfig {
1083 fusion_algorithm: FusionAlgorithm::RRF,
1084 ..Default::default()
1085 };
1086 let fusion_engine = ResultFusionEngine::with_config(config);
1087
1088 let source1 = fusion_utils::create_source_results(
1090 "source1".to_string(),
1091 vec![("doc1".to_string(), 0.9), ("doc2".to_string(), 0.7)],
1092 );
1093
1094 let source2 = fusion_utils::create_source_results(
1095 "source2".to_string(),
1096 vec![("doc2".to_string(), 0.8), ("doc3".to_string(), 0.6)],
1097 );
1098
1099 let result = fusion_engine.fuse_results(vec![source1, source2]).unwrap();
1100
1101 assert!(!result.results.is_empty());
1102 assert_eq!(result.fusion_stats.unique_resources, 3);
1103 }
1104
1105 #[test]
1106 fn test_score_normalization() {
1107 let config = FusionConfig {
1108 normalization_strategy: ScoreNormalizationStrategy::MinMax,
1109 ..Default::default()
1110 };
1111 let fusion_engine = ResultFusionEngine::with_config(config);
1112
1113 let source = fusion_utils::create_source_results(
1114 "test".to_string(),
1115 vec![
1116 ("doc1".to_string(), 0.2),
1117 ("doc2".to_string(), 0.8),
1118 ("doc3".to_string(), 0.5),
1119 ],
1120 );
1121
1122 let result = fusion_engine.fuse_results(vec![source]).unwrap();
1123
1124 for res in &result.results {
1126 assert!(res.score >= 0.0 && res.score <= 1.0);
1127 }
1128 }
1129
1130 #[test]
1131 fn test_fusion_quality_metrics() {
1132 let fusion_results = FusedResults {
1133 results: vec![
1134 VectorSearchResult {
1135 resource: "relevant1".to_string(),
1136 score: 0.9,
1137 normalized_score: Some(0.9),
1138 source: "test".to_string(),
1139 original_rank: 0,
1140 final_rank: Some(1),
1141 vector: None,
1142 metadata: HashMap::new(),
1143 explanation: None,
1144 },
1145 VectorSearchResult {
1146 resource: "irrelevant1".to_string(),
1147 score: 0.8,
1148 normalized_score: Some(0.8),
1149 source: "test".to_string(),
1150 original_rank: 1,
1151 final_rank: Some(2),
1152 vector: None,
1153 metadata: HashMap::new(),
1154 explanation: None,
1155 },
1156 ],
1157 fusion_stats: FusionStats::default(),
1158 config: FusionConfig::default(),
1159 processing_time: Duration::from_millis(10),
1160 };
1161
1162 let ground_truth = vec!["relevant1".to_string(), "relevant2".to_string()];
1163 let metrics = fusion_utils::calculate_fusion_quality(&fusion_results, Some(&ground_truth));
1164
1165 assert_eq!(metrics.result_count, 2);
1166 assert_eq!(metrics.precision, 0.5); assert_eq!(metrics.recall, 0.5); assert!(metrics.diversity > 0.0);
1169 }
1170}