1use anyhow::Result;
47use serde::{Deserialize, Serialize};
48use std::collections::HashMap;
49use tracing::debug;
50
51#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
53pub enum HybridFusionStrategy {
54 WeightedSum,
56 ReciprocalRankFusion,
58 LearnedFusion,
60 ConvexCombination,
62 HarmonicMean,
64 GeometricMean,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct HybridFusionConfig {
71 pub strategy: HybridFusionStrategy,
73 pub dense_weight: f32,
75 pub sparse_weight: f32,
77 pub normalize_scores: bool,
79 pub normalization_method: NormalizationMethod,
81 pub rrf_k: f32,
83 pub min_score_threshold: f32,
85 pub max_results: usize,
87 pub enable_boosting: bool,
89}
90
91impl Default for HybridFusionConfig {
92 fn default() -> Self {
93 Self {
94 strategy: HybridFusionStrategy::WeightedSum,
95 dense_weight: 0.7,
96 sparse_weight: 0.3,
97 normalize_scores: true,
98 normalization_method: NormalizationMethod::MinMax,
99 rrf_k: 60.0,
100 min_score_threshold: 0.0,
101 max_results: 100,
102 enable_boosting: false,
103 }
104 }
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
109pub enum NormalizationMethod {
110 MinMax,
112 ZScore,
114 Softmax,
116 Rank,
118 None,
120}
121
122#[derive(Debug, Clone)]
124pub struct FusedResult {
125 pub id: String,
127 pub score: f32,
129 pub dense_score: Option<f32>,
131 pub sparse_score: Option<f32>,
133 pub dense_rank: Option<usize>,
135 pub sparse_rank: Option<usize>,
137}
138
139pub struct HybridFusion {
141 config: HybridFusionConfig,
142 stats: HybridFusionStatistics,
143}
144
145#[derive(Debug, Clone, Default)]
147pub struct HybridFusionStatistics {
148 pub total_fusions: usize,
149 pub avg_dense_results: f64,
150 pub avg_sparse_results: f64,
151 pub avg_fused_results: f64,
152 pub avg_overlap: f64,
153}
154
155impl HybridFusion {
156 pub fn new(config: HybridFusionConfig) -> Self {
158 let total_weight = config.dense_weight + config.sparse_weight;
160 let normalized_config = if (total_weight - 1.0).abs() > 1e-6 {
161 debug!(
162 "Normalizing fusion weights: dense={}, sparse={} -> sum={}",
163 config.dense_weight, config.sparse_weight, total_weight
164 );
165 HybridFusionConfig {
166 dense_weight: config.dense_weight / total_weight,
167 sparse_weight: config.sparse_weight / total_weight,
168 ..config
169 }
170 } else {
171 config
172 };
173
174 Self {
175 config: normalized_config,
176 stats: HybridFusionStatistics::default(),
177 }
178 }
179
180 pub fn fuse(
182 &mut self,
183 dense_results: Vec<(String, f32)>,
184 sparse_results: Vec<(String, f32)>,
185 ) -> Result<Vec<FusedResult>> {
186 self.stats.total_fusions += 1;
188 self.stats.avg_dense_results = self.update_avg(
189 self.stats.avg_dense_results,
190 dense_results.len() as f64,
191 self.stats.total_fusions,
192 );
193 self.stats.avg_sparse_results = self.update_avg(
194 self.stats.avg_sparse_results,
195 sparse_results.len() as f64,
196 self.stats.total_fusions,
197 );
198
199 let normalized_dense = if self.config.normalize_scores {
201 self.normalize(&dense_results)
202 } else {
203 dense_results.clone()
204 };
205
206 let normalized_sparse = if self.config.normalize_scores {
207 self.normalize(&sparse_results)
208 } else {
209 sparse_results.clone()
210 };
211
212 let fused = match self.config.strategy {
214 HybridFusionStrategy::WeightedSum => {
215 self.weighted_sum_fusion(&normalized_dense, &normalized_sparse)
216 }
217 HybridFusionStrategy::ReciprocalRankFusion => {
218 self.rrf_fusion(&dense_results, &sparse_results)
219 }
220 HybridFusionStrategy::LearnedFusion => {
221 self.learned_fusion(&normalized_dense, &normalized_sparse)
222 }
223 HybridFusionStrategy::ConvexCombination => {
224 self.convex_combination(&normalized_dense, &normalized_sparse)
225 }
226 HybridFusionStrategy::HarmonicMean => {
227 self.harmonic_mean_fusion(&normalized_dense, &normalized_sparse)
228 }
229 HybridFusionStrategy::GeometricMean => {
230 self.geometric_mean_fusion(&normalized_dense, &normalized_sparse)
231 }
232 };
233
234 let dense_ids: std::collections::HashSet<_> =
236 dense_results.iter().map(|(id, _)| id).collect();
237 let sparse_ids: std::collections::HashSet<_> =
238 sparse_results.iter().map(|(id, _)| id).collect();
239 let overlap = dense_ids.intersection(&sparse_ids).count();
240 let total_unique = dense_ids.union(&sparse_ids).count();
241 let overlap_ratio = if total_unique > 0 {
242 overlap as f64 / total_unique as f64
243 } else {
244 0.0
245 };
246 self.stats.avg_overlap = self.update_avg(
247 self.stats.avg_overlap,
248 overlap_ratio,
249 self.stats.total_fusions,
250 );
251
252 let mut filtered: Vec<_> = fused
254 .into_iter()
255 .filter(|r| r.score >= self.config.min_score_threshold)
256 .collect();
257
258 filtered.sort_by(|a, b| {
260 b.score
261 .partial_cmp(&a.score)
262 .unwrap_or(std::cmp::Ordering::Equal)
263 });
264
265 filtered.truncate(self.config.max_results);
267
268 self.stats.avg_fused_results = self.update_avg(
269 self.stats.avg_fused_results,
270 filtered.len() as f64,
271 self.stats.total_fusions,
272 );
273
274 Ok(filtered)
275 }
276
277 fn weighted_sum_fusion(
279 &self,
280 dense: &[(String, f32)],
281 sparse: &[(String, f32)],
282 ) -> Vec<FusedResult> {
283 let mut score_map: HashMap<String, (Option<f32>, Option<f32>)> = HashMap::new();
284
285 for (id, score) in dense {
287 score_map.insert(id.clone(), (Some(*score), None));
288 }
289
290 for (id, score) in sparse {
292 score_map
293 .entry(id.clone())
294 .and_modify(|e| e.1 = Some(*score))
295 .or_insert((None, Some(*score)));
296 }
297
298 score_map
300 .into_iter()
301 .map(|(id, (dense_score, sparse_score))| {
302 let combined_score = dense_score.unwrap_or(0.0) * self.config.dense_weight
303 + sparse_score.unwrap_or(0.0) * self.config.sparse_weight;
304
305 FusedResult {
306 id,
307 score: combined_score,
308 dense_score,
309 sparse_score,
310 dense_rank: None,
311 sparse_rank: None,
312 }
313 })
314 .collect()
315 }
316
317 fn rrf_fusion(&self, dense: &[(String, f32)], sparse: &[(String, f32)]) -> Vec<FusedResult> {
319 let mut score_map: HashMap<String, (Option<usize>, Option<usize>)> = HashMap::new();
320
321 for (rank, (id, _)) in dense.iter().enumerate() {
323 score_map.insert(id.clone(), (Some(rank), None));
324 }
325
326 for (rank, (id, _)) in sparse.iter().enumerate() {
328 score_map
329 .entry(id.clone())
330 .and_modify(|e| e.1 = Some(rank))
331 .or_insert((None, Some(rank)));
332 }
333
334 score_map
336 .into_iter()
337 .map(|(id, (dense_rank, sparse_rank))| {
338 let dense_rrf = dense_rank.map_or(0.0, |r| 1.0 / (self.config.rrf_k + r as f32));
339 let sparse_rrf = sparse_rank.map_or(0.0, |r| 1.0 / (self.config.rrf_k + r as f32));
340
341 let combined_score =
342 dense_rrf * self.config.dense_weight + sparse_rrf * self.config.sparse_weight;
343
344 FusedResult {
345 id,
346 score: combined_score,
347 dense_score: dense_rank.map(|_| dense_rrf),
348 sparse_score: sparse_rank.map(|_| sparse_rrf),
349 dense_rank,
350 sparse_rank,
351 }
352 })
353 .collect()
354 }
355
356 fn learned_fusion(
358 &self,
359 dense: &[(String, f32)],
360 sparse: &[(String, f32)],
361 ) -> Vec<FusedResult> {
362 self.weighted_sum_fusion(dense, sparse)
365 }
366
367 fn convex_combination(
369 &self,
370 dense: &[(String, f32)],
371 sparse: &[(String, f32)],
372 ) -> Vec<FusedResult> {
373 self.weighted_sum_fusion(dense, sparse)
375 }
376
377 fn harmonic_mean_fusion(
379 &self,
380 dense: &[(String, f32)],
381 sparse: &[(String, f32)],
382 ) -> Vec<FusedResult> {
383 let mut score_map: HashMap<String, (Option<f32>, Option<f32>)> = HashMap::new();
384
385 for (id, score) in dense {
386 score_map.insert(id.clone(), (Some(*score), None));
387 }
388
389 for (id, score) in sparse {
390 score_map
391 .entry(id.clone())
392 .and_modify(|e| e.1 = Some(*score))
393 .or_insert((None, Some(*score)));
394 }
395
396 score_map
397 .into_iter()
398 .filter_map(
399 |(id, (dense_score, sparse_score))| match (dense_score, sparse_score) {
400 (Some(d), Some(s)) if d > 0.0 && s > 0.0 => {
401 let harmonic = 2.0 / (1.0 / d + 1.0 / s);
402 Some(FusedResult {
403 id,
404 score: harmonic,
405 dense_score: Some(d),
406 sparse_score: Some(s),
407 dense_rank: None,
408 sparse_rank: None,
409 })
410 }
411 (Some(d), None) => Some(FusedResult {
412 id,
413 score: d * self.config.dense_weight,
414 dense_score: Some(d),
415 sparse_score: None,
416 dense_rank: None,
417 sparse_rank: None,
418 }),
419 (None, Some(s)) => Some(FusedResult {
420 id,
421 score: s * self.config.sparse_weight,
422 dense_score: None,
423 sparse_score: Some(s),
424 dense_rank: None,
425 sparse_rank: None,
426 }),
427 _ => None,
428 },
429 )
430 .collect()
431 }
432
433 fn geometric_mean_fusion(
435 &self,
436 dense: &[(String, f32)],
437 sparse: &[(String, f32)],
438 ) -> Vec<FusedResult> {
439 let mut score_map: HashMap<String, (Option<f32>, Option<f32>)> = HashMap::new();
440
441 for (id, score) in dense {
442 score_map.insert(id.clone(), (Some(*score), None));
443 }
444
445 for (id, score) in sparse {
446 score_map
447 .entry(id.clone())
448 .and_modify(|e| e.1 = Some(*score))
449 .or_insert((None, Some(*score)));
450 }
451
452 score_map
453 .into_iter()
454 .filter_map(
455 |(id, (dense_score, sparse_score))| match (dense_score, sparse_score) {
456 (Some(d), Some(s)) if d > 0.0 && s > 0.0 => {
457 let geometric = (d * s).sqrt();
458 Some(FusedResult {
459 id,
460 score: geometric,
461 dense_score: Some(d),
462 sparse_score: Some(s),
463 dense_rank: None,
464 sparse_rank: None,
465 })
466 }
467 (Some(d), None) => Some(FusedResult {
468 id,
469 score: d * self.config.dense_weight,
470 dense_score: Some(d),
471 sparse_score: None,
472 dense_rank: None,
473 sparse_rank: None,
474 }),
475 (None, Some(s)) => Some(FusedResult {
476 id,
477 score: s * self.config.sparse_weight,
478 dense_score: None,
479 sparse_score: Some(s),
480 dense_rank: None,
481 sparse_rank: None,
482 }),
483 _ => None,
484 },
485 )
486 .collect()
487 }
488
489 fn normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
491 if results.is_empty() {
492 return Vec::new();
493 }
494
495 match self.config.normalization_method {
496 NormalizationMethod::MinMax => self.min_max_normalize(results),
497 NormalizationMethod::ZScore => self.z_score_normalize(results),
498 NormalizationMethod::Softmax => self.softmax_normalize(results),
499 NormalizationMethod::Rank => self.rank_normalize(results),
500 NormalizationMethod::None => results.to_vec(),
501 }
502 }
503
504 fn min_max_normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
506 let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
507 let min = scores.iter().fold(f32::INFINITY, |a, &b| a.min(b));
508 let max = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
509
510 if (max - min).abs() < 1e-6 {
511 return results.iter().map(|(id, _)| (id.clone(), 1.0)).collect();
512 }
513
514 results
515 .iter()
516 .map(|(id, score)| {
517 let normalized = (score - min) / (max - min);
518 (id.clone(), normalized)
519 })
520 .collect()
521 }
522
523 fn z_score_normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
525 let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
526 let mean = scores.iter().sum::<f32>() / scores.len() as f32;
527 let variance = scores.iter().map(|s| (s - mean).powi(2)).sum::<f32>() / scores.len() as f32;
528 let std_dev = variance.sqrt();
529
530 if std_dev < 1e-6 {
531 return results.iter().map(|(id, _)| (id.clone(), 0.0)).collect();
532 }
533
534 results
535 .iter()
536 .map(|(id, score)| {
537 let normalized = (score - mean) / std_dev;
538 (id.clone(), normalized)
539 })
540 .collect()
541 }
542
543 fn softmax_normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
545 let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
546 let max = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
547
548 let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max).exp()).collect();
550 let sum_exp: f32 = exp_scores.iter().sum();
551
552 results
553 .iter()
554 .enumerate()
555 .map(|(i, (id, _))| {
556 let normalized = exp_scores[i] / sum_exp;
557 (id.clone(), normalized)
558 })
559 .collect()
560 }
561
562 fn rank_normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
564 let n = results.len() as f32;
565 results
566 .iter()
567 .enumerate()
568 .map(|(rank, (id, _))| {
569 let normalized = 1.0 - (rank as f32 / n);
570 (id.clone(), normalized)
571 })
572 .collect()
573 }
574
575 fn update_avg(&self, old_avg: f64, new_val: f64, count: usize) -> f64 {
577 old_avg + (new_val - old_avg) / count as f64
578 }
579
580 pub fn stats(&self) -> &HybridFusionStatistics {
582 &self.stats
583 }
584
585 pub fn reset_stats(&mut self) {
587 self.stats = HybridFusionStatistics::default();
588 }
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594
595 #[test]
596 fn test_weighted_sum_fusion() {
597 let config = HybridFusionConfig {
598 strategy: HybridFusionStrategy::WeightedSum,
599 dense_weight: 0.6,
600 sparse_weight: 0.4,
601 normalize_scores: false,
602 ..Default::default()
603 };
604
605 let mut fusion = HybridFusion::new(config);
606
607 let dense = vec![("doc1".to_string(), 0.9), ("doc2".to_string(), 0.8)];
608
609 let sparse = vec![("doc2".to_string(), 0.7), ("doc3".to_string(), 0.6)];
610
611 let results = fusion.fuse(dense, sparse).unwrap();
612
613 assert!(!results.is_empty());
614 for i in 1..results.len() {
616 assert!(results[i - 1].score >= results[i].score);
617 }
618 }
619
620 #[test]
621 fn test_rrf_fusion() {
622 let config = HybridFusionConfig {
623 strategy: HybridFusionStrategy::ReciprocalRankFusion,
624 rrf_k: 60.0,
625 ..Default::default()
626 };
627
628 let mut fusion = HybridFusion::new(config);
629
630 let dense = vec![
631 ("doc1".to_string(), 0.9),
632 ("doc2".to_string(), 0.8),
633 ("doc3".to_string(), 0.7),
634 ];
635
636 let sparse = vec![
637 ("doc2".to_string(), 0.85),
638 ("doc3".to_string(), 0.75),
639 ("doc4".to_string(), 0.65),
640 ];
641
642 let results = fusion.fuse(dense, sparse).unwrap();
643
644 assert!(!results.is_empty());
645 let top_ids: Vec<_> = results.iter().take(2).map(|r| r.id.as_str()).collect();
647 assert!(top_ids.contains(&"doc2") || top_ids.contains(&"doc3"));
648 }
649
650 #[test]
651 fn test_normalization() {
652 let config = HybridFusionConfig {
653 normalize_scores: true,
654 normalization_method: NormalizationMethod::MinMax,
655 ..Default::default()
656 };
657
658 let fusion = HybridFusion::new(config);
659
660 let results = vec![
661 ("doc1".to_string(), 10.0),
662 ("doc2".to_string(), 20.0),
663 ("doc3".to_string(), 30.0),
664 ];
665
666 let normalized = fusion.min_max_normalize(&results);
667
668 assert_eq!(normalized[0].1, 0.0); assert_eq!(normalized[2].1, 1.0); assert!((normalized[1].1 - 0.5).abs() < 0.01); }
672
673 #[test]
674 fn test_harmonic_mean_fusion() {
675 let config = HybridFusionConfig {
676 strategy: HybridFusionStrategy::HarmonicMean,
677 ..Default::default()
678 };
679
680 let mut fusion = HybridFusion::new(config);
681
682 let dense = vec![("doc1".to_string(), 0.8), ("doc2".to_string(), 0.6)];
683
684 let sparse = vec![("doc1".to_string(), 0.9), ("doc3".to_string(), 0.7)];
685
686 let results = fusion.fuse(dense, sparse).unwrap();
687
688 assert!(!results.is_empty());
689 assert_eq!(results[0].id, "doc1");
691 }
692
693 #[test]
694 fn test_statistics() {
695 let config = HybridFusionConfig::default();
696 let mut fusion = HybridFusion::new(config);
697
698 let dense = vec![("doc1".to_string(), 0.9)];
699 let sparse = vec![("doc2".to_string(), 0.8)];
700
701 fusion.fuse(dense.clone(), sparse.clone()).unwrap();
702 fusion.fuse(dense, sparse).unwrap();
703
704 let stats = fusion.stats();
705 assert_eq!(stats.total_fusions, 2);
706 assert!(stats.avg_dense_results > 0.0);
707 assert!(stats.avg_sparse_results > 0.0);
708 }
709}