1use anyhow::Result;
47use serde::{Deserialize, Serialize};
48use std::collections::HashMap;
49use tracing::debug;
50
51#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
53pub enum HybridFusionStrategy {
54 WeightedSum,
56 ReciprocalRankFusion,
58 LearnedFusion,
60 ConvexCombination,
62 HarmonicMean,
64 GeometricMean,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct HybridFusionConfig {
71 pub strategy: HybridFusionStrategy,
73 pub dense_weight: f32,
75 pub sparse_weight: f32,
77 pub normalize_scores: bool,
79 pub normalization_method: NormalizationMethod,
81 pub rrf_k: f32,
83 pub min_score_threshold: f32,
85 pub max_results: usize,
87 pub enable_boosting: bool,
89}
90
91impl Default for HybridFusionConfig {
92 fn default() -> Self {
93 Self {
94 strategy: HybridFusionStrategy::WeightedSum,
95 dense_weight: 0.7,
96 sparse_weight: 0.3,
97 normalize_scores: true,
98 normalization_method: NormalizationMethod::MinMax,
99 rrf_k: 60.0,
100 min_score_threshold: 0.0,
101 max_results: 100,
102 enable_boosting: false,
103 }
104 }
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
109pub enum NormalizationMethod {
110 MinMax,
112 ZScore,
114 Softmax,
116 Rank,
118 None,
120}
121
122#[derive(Debug, Clone)]
124pub struct FusedResult {
125 pub id: String,
127 pub score: f32,
129 pub dense_score: Option<f32>,
131 pub sparse_score: Option<f32>,
133 pub dense_rank: Option<usize>,
135 pub sparse_rank: Option<usize>,
137}
138
139pub struct HybridFusion {
141 config: HybridFusionConfig,
142 stats: HybridFusionStatistics,
143}
144
145#[derive(Debug, Clone, Default)]
147pub struct HybridFusionStatistics {
148 pub total_fusions: usize,
149 pub avg_dense_results: f64,
150 pub avg_sparse_results: f64,
151 pub avg_fused_results: f64,
152 pub avg_overlap: f64,
153}
154
155impl HybridFusion {
156 pub fn new(config: HybridFusionConfig) -> Self {
158 let total_weight = config.dense_weight + config.sparse_weight;
160 let normalized_config = if (total_weight - 1.0).abs() > 1e-6 {
161 debug!(
162 "Normalizing fusion weights: dense={}, sparse={} -> sum={}",
163 config.dense_weight, config.sparse_weight, total_weight
164 );
165 HybridFusionConfig {
166 dense_weight: config.dense_weight / total_weight,
167 sparse_weight: config.sparse_weight / total_weight,
168 ..config
169 }
170 } else {
171 config
172 };
173
174 Self {
175 config: normalized_config,
176 stats: HybridFusionStatistics::default(),
177 }
178 }
179
180 pub fn fuse(
182 &mut self,
183 dense_results: Vec<(String, f32)>,
184 sparse_results: Vec<(String, f32)>,
185 ) -> Result<Vec<FusedResult>> {
186 self.stats.total_fusions += 1;
188 self.stats.avg_dense_results = self.update_avg(
189 self.stats.avg_dense_results,
190 dense_results.len() as f64,
191 self.stats.total_fusions,
192 );
193 self.stats.avg_sparse_results = self.update_avg(
194 self.stats.avg_sparse_results,
195 sparse_results.len() as f64,
196 self.stats.total_fusions,
197 );
198
199 let normalized_dense = if self.config.normalize_scores {
201 self.normalize(&dense_results)
202 } else {
203 dense_results.clone()
204 };
205
206 let normalized_sparse = if self.config.normalize_scores {
207 self.normalize(&sparse_results)
208 } else {
209 sparse_results.clone()
210 };
211
212 let fused = match self.config.strategy {
214 HybridFusionStrategy::WeightedSum => {
215 self.weighted_sum_fusion(&normalized_dense, &normalized_sparse)
216 }
217 HybridFusionStrategy::ReciprocalRankFusion => {
218 self.rrf_fusion(&dense_results, &sparse_results)
219 }
220 HybridFusionStrategy::LearnedFusion => {
221 self.learned_fusion(&normalized_dense, &normalized_sparse)
222 }
223 HybridFusionStrategy::ConvexCombination => {
224 self.convex_combination(&normalized_dense, &normalized_sparse)
225 }
226 HybridFusionStrategy::HarmonicMean => {
227 self.harmonic_mean_fusion(&normalized_dense, &normalized_sparse)
228 }
229 HybridFusionStrategy::GeometricMean => {
230 self.geometric_mean_fusion(&normalized_dense, &normalized_sparse)
231 }
232 };
233
234 let dense_ids: std::collections::HashSet<_> =
236 dense_results.iter().map(|(id, _)| id).collect();
237 let sparse_ids: std::collections::HashSet<_> =
238 sparse_results.iter().map(|(id, _)| id).collect();
239 let overlap = dense_ids.intersection(&sparse_ids).count();
240 let total_unique = dense_ids.union(&sparse_ids).count();
241 let overlap_ratio = if total_unique > 0 {
242 overlap as f64 / total_unique as f64
243 } else {
244 0.0
245 };
246 self.stats.avg_overlap = self.update_avg(
247 self.stats.avg_overlap,
248 overlap_ratio,
249 self.stats.total_fusions,
250 );
251
252 let mut filtered: Vec<_> = fused
254 .into_iter()
255 .filter(|r| r.score >= self.config.min_score_threshold)
256 .collect();
257
258 filtered.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
260
261 filtered.truncate(self.config.max_results);
263
264 self.stats.avg_fused_results = self.update_avg(
265 self.stats.avg_fused_results,
266 filtered.len() as f64,
267 self.stats.total_fusions,
268 );
269
270 Ok(filtered)
271 }
272
273 fn weighted_sum_fusion(
275 &self,
276 dense: &[(String, f32)],
277 sparse: &[(String, f32)],
278 ) -> Vec<FusedResult> {
279 let mut score_map: HashMap<String, (Option<f32>, Option<f32>)> = HashMap::new();
280
281 for (id, score) in dense {
283 score_map.insert(id.clone(), (Some(*score), None));
284 }
285
286 for (id, score) in sparse {
288 score_map
289 .entry(id.clone())
290 .and_modify(|e| e.1 = Some(*score))
291 .or_insert((None, Some(*score)));
292 }
293
294 score_map
296 .into_iter()
297 .map(|(id, (dense_score, sparse_score))| {
298 let combined_score = dense_score.unwrap_or(0.0) * self.config.dense_weight
299 + sparse_score.unwrap_or(0.0) * self.config.sparse_weight;
300
301 FusedResult {
302 id,
303 score: combined_score,
304 dense_score,
305 sparse_score,
306 dense_rank: None,
307 sparse_rank: None,
308 }
309 })
310 .collect()
311 }
312
313 fn rrf_fusion(&self, dense: &[(String, f32)], sparse: &[(String, f32)]) -> Vec<FusedResult> {
315 let mut score_map: HashMap<String, (Option<usize>, Option<usize>)> = HashMap::new();
316
317 for (rank, (id, _)) in dense.iter().enumerate() {
319 score_map.insert(id.clone(), (Some(rank), None));
320 }
321
322 for (rank, (id, _)) in sparse.iter().enumerate() {
324 score_map
325 .entry(id.clone())
326 .and_modify(|e| e.1 = Some(rank))
327 .or_insert((None, Some(rank)));
328 }
329
330 score_map
332 .into_iter()
333 .map(|(id, (dense_rank, sparse_rank))| {
334 let dense_rrf = dense_rank.map_or(0.0, |r| 1.0 / (self.config.rrf_k + r as f32));
335 let sparse_rrf = sparse_rank.map_or(0.0, |r| 1.0 / (self.config.rrf_k + r as f32));
336
337 let combined_score =
338 dense_rrf * self.config.dense_weight + sparse_rrf * self.config.sparse_weight;
339
340 FusedResult {
341 id,
342 score: combined_score,
343 dense_score: dense_rank.map(|_| dense_rrf),
344 sparse_score: sparse_rank.map(|_| sparse_rrf),
345 dense_rank,
346 sparse_rank,
347 }
348 })
349 .collect()
350 }
351
352 fn learned_fusion(
354 &self,
355 dense: &[(String, f32)],
356 sparse: &[(String, f32)],
357 ) -> Vec<FusedResult> {
358 self.weighted_sum_fusion(dense, sparse)
361 }
362
363 fn convex_combination(
365 &self,
366 dense: &[(String, f32)],
367 sparse: &[(String, f32)],
368 ) -> Vec<FusedResult> {
369 self.weighted_sum_fusion(dense, sparse)
371 }
372
373 fn harmonic_mean_fusion(
375 &self,
376 dense: &[(String, f32)],
377 sparse: &[(String, f32)],
378 ) -> Vec<FusedResult> {
379 let mut score_map: HashMap<String, (Option<f32>, Option<f32>)> = HashMap::new();
380
381 for (id, score) in dense {
382 score_map.insert(id.clone(), (Some(*score), None));
383 }
384
385 for (id, score) in sparse {
386 score_map
387 .entry(id.clone())
388 .and_modify(|e| e.1 = Some(*score))
389 .or_insert((None, Some(*score)));
390 }
391
392 score_map
393 .into_iter()
394 .filter_map(
395 |(id, (dense_score, sparse_score))| match (dense_score, sparse_score) {
396 (Some(d), Some(s)) if d > 0.0 && s > 0.0 => {
397 let harmonic = 2.0 / (1.0 / d + 1.0 / s);
398 Some(FusedResult {
399 id,
400 score: harmonic,
401 dense_score: Some(d),
402 sparse_score: Some(s),
403 dense_rank: None,
404 sparse_rank: None,
405 })
406 }
407 (Some(d), None) => Some(FusedResult {
408 id,
409 score: d * self.config.dense_weight,
410 dense_score: Some(d),
411 sparse_score: None,
412 dense_rank: None,
413 sparse_rank: None,
414 }),
415 (None, Some(s)) => Some(FusedResult {
416 id,
417 score: s * self.config.sparse_weight,
418 dense_score: None,
419 sparse_score: Some(s),
420 dense_rank: None,
421 sparse_rank: None,
422 }),
423 _ => None,
424 },
425 )
426 .collect()
427 }
428
429 fn geometric_mean_fusion(
431 &self,
432 dense: &[(String, f32)],
433 sparse: &[(String, f32)],
434 ) -> Vec<FusedResult> {
435 let mut score_map: HashMap<String, (Option<f32>, Option<f32>)> = HashMap::new();
436
437 for (id, score) in dense {
438 score_map.insert(id.clone(), (Some(*score), None));
439 }
440
441 for (id, score) in sparse {
442 score_map
443 .entry(id.clone())
444 .and_modify(|e| e.1 = Some(*score))
445 .or_insert((None, Some(*score)));
446 }
447
448 score_map
449 .into_iter()
450 .filter_map(
451 |(id, (dense_score, sparse_score))| match (dense_score, sparse_score) {
452 (Some(d), Some(s)) if d > 0.0 && s > 0.0 => {
453 let geometric = (d * s).sqrt();
454 Some(FusedResult {
455 id,
456 score: geometric,
457 dense_score: Some(d),
458 sparse_score: Some(s),
459 dense_rank: None,
460 sparse_rank: None,
461 })
462 }
463 (Some(d), None) => Some(FusedResult {
464 id,
465 score: d * self.config.dense_weight,
466 dense_score: Some(d),
467 sparse_score: None,
468 dense_rank: None,
469 sparse_rank: None,
470 }),
471 (None, Some(s)) => Some(FusedResult {
472 id,
473 score: s * self.config.sparse_weight,
474 dense_score: None,
475 sparse_score: Some(s),
476 dense_rank: None,
477 sparse_rank: None,
478 }),
479 _ => None,
480 },
481 )
482 .collect()
483 }
484
485 fn normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
487 if results.is_empty() {
488 return Vec::new();
489 }
490
491 match self.config.normalization_method {
492 NormalizationMethod::MinMax => self.min_max_normalize(results),
493 NormalizationMethod::ZScore => self.z_score_normalize(results),
494 NormalizationMethod::Softmax => self.softmax_normalize(results),
495 NormalizationMethod::Rank => self.rank_normalize(results),
496 NormalizationMethod::None => results.to_vec(),
497 }
498 }
499
500 fn min_max_normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
502 let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
503 let min = scores.iter().fold(f32::INFINITY, |a, &b| a.min(b));
504 let max = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
505
506 if (max - min).abs() < 1e-6 {
507 return results.iter().map(|(id, _)| (id.clone(), 1.0)).collect();
508 }
509
510 results
511 .iter()
512 .map(|(id, score)| {
513 let normalized = (score - min) / (max - min);
514 (id.clone(), normalized)
515 })
516 .collect()
517 }
518
519 fn z_score_normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
521 let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
522 let mean = scores.iter().sum::<f32>() / scores.len() as f32;
523 let variance = scores.iter().map(|s| (s - mean).powi(2)).sum::<f32>() / scores.len() as f32;
524 let std_dev = variance.sqrt();
525
526 if std_dev < 1e-6 {
527 return results.iter().map(|(id, _)| (id.clone(), 0.0)).collect();
528 }
529
530 results
531 .iter()
532 .map(|(id, score)| {
533 let normalized = (score - mean) / std_dev;
534 (id.clone(), normalized)
535 })
536 .collect()
537 }
538
539 fn softmax_normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
541 let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
542 let max = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
543
544 let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max).exp()).collect();
546 let sum_exp: f32 = exp_scores.iter().sum();
547
548 results
549 .iter()
550 .enumerate()
551 .map(|(i, (id, _))| {
552 let normalized = exp_scores[i] / sum_exp;
553 (id.clone(), normalized)
554 })
555 .collect()
556 }
557
558 fn rank_normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
560 let n = results.len() as f32;
561 results
562 .iter()
563 .enumerate()
564 .map(|(rank, (id, _))| {
565 let normalized = 1.0 - (rank as f32 / n);
566 (id.clone(), normalized)
567 })
568 .collect()
569 }
570
571 fn update_avg(&self, old_avg: f64, new_val: f64, count: usize) -> f64 {
573 old_avg + (new_val - old_avg) / count as f64
574 }
575
576 pub fn stats(&self) -> &HybridFusionStatistics {
578 &self.stats
579 }
580
581 pub fn reset_stats(&mut self) {
583 self.stats = HybridFusionStatistics::default();
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590
591 #[test]
592 fn test_weighted_sum_fusion() {
593 let config = HybridFusionConfig {
594 strategy: HybridFusionStrategy::WeightedSum,
595 dense_weight: 0.6,
596 sparse_weight: 0.4,
597 normalize_scores: false,
598 ..Default::default()
599 };
600
601 let mut fusion = HybridFusion::new(config);
602
603 let dense = vec![("doc1".to_string(), 0.9), ("doc2".to_string(), 0.8)];
604
605 let sparse = vec![("doc2".to_string(), 0.7), ("doc3".to_string(), 0.6)];
606
607 let results = fusion.fuse(dense, sparse).unwrap();
608
609 assert!(!results.is_empty());
610 for i in 1..results.len() {
612 assert!(results[i - 1].score >= results[i].score);
613 }
614 }
615
616 #[test]
617 fn test_rrf_fusion() {
618 let config = HybridFusionConfig {
619 strategy: HybridFusionStrategy::ReciprocalRankFusion,
620 rrf_k: 60.0,
621 ..Default::default()
622 };
623
624 let mut fusion = HybridFusion::new(config);
625
626 let dense = vec![
627 ("doc1".to_string(), 0.9),
628 ("doc2".to_string(), 0.8),
629 ("doc3".to_string(), 0.7),
630 ];
631
632 let sparse = vec![
633 ("doc2".to_string(), 0.85),
634 ("doc3".to_string(), 0.75),
635 ("doc4".to_string(), 0.65),
636 ];
637
638 let results = fusion.fuse(dense, sparse).unwrap();
639
640 assert!(!results.is_empty());
641 let top_ids: Vec<_> = results.iter().take(2).map(|r| r.id.as_str()).collect();
643 assert!(top_ids.contains(&"doc2") || top_ids.contains(&"doc3"));
644 }
645
646 #[test]
647 fn test_normalization() {
648 let config = HybridFusionConfig {
649 normalize_scores: true,
650 normalization_method: NormalizationMethod::MinMax,
651 ..Default::default()
652 };
653
654 let fusion = HybridFusion::new(config);
655
656 let results = vec![
657 ("doc1".to_string(), 10.0),
658 ("doc2".to_string(), 20.0),
659 ("doc3".to_string(), 30.0),
660 ];
661
662 let normalized = fusion.min_max_normalize(&results);
663
664 assert_eq!(normalized[0].1, 0.0); assert_eq!(normalized[2].1, 1.0); assert!((normalized[1].1 - 0.5).abs() < 0.01); }
668
669 #[test]
670 fn test_harmonic_mean_fusion() {
671 let config = HybridFusionConfig {
672 strategy: HybridFusionStrategy::HarmonicMean,
673 ..Default::default()
674 };
675
676 let mut fusion = HybridFusion::new(config);
677
678 let dense = vec![("doc1".to_string(), 0.8), ("doc2".to_string(), 0.6)];
679
680 let sparse = vec![("doc1".to_string(), 0.9), ("doc3".to_string(), 0.7)];
681
682 let results = fusion.fuse(dense, sparse).unwrap();
683
684 assert!(!results.is_empty());
685 assert_eq!(results[0].id, "doc1");
687 }
688
689 #[test]
690 fn test_statistics() {
691 let config = HybridFusionConfig::default();
692 let mut fusion = HybridFusion::new(config);
693
694 let dense = vec![("doc1".to_string(), 0.9)];
695 let sparse = vec![("doc2".to_string(), 0.8)];
696
697 fusion.fuse(dense.clone(), sparse.clone()).unwrap();
698 fusion.fuse(dense, sparse).unwrap();
699
700 let stats = fusion.stats();
701 assert_eq!(stats.total_fusions, 2);
702 assert!(stats.avg_dense_results > 0.0);
703 assert!(stats.avg_sparse_results > 0.0);
704 }
705}