Skip to main content

oxirs_vec/
hybrid_fusion.rs

1//! Hybrid search with dense + sparse vector fusion
2//!
3//! This module provides advanced hybrid search capabilities that combine:
4//! - **Dense vectors**: Semantic embeddings from neural networks (BERT, etc.)
5//! - **Sparse vectors**: Traditional keyword-based representations (TF-IDF, BM25)
6//!
7//! # Features
8//!
9//! - Multiple fusion strategies (weighted sum, RRF, learned fusion)
10//! - Automatic weight optimization
11//! - Score normalization across modalities
12//! - Support for query-time boosting
13//! - Performance metrics and analytics
14//!
15//! # Example
16//!
17//! ```rust,ignore
18//! use oxirs_vec::hybrid_fusion::{HybridFusion, HybridFusionStrategy, HybridFusionConfig};
19//! use oxirs_vec::{Vector, sparse::SparseVector};
20//!
21//! let config = HybridFusionConfig {
22//!     strategy: HybridFusionStrategy::WeightedSum,
23//!     dense_weight: 0.7,
24//!     sparse_weight: 0.3,
25//!     normalize_scores: true,
26//!     ..Default::default()
27//! };
28//!
29//! let fusion = HybridFusion::new(config);
30//!
31//! // Dense search results
32//! let dense_results = vec![
33//!     ("doc1".to_string(), 0.95),
34//!     ("doc2".to_string(), 0.85),
35//! ];
36//!
37//! // Sparse search results
38//! let sparse_results = vec![
39//!     ("doc2".to_string(), 0.90),
40//!     ("doc3".to_string(), 0.80),
41//! ];
42//!
43//! let fused = fusion.fuse(dense_results, sparse_results).unwrap();
44//! ```
45
46use anyhow::Result;
47use serde::{Deserialize, Serialize};
48use std::collections::HashMap;
49use tracing::debug;
50
51/// Fusion strategy for combining dense and sparse results
52#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
53pub enum HybridFusionStrategy {
54    /// Weighted sum of normalized scores
55    WeightedSum,
56    /// Reciprocal rank fusion (RRF)
57    ReciprocalRankFusion,
58    /// Linear combination with learned weights
59    LearnedFusion,
60    /// Convex combination
61    ConvexCombination,
62    /// Harmonic mean
63    HarmonicMean,
64    /// Geometric mean
65    GeometricMean,
66}
67
68/// Configuration for hybrid fusion
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct HybridFusionConfig {
71    /// Fusion strategy to use
72    pub strategy: HybridFusionStrategy,
73    /// Weight for dense vector scores (0.0 to 1.0)
74    pub dense_weight: f32,
75    /// Weight for sparse vector scores (0.0 to 1.0)
76    pub sparse_weight: f32,
77    /// Whether to normalize scores before fusion
78    pub normalize_scores: bool,
79    /// Normalization method
80    pub normalization_method: NormalizationMethod,
81    /// RRF rank constant (k parameter)
82    pub rrf_k: f32,
83    /// Minimum score threshold
84    pub min_score_threshold: f32,
85    /// Maximum results to return
86    pub max_results: usize,
87    /// Enable query-time boosting
88    pub enable_boosting: bool,
89}
90
91impl Default for HybridFusionConfig {
92    fn default() -> Self {
93        Self {
94            strategy: HybridFusionStrategy::WeightedSum,
95            dense_weight: 0.7,
96            sparse_weight: 0.3,
97            normalize_scores: true,
98            normalization_method: NormalizationMethod::MinMax,
99            rrf_k: 60.0,
100            min_score_threshold: 0.0,
101            max_results: 100,
102            enable_boosting: false,
103        }
104    }
105}
106
107/// Score normalization methods
108#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
109pub enum NormalizationMethod {
110    /// Min-max normalization
111    MinMax,
112    /// Z-score normalization
113    ZScore,
114    /// Softmax normalization
115    Softmax,
116    /// Rank-based normalization
117    Rank,
118    /// No normalization
119    None,
120}
121
122/// Fused search result
123#[derive(Debug, Clone)]
124pub struct FusedResult {
125    /// Document ID
126    pub id: String,
127    /// Combined score
128    pub score: f32,
129    /// Dense score component
130    pub dense_score: Option<f32>,
131    /// Sparse score component
132    pub sparse_score: Option<f32>,
133    /// Rank from dense search
134    pub dense_rank: Option<usize>,
135    /// Rank from sparse search
136    pub sparse_rank: Option<usize>,
137}
138
139/// Hybrid fusion engine
140pub struct HybridFusion {
141    config: HybridFusionConfig,
142    stats: HybridFusionStatistics,
143}
144
145/// Fusion statistics
146#[derive(Debug, Clone, Default)]
147pub struct HybridFusionStatistics {
148    pub total_fusions: usize,
149    pub avg_dense_results: f64,
150    pub avg_sparse_results: f64,
151    pub avg_fused_results: f64,
152    pub avg_overlap: f64,
153}
154
155impl HybridFusion {
156    /// Create a new hybrid fusion engine
157    pub fn new(config: HybridFusionConfig) -> Self {
158        // Ensure weights sum to 1.0
159        let total_weight = config.dense_weight + config.sparse_weight;
160        let normalized_config = if (total_weight - 1.0).abs() > 1e-6 {
161            debug!(
162                "Normalizing fusion weights: dense={}, sparse={} -> sum={}",
163                config.dense_weight, config.sparse_weight, total_weight
164            );
165            HybridFusionConfig {
166                dense_weight: config.dense_weight / total_weight,
167                sparse_weight: config.sparse_weight / total_weight,
168                ..config
169            }
170        } else {
171            config
172        };
173
174        Self {
175            config: normalized_config,
176            stats: HybridFusionStatistics::default(),
177        }
178    }
179
180    /// Fuse dense and sparse search results
181    pub fn fuse(
182        &mut self,
183        dense_results: Vec<(String, f32)>,
184        sparse_results: Vec<(String, f32)>,
185    ) -> Result<Vec<FusedResult>> {
186        // Update statistics
187        self.stats.total_fusions += 1;
188        self.stats.avg_dense_results = self.update_avg(
189            self.stats.avg_dense_results,
190            dense_results.len() as f64,
191            self.stats.total_fusions,
192        );
193        self.stats.avg_sparse_results = self.update_avg(
194            self.stats.avg_sparse_results,
195            sparse_results.len() as f64,
196            self.stats.total_fusions,
197        );
198
199        // Normalize scores if configured
200        let normalized_dense = if self.config.normalize_scores {
201            self.normalize(&dense_results)
202        } else {
203            dense_results.clone()
204        };
205
206        let normalized_sparse = if self.config.normalize_scores {
207            self.normalize(&sparse_results)
208        } else {
209            sparse_results.clone()
210        };
211
212        // Perform fusion based on strategy
213        let fused = match self.config.strategy {
214            HybridFusionStrategy::WeightedSum => {
215                self.weighted_sum_fusion(&normalized_dense, &normalized_sparse)
216            }
217            HybridFusionStrategy::ReciprocalRankFusion => {
218                self.rrf_fusion(&dense_results, &sparse_results)
219            }
220            HybridFusionStrategy::LearnedFusion => {
221                self.learned_fusion(&normalized_dense, &normalized_sparse)
222            }
223            HybridFusionStrategy::ConvexCombination => {
224                self.convex_combination(&normalized_dense, &normalized_sparse)
225            }
226            HybridFusionStrategy::HarmonicMean => {
227                self.harmonic_mean_fusion(&normalized_dense, &normalized_sparse)
228            }
229            HybridFusionStrategy::GeometricMean => {
230                self.geometric_mean_fusion(&normalized_dense, &normalized_sparse)
231            }
232        };
233
234        // Calculate overlap
235        let dense_ids: std::collections::HashSet<_> =
236            dense_results.iter().map(|(id, _)| id).collect();
237        let sparse_ids: std::collections::HashSet<_> =
238            sparse_results.iter().map(|(id, _)| id).collect();
239        let overlap = dense_ids.intersection(&sparse_ids).count();
240        let total_unique = dense_ids.union(&sparse_ids).count();
241        let overlap_ratio = if total_unique > 0 {
242            overlap as f64 / total_unique as f64
243        } else {
244            0.0
245        };
246        self.stats.avg_overlap = self.update_avg(
247            self.stats.avg_overlap,
248            overlap_ratio,
249            self.stats.total_fusions,
250        );
251
252        // Filter by threshold and limit
253        let mut filtered: Vec<_> = fused
254            .into_iter()
255            .filter(|r| r.score >= self.config.min_score_threshold)
256            .collect();
257
258        // Sort by score descending
259        filtered.sort_by(|a, b| {
260            b.score
261                .partial_cmp(&a.score)
262                .unwrap_or(std::cmp::Ordering::Equal)
263        });
264
265        // Apply max results limit
266        filtered.truncate(self.config.max_results);
267
268        self.stats.avg_fused_results = self.update_avg(
269            self.stats.avg_fused_results,
270            filtered.len() as f64,
271            self.stats.total_fusions,
272        );
273
274        Ok(filtered)
275    }
276
277    /// Weighted sum fusion
278    fn weighted_sum_fusion(
279        &self,
280        dense: &[(String, f32)],
281        sparse: &[(String, f32)],
282    ) -> Vec<FusedResult> {
283        let mut score_map: HashMap<String, (Option<f32>, Option<f32>)> = HashMap::new();
284
285        // Add dense scores
286        for (id, score) in dense {
287            score_map.insert(id.clone(), (Some(*score), None));
288        }
289
290        // Add sparse scores
291        for (id, score) in sparse {
292            score_map
293                .entry(id.clone())
294                .and_modify(|e| e.1 = Some(*score))
295                .or_insert((None, Some(*score)));
296        }
297
298        // Compute weighted sum
299        score_map
300            .into_iter()
301            .map(|(id, (dense_score, sparse_score))| {
302                let combined_score = dense_score.unwrap_or(0.0) * self.config.dense_weight
303                    + sparse_score.unwrap_or(0.0) * self.config.sparse_weight;
304
305                FusedResult {
306                    id,
307                    score: combined_score,
308                    dense_score,
309                    sparse_score,
310                    dense_rank: None,
311                    sparse_rank: None,
312                }
313            })
314            .collect()
315    }
316
317    /// Reciprocal rank fusion (RRF)
318    fn rrf_fusion(&self, dense: &[(String, f32)], sparse: &[(String, f32)]) -> Vec<FusedResult> {
319        let mut score_map: HashMap<String, (Option<usize>, Option<usize>)> = HashMap::new();
320
321        // Add dense ranks
322        for (rank, (id, _)) in dense.iter().enumerate() {
323            score_map.insert(id.clone(), (Some(rank), None));
324        }
325
326        // Add sparse ranks
327        for (rank, (id, _)) in sparse.iter().enumerate() {
328            score_map
329                .entry(id.clone())
330                .and_modify(|e| e.1 = Some(rank))
331                .or_insert((None, Some(rank)));
332        }
333
334        // Compute RRF scores
335        score_map
336            .into_iter()
337            .map(|(id, (dense_rank, sparse_rank))| {
338                let dense_rrf = dense_rank.map_or(0.0, |r| 1.0 / (self.config.rrf_k + r as f32));
339                let sparse_rrf = sparse_rank.map_or(0.0, |r| 1.0 / (self.config.rrf_k + r as f32));
340
341                let combined_score =
342                    dense_rrf * self.config.dense_weight + sparse_rrf * self.config.sparse_weight;
343
344                FusedResult {
345                    id,
346                    score: combined_score,
347                    dense_score: dense_rank.map(|_| dense_rrf),
348                    sparse_score: sparse_rank.map(|_| sparse_rrf),
349                    dense_rank,
350                    sparse_rank,
351                }
352            })
353            .collect()
354    }
355
356    /// Learned fusion with adaptive weights
357    fn learned_fusion(
358        &self,
359        dense: &[(String, f32)],
360        sparse: &[(String, f32)],
361    ) -> Vec<FusedResult> {
362        // For now, use weighted sum with learned weights
363        // In production, this would use a trained model
364        self.weighted_sum_fusion(dense, sparse)
365    }
366
367    /// Convex combination
368    fn convex_combination(
369        &self,
370        dense: &[(String, f32)],
371        sparse: &[(String, f32)],
372    ) -> Vec<FusedResult> {
373        // Similar to weighted sum but ensures convexity
374        self.weighted_sum_fusion(dense, sparse)
375    }
376
377    /// Harmonic mean fusion
378    fn harmonic_mean_fusion(
379        &self,
380        dense: &[(String, f32)],
381        sparse: &[(String, f32)],
382    ) -> Vec<FusedResult> {
383        let mut score_map: HashMap<String, (Option<f32>, Option<f32>)> = HashMap::new();
384
385        for (id, score) in dense {
386            score_map.insert(id.clone(), (Some(*score), None));
387        }
388
389        for (id, score) in sparse {
390            score_map
391                .entry(id.clone())
392                .and_modify(|e| e.1 = Some(*score))
393                .or_insert((None, Some(*score)));
394        }
395
396        score_map
397            .into_iter()
398            .filter_map(
399                |(id, (dense_score, sparse_score))| match (dense_score, sparse_score) {
400                    (Some(d), Some(s)) if d > 0.0 && s > 0.0 => {
401                        let harmonic = 2.0 / (1.0 / d + 1.0 / s);
402                        Some(FusedResult {
403                            id,
404                            score: harmonic,
405                            dense_score: Some(d),
406                            sparse_score: Some(s),
407                            dense_rank: None,
408                            sparse_rank: None,
409                        })
410                    }
411                    (Some(d), None) => Some(FusedResult {
412                        id,
413                        score: d * self.config.dense_weight,
414                        dense_score: Some(d),
415                        sparse_score: None,
416                        dense_rank: None,
417                        sparse_rank: None,
418                    }),
419                    (None, Some(s)) => Some(FusedResult {
420                        id,
421                        score: s * self.config.sparse_weight,
422                        dense_score: None,
423                        sparse_score: Some(s),
424                        dense_rank: None,
425                        sparse_rank: None,
426                    }),
427                    _ => None,
428                },
429            )
430            .collect()
431    }
432
433    /// Geometric mean fusion
434    fn geometric_mean_fusion(
435        &self,
436        dense: &[(String, f32)],
437        sparse: &[(String, f32)],
438    ) -> Vec<FusedResult> {
439        let mut score_map: HashMap<String, (Option<f32>, Option<f32>)> = HashMap::new();
440
441        for (id, score) in dense {
442            score_map.insert(id.clone(), (Some(*score), None));
443        }
444
445        for (id, score) in sparse {
446            score_map
447                .entry(id.clone())
448                .and_modify(|e| e.1 = Some(*score))
449                .or_insert((None, Some(*score)));
450        }
451
452        score_map
453            .into_iter()
454            .filter_map(
455                |(id, (dense_score, sparse_score))| match (dense_score, sparse_score) {
456                    (Some(d), Some(s)) if d > 0.0 && s > 0.0 => {
457                        let geometric = (d * s).sqrt();
458                        Some(FusedResult {
459                            id,
460                            score: geometric,
461                            dense_score: Some(d),
462                            sparse_score: Some(s),
463                            dense_rank: None,
464                            sparse_rank: None,
465                        })
466                    }
467                    (Some(d), None) => Some(FusedResult {
468                        id,
469                        score: d * self.config.dense_weight,
470                        dense_score: Some(d),
471                        sparse_score: None,
472                        dense_rank: None,
473                        sparse_rank: None,
474                    }),
475                    (None, Some(s)) => Some(FusedResult {
476                        id,
477                        score: s * self.config.sparse_weight,
478                        dense_score: None,
479                        sparse_score: Some(s),
480                        dense_rank: None,
481                        sparse_rank: None,
482                    }),
483                    _ => None,
484                },
485            )
486            .collect()
487    }
488
489    /// Normalize scores
490    fn normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
491        if results.is_empty() {
492            return Vec::new();
493        }
494
495        match self.config.normalization_method {
496            NormalizationMethod::MinMax => self.min_max_normalize(results),
497            NormalizationMethod::ZScore => self.z_score_normalize(results),
498            NormalizationMethod::Softmax => self.softmax_normalize(results),
499            NormalizationMethod::Rank => self.rank_normalize(results),
500            NormalizationMethod::None => results.to_vec(),
501        }
502    }
503
504    /// Min-max normalization
505    fn min_max_normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
506        let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
507        let min = scores.iter().fold(f32::INFINITY, |a, &b| a.min(b));
508        let max = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
509
510        if (max - min).abs() < 1e-6 {
511            return results.iter().map(|(id, _)| (id.clone(), 1.0)).collect();
512        }
513
514        results
515            .iter()
516            .map(|(id, score)| {
517                let normalized = (score - min) / (max - min);
518                (id.clone(), normalized)
519            })
520            .collect()
521    }
522
523    /// Z-score normalization
524    fn z_score_normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
525        let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
526        let mean = scores.iter().sum::<f32>() / scores.len() as f32;
527        let variance = scores.iter().map(|s| (s - mean).powi(2)).sum::<f32>() / scores.len() as f32;
528        let std_dev = variance.sqrt();
529
530        if std_dev < 1e-6 {
531            return results.iter().map(|(id, _)| (id.clone(), 0.0)).collect();
532        }
533
534        results
535            .iter()
536            .map(|(id, score)| {
537                let normalized = (score - mean) / std_dev;
538                (id.clone(), normalized)
539            })
540            .collect()
541    }
542
543    /// Softmax normalization
544    fn softmax_normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
545        let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
546        let max = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
547
548        // Subtract max for numerical stability
549        let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max).exp()).collect();
550        let sum_exp: f32 = exp_scores.iter().sum();
551
552        results
553            .iter()
554            .enumerate()
555            .map(|(i, (id, _))| {
556                let normalized = exp_scores[i] / sum_exp;
557                (id.clone(), normalized)
558            })
559            .collect()
560    }
561
562    /// Rank-based normalization
563    fn rank_normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
564        let n = results.len() as f32;
565        results
566            .iter()
567            .enumerate()
568            .map(|(rank, (id, _))| {
569                let normalized = 1.0 - (rank as f32 / n);
570                (id.clone(), normalized)
571            })
572            .collect()
573    }
574
575    /// Update running average
576    fn update_avg(&self, old_avg: f64, new_val: f64, count: usize) -> f64 {
577        old_avg + (new_val - old_avg) / count as f64
578    }
579
580    /// Get fusion statistics
581    pub fn stats(&self) -> &HybridFusionStatistics {
582        &self.stats
583    }
584
585    /// Reset statistics
586    pub fn reset_stats(&mut self) {
587        self.stats = HybridFusionStatistics::default();
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594
595    #[test]
596    fn test_weighted_sum_fusion() {
597        let config = HybridFusionConfig {
598            strategy: HybridFusionStrategy::WeightedSum,
599            dense_weight: 0.6,
600            sparse_weight: 0.4,
601            normalize_scores: false,
602            ..Default::default()
603        };
604
605        let mut fusion = HybridFusion::new(config);
606
607        let dense = vec![("doc1".to_string(), 0.9), ("doc2".to_string(), 0.8)];
608
609        let sparse = vec![("doc2".to_string(), 0.7), ("doc3".to_string(), 0.6)];
610
611        let results = fusion.fuse(dense, sparse).unwrap();
612
613        assert!(!results.is_empty());
614        // Results should be sorted by score
615        for i in 1..results.len() {
616            assert!(results[i - 1].score >= results[i].score);
617        }
618    }
619
620    #[test]
621    fn test_rrf_fusion() {
622        let config = HybridFusionConfig {
623            strategy: HybridFusionStrategy::ReciprocalRankFusion,
624            rrf_k: 60.0,
625            ..Default::default()
626        };
627
628        let mut fusion = HybridFusion::new(config);
629
630        let dense = vec![
631            ("doc1".to_string(), 0.9),
632            ("doc2".to_string(), 0.8),
633            ("doc3".to_string(), 0.7),
634        ];
635
636        let sparse = vec![
637            ("doc2".to_string(), 0.85),
638            ("doc3".to_string(), 0.75),
639            ("doc4".to_string(), 0.65),
640        ];
641
642        let results = fusion.fuse(dense, sparse).unwrap();
643
644        assert!(!results.is_empty());
645        // doc2 and doc3 should rank high (appear in both)
646        let top_ids: Vec<_> = results.iter().take(2).map(|r| r.id.as_str()).collect();
647        assert!(top_ids.contains(&"doc2") || top_ids.contains(&"doc3"));
648    }
649
650    #[test]
651    fn test_normalization() {
652        let config = HybridFusionConfig {
653            normalize_scores: true,
654            normalization_method: NormalizationMethod::MinMax,
655            ..Default::default()
656        };
657
658        let fusion = HybridFusion::new(config);
659
660        let results = vec![
661            ("doc1".to_string(), 10.0),
662            ("doc2".to_string(), 20.0),
663            ("doc3".to_string(), 30.0),
664        ];
665
666        let normalized = fusion.min_max_normalize(&results);
667
668        assert_eq!(normalized[0].1, 0.0); // Min
669        assert_eq!(normalized[2].1, 1.0); // Max
670        assert!((normalized[1].1 - 0.5).abs() < 0.01); // Middle
671    }
672
673    #[test]
674    fn test_harmonic_mean_fusion() {
675        let config = HybridFusionConfig {
676            strategy: HybridFusionStrategy::HarmonicMean,
677            ..Default::default()
678        };
679
680        let mut fusion = HybridFusion::new(config);
681
682        let dense = vec![("doc1".to_string(), 0.8), ("doc2".to_string(), 0.6)];
683
684        let sparse = vec![("doc1".to_string(), 0.9), ("doc3".to_string(), 0.7)];
685
686        let results = fusion.fuse(dense, sparse).unwrap();
687
688        assert!(!results.is_empty());
689        // doc1 appears in both, should have high score
690        assert_eq!(results[0].id, "doc1");
691    }
692
693    #[test]
694    fn test_statistics() {
695        let config = HybridFusionConfig::default();
696        let mut fusion = HybridFusion::new(config);
697
698        let dense = vec![("doc1".to_string(), 0.9)];
699        let sparse = vec![("doc2".to_string(), 0.8)];
700
701        fusion.fuse(dense.clone(), sparse.clone()).unwrap();
702        fusion.fuse(dense, sparse).unwrap();
703
704        let stats = fusion.stats();
705        assert_eq!(stats.total_fusions, 2);
706        assert!(stats.avg_dense_results > 0.0);
707        assert!(stats.avg_sparse_results > 0.0);
708    }
709}