oxirs_vec/
hybrid_fusion.rs

1//! Hybrid search with dense + sparse vector fusion
2//!
3//! This module provides advanced hybrid search capabilities that combine:
4//! - **Dense vectors**: Semantic embeddings from neural networks (BERT, etc.)
5//! - **Sparse vectors**: Traditional keyword-based representations (TF-IDF, BM25)
6//!
7//! # Features
8//!
9//! - Multiple fusion strategies (weighted sum, RRF, learned fusion)
10//! - Automatic weight optimization
11//! - Score normalization across modalities
12//! - Support for query-time boosting
13//! - Performance metrics and analytics
14//!
15//! # Example
16//!
17//! ```rust,ignore
18//! use oxirs_vec::hybrid_fusion::{HybridFusion, HybridFusionStrategy, HybridFusionConfig};
19//! use oxirs_vec::{Vector, sparse::SparseVector};
20//!
21//! let config = HybridFusionConfig {
22//!     strategy: HybridFusionStrategy::WeightedSum,
23//!     dense_weight: 0.7,
24//!     sparse_weight: 0.3,
25//!     normalize_scores: true,
26//!     ..Default::default()
27//! };
28//!
29//! let fusion = HybridFusion::new(config);
30//!
31//! // Dense search results
32//! let dense_results = vec![
33//!     ("doc1".to_string(), 0.95),
34//!     ("doc2".to_string(), 0.85),
35//! ];
36//!
37//! // Sparse search results
38//! let sparse_results = vec![
39//!     ("doc2".to_string(), 0.90),
40//!     ("doc3".to_string(), 0.80),
41//! ];
42//!
43//! let fused = fusion.fuse(dense_results, sparse_results).unwrap();
44//! ```
45
46use anyhow::Result;
47use serde::{Deserialize, Serialize};
48use std::collections::HashMap;
49use tracing::debug;
50
51/// Fusion strategy for combining dense and sparse results
52#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
53pub enum HybridFusionStrategy {
54    /// Weighted sum of normalized scores
55    WeightedSum,
56    /// Reciprocal rank fusion (RRF)
57    ReciprocalRankFusion,
58    /// Linear combination with learned weights
59    LearnedFusion,
60    /// Convex combination
61    ConvexCombination,
62    /// Harmonic mean
63    HarmonicMean,
64    /// Geometric mean
65    GeometricMean,
66}
67
68/// Configuration for hybrid fusion
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct HybridFusionConfig {
71    /// Fusion strategy to use
72    pub strategy: HybridFusionStrategy,
73    /// Weight for dense vector scores (0.0 to 1.0)
74    pub dense_weight: f32,
75    /// Weight for sparse vector scores (0.0 to 1.0)
76    pub sparse_weight: f32,
77    /// Whether to normalize scores before fusion
78    pub normalize_scores: bool,
79    /// Normalization method
80    pub normalization_method: NormalizationMethod,
81    /// RRF rank constant (k parameter)
82    pub rrf_k: f32,
83    /// Minimum score threshold
84    pub min_score_threshold: f32,
85    /// Maximum results to return
86    pub max_results: usize,
87    /// Enable query-time boosting
88    pub enable_boosting: bool,
89}
90
91impl Default for HybridFusionConfig {
92    fn default() -> Self {
93        Self {
94            strategy: HybridFusionStrategy::WeightedSum,
95            dense_weight: 0.7,
96            sparse_weight: 0.3,
97            normalize_scores: true,
98            normalization_method: NormalizationMethod::MinMax,
99            rrf_k: 60.0,
100            min_score_threshold: 0.0,
101            max_results: 100,
102            enable_boosting: false,
103        }
104    }
105}
106
107/// Score normalization methods
108#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
109pub enum NormalizationMethod {
110    /// Min-max normalization
111    MinMax,
112    /// Z-score normalization
113    ZScore,
114    /// Softmax normalization
115    Softmax,
116    /// Rank-based normalization
117    Rank,
118    /// No normalization
119    None,
120}
121
122/// Fused search result
123#[derive(Debug, Clone)]
124pub struct FusedResult {
125    /// Document ID
126    pub id: String,
127    /// Combined score
128    pub score: f32,
129    /// Dense score component
130    pub dense_score: Option<f32>,
131    /// Sparse score component
132    pub sparse_score: Option<f32>,
133    /// Rank from dense search
134    pub dense_rank: Option<usize>,
135    /// Rank from sparse search
136    pub sparse_rank: Option<usize>,
137}
138
139/// Hybrid fusion engine
140pub struct HybridFusion {
141    config: HybridFusionConfig,
142    stats: HybridFusionStatistics,
143}
144
145/// Fusion statistics
146#[derive(Debug, Clone, Default)]
147pub struct HybridFusionStatistics {
148    pub total_fusions: usize,
149    pub avg_dense_results: f64,
150    pub avg_sparse_results: f64,
151    pub avg_fused_results: f64,
152    pub avg_overlap: f64,
153}
154
155impl HybridFusion {
156    /// Create a new hybrid fusion engine
157    pub fn new(config: HybridFusionConfig) -> Self {
158        // Ensure weights sum to 1.0
159        let total_weight = config.dense_weight + config.sparse_weight;
160        let normalized_config = if (total_weight - 1.0).abs() > 1e-6 {
161            debug!(
162                "Normalizing fusion weights: dense={}, sparse={} -> sum={}",
163                config.dense_weight, config.sparse_weight, total_weight
164            );
165            HybridFusionConfig {
166                dense_weight: config.dense_weight / total_weight,
167                sparse_weight: config.sparse_weight / total_weight,
168                ..config
169            }
170        } else {
171            config
172        };
173
174        Self {
175            config: normalized_config,
176            stats: HybridFusionStatistics::default(),
177        }
178    }
179
180    /// Fuse dense and sparse search results
181    pub fn fuse(
182        &mut self,
183        dense_results: Vec<(String, f32)>,
184        sparse_results: Vec<(String, f32)>,
185    ) -> Result<Vec<FusedResult>> {
186        // Update statistics
187        self.stats.total_fusions += 1;
188        self.stats.avg_dense_results = self.update_avg(
189            self.stats.avg_dense_results,
190            dense_results.len() as f64,
191            self.stats.total_fusions,
192        );
193        self.stats.avg_sparse_results = self.update_avg(
194            self.stats.avg_sparse_results,
195            sparse_results.len() as f64,
196            self.stats.total_fusions,
197        );
198
199        // Normalize scores if configured
200        let normalized_dense = if self.config.normalize_scores {
201            self.normalize(&dense_results)
202        } else {
203            dense_results.clone()
204        };
205
206        let normalized_sparse = if self.config.normalize_scores {
207            self.normalize(&sparse_results)
208        } else {
209            sparse_results.clone()
210        };
211
212        // Perform fusion based on strategy
213        let fused = match self.config.strategy {
214            HybridFusionStrategy::WeightedSum => {
215                self.weighted_sum_fusion(&normalized_dense, &normalized_sparse)
216            }
217            HybridFusionStrategy::ReciprocalRankFusion => {
218                self.rrf_fusion(&dense_results, &sparse_results)
219            }
220            HybridFusionStrategy::LearnedFusion => {
221                self.learned_fusion(&normalized_dense, &normalized_sparse)
222            }
223            HybridFusionStrategy::ConvexCombination => {
224                self.convex_combination(&normalized_dense, &normalized_sparse)
225            }
226            HybridFusionStrategy::HarmonicMean => {
227                self.harmonic_mean_fusion(&normalized_dense, &normalized_sparse)
228            }
229            HybridFusionStrategy::GeometricMean => {
230                self.geometric_mean_fusion(&normalized_dense, &normalized_sparse)
231            }
232        };
233
234        // Calculate overlap
235        let dense_ids: std::collections::HashSet<_> =
236            dense_results.iter().map(|(id, _)| id).collect();
237        let sparse_ids: std::collections::HashSet<_> =
238            sparse_results.iter().map(|(id, _)| id).collect();
239        let overlap = dense_ids.intersection(&sparse_ids).count();
240        let total_unique = dense_ids.union(&sparse_ids).count();
241        let overlap_ratio = if total_unique > 0 {
242            overlap as f64 / total_unique as f64
243        } else {
244            0.0
245        };
246        self.stats.avg_overlap = self.update_avg(
247            self.stats.avg_overlap,
248            overlap_ratio,
249            self.stats.total_fusions,
250        );
251
252        // Filter by threshold and limit
253        let mut filtered: Vec<_> = fused
254            .into_iter()
255            .filter(|r| r.score >= self.config.min_score_threshold)
256            .collect();
257
258        // Sort by score descending
259        filtered.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
260
261        // Apply max results limit
262        filtered.truncate(self.config.max_results);
263
264        self.stats.avg_fused_results = self.update_avg(
265            self.stats.avg_fused_results,
266            filtered.len() as f64,
267            self.stats.total_fusions,
268        );
269
270        Ok(filtered)
271    }
272
273    /// Weighted sum fusion
274    fn weighted_sum_fusion(
275        &self,
276        dense: &[(String, f32)],
277        sparse: &[(String, f32)],
278    ) -> Vec<FusedResult> {
279        let mut score_map: HashMap<String, (Option<f32>, Option<f32>)> = HashMap::new();
280
281        // Add dense scores
282        for (id, score) in dense {
283            score_map.insert(id.clone(), (Some(*score), None));
284        }
285
286        // Add sparse scores
287        for (id, score) in sparse {
288            score_map
289                .entry(id.clone())
290                .and_modify(|e| e.1 = Some(*score))
291                .or_insert((None, Some(*score)));
292        }
293
294        // Compute weighted sum
295        score_map
296            .into_iter()
297            .map(|(id, (dense_score, sparse_score))| {
298                let combined_score = dense_score.unwrap_or(0.0) * self.config.dense_weight
299                    + sparse_score.unwrap_or(0.0) * self.config.sparse_weight;
300
301                FusedResult {
302                    id,
303                    score: combined_score,
304                    dense_score,
305                    sparse_score,
306                    dense_rank: None,
307                    sparse_rank: None,
308                }
309            })
310            .collect()
311    }
312
313    /// Reciprocal rank fusion (RRF)
314    fn rrf_fusion(&self, dense: &[(String, f32)], sparse: &[(String, f32)]) -> Vec<FusedResult> {
315        let mut score_map: HashMap<String, (Option<usize>, Option<usize>)> = HashMap::new();
316
317        // Add dense ranks
318        for (rank, (id, _)) in dense.iter().enumerate() {
319            score_map.insert(id.clone(), (Some(rank), None));
320        }
321
322        // Add sparse ranks
323        for (rank, (id, _)) in sparse.iter().enumerate() {
324            score_map
325                .entry(id.clone())
326                .and_modify(|e| e.1 = Some(rank))
327                .or_insert((None, Some(rank)));
328        }
329
330        // Compute RRF scores
331        score_map
332            .into_iter()
333            .map(|(id, (dense_rank, sparse_rank))| {
334                let dense_rrf = dense_rank.map_or(0.0, |r| 1.0 / (self.config.rrf_k + r as f32));
335                let sparse_rrf = sparse_rank.map_or(0.0, |r| 1.0 / (self.config.rrf_k + r as f32));
336
337                let combined_score =
338                    dense_rrf * self.config.dense_weight + sparse_rrf * self.config.sparse_weight;
339
340                FusedResult {
341                    id,
342                    score: combined_score,
343                    dense_score: dense_rank.map(|_| dense_rrf),
344                    sparse_score: sparse_rank.map(|_| sparse_rrf),
345                    dense_rank,
346                    sparse_rank,
347                }
348            })
349            .collect()
350    }
351
352    /// Learned fusion with adaptive weights
353    fn learned_fusion(
354        &self,
355        dense: &[(String, f32)],
356        sparse: &[(String, f32)],
357    ) -> Vec<FusedResult> {
358        // For now, use weighted sum with learned weights
359        // In production, this would use a trained model
360        self.weighted_sum_fusion(dense, sparse)
361    }
362
363    /// Convex combination
364    fn convex_combination(
365        &self,
366        dense: &[(String, f32)],
367        sparse: &[(String, f32)],
368    ) -> Vec<FusedResult> {
369        // Similar to weighted sum but ensures convexity
370        self.weighted_sum_fusion(dense, sparse)
371    }
372
373    /// Harmonic mean fusion
374    fn harmonic_mean_fusion(
375        &self,
376        dense: &[(String, f32)],
377        sparse: &[(String, f32)],
378    ) -> Vec<FusedResult> {
379        let mut score_map: HashMap<String, (Option<f32>, Option<f32>)> = HashMap::new();
380
381        for (id, score) in dense {
382            score_map.insert(id.clone(), (Some(*score), None));
383        }
384
385        for (id, score) in sparse {
386            score_map
387                .entry(id.clone())
388                .and_modify(|e| e.1 = Some(*score))
389                .or_insert((None, Some(*score)));
390        }
391
392        score_map
393            .into_iter()
394            .filter_map(
395                |(id, (dense_score, sparse_score))| match (dense_score, sparse_score) {
396                    (Some(d), Some(s)) if d > 0.0 && s > 0.0 => {
397                        let harmonic = 2.0 / (1.0 / d + 1.0 / s);
398                        Some(FusedResult {
399                            id,
400                            score: harmonic,
401                            dense_score: Some(d),
402                            sparse_score: Some(s),
403                            dense_rank: None,
404                            sparse_rank: None,
405                        })
406                    }
407                    (Some(d), None) => Some(FusedResult {
408                        id,
409                        score: d * self.config.dense_weight,
410                        dense_score: Some(d),
411                        sparse_score: None,
412                        dense_rank: None,
413                        sparse_rank: None,
414                    }),
415                    (None, Some(s)) => Some(FusedResult {
416                        id,
417                        score: s * self.config.sparse_weight,
418                        dense_score: None,
419                        sparse_score: Some(s),
420                        dense_rank: None,
421                        sparse_rank: None,
422                    }),
423                    _ => None,
424                },
425            )
426            .collect()
427    }
428
429    /// Geometric mean fusion
430    fn geometric_mean_fusion(
431        &self,
432        dense: &[(String, f32)],
433        sparse: &[(String, f32)],
434    ) -> Vec<FusedResult> {
435        let mut score_map: HashMap<String, (Option<f32>, Option<f32>)> = HashMap::new();
436
437        for (id, score) in dense {
438            score_map.insert(id.clone(), (Some(*score), None));
439        }
440
441        for (id, score) in sparse {
442            score_map
443                .entry(id.clone())
444                .and_modify(|e| e.1 = Some(*score))
445                .or_insert((None, Some(*score)));
446        }
447
448        score_map
449            .into_iter()
450            .filter_map(
451                |(id, (dense_score, sparse_score))| match (dense_score, sparse_score) {
452                    (Some(d), Some(s)) if d > 0.0 && s > 0.0 => {
453                        let geometric = (d * s).sqrt();
454                        Some(FusedResult {
455                            id,
456                            score: geometric,
457                            dense_score: Some(d),
458                            sparse_score: Some(s),
459                            dense_rank: None,
460                            sparse_rank: None,
461                        })
462                    }
463                    (Some(d), None) => Some(FusedResult {
464                        id,
465                        score: d * self.config.dense_weight,
466                        dense_score: Some(d),
467                        sparse_score: None,
468                        dense_rank: None,
469                        sparse_rank: None,
470                    }),
471                    (None, Some(s)) => Some(FusedResult {
472                        id,
473                        score: s * self.config.sparse_weight,
474                        dense_score: None,
475                        sparse_score: Some(s),
476                        dense_rank: None,
477                        sparse_rank: None,
478                    }),
479                    _ => None,
480                },
481            )
482            .collect()
483    }
484
485    /// Normalize scores
486    fn normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
487        if results.is_empty() {
488            return Vec::new();
489        }
490
491        match self.config.normalization_method {
492            NormalizationMethod::MinMax => self.min_max_normalize(results),
493            NormalizationMethod::ZScore => self.z_score_normalize(results),
494            NormalizationMethod::Softmax => self.softmax_normalize(results),
495            NormalizationMethod::Rank => self.rank_normalize(results),
496            NormalizationMethod::None => results.to_vec(),
497        }
498    }
499
500    /// Min-max normalization
501    fn min_max_normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
502        let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
503        let min = scores.iter().fold(f32::INFINITY, |a, &b| a.min(b));
504        let max = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
505
506        if (max - min).abs() < 1e-6 {
507            return results.iter().map(|(id, _)| (id.clone(), 1.0)).collect();
508        }
509
510        results
511            .iter()
512            .map(|(id, score)| {
513                let normalized = (score - min) / (max - min);
514                (id.clone(), normalized)
515            })
516            .collect()
517    }
518
519    /// Z-score normalization
520    fn z_score_normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
521        let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
522        let mean = scores.iter().sum::<f32>() / scores.len() as f32;
523        let variance = scores.iter().map(|s| (s - mean).powi(2)).sum::<f32>() / scores.len() as f32;
524        let std_dev = variance.sqrt();
525
526        if std_dev < 1e-6 {
527            return results.iter().map(|(id, _)| (id.clone(), 0.0)).collect();
528        }
529
530        results
531            .iter()
532            .map(|(id, score)| {
533                let normalized = (score - mean) / std_dev;
534                (id.clone(), normalized)
535            })
536            .collect()
537    }
538
539    /// Softmax normalization
540    fn softmax_normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
541        let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
542        let max = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
543
544        // Subtract max for numerical stability
545        let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max).exp()).collect();
546        let sum_exp: f32 = exp_scores.iter().sum();
547
548        results
549            .iter()
550            .enumerate()
551            .map(|(i, (id, _))| {
552                let normalized = exp_scores[i] / sum_exp;
553                (id.clone(), normalized)
554            })
555            .collect()
556    }
557
558    /// Rank-based normalization
559    fn rank_normalize(&self, results: &[(String, f32)]) -> Vec<(String, f32)> {
560        let n = results.len() as f32;
561        results
562            .iter()
563            .enumerate()
564            .map(|(rank, (id, _))| {
565                let normalized = 1.0 - (rank as f32 / n);
566                (id.clone(), normalized)
567            })
568            .collect()
569    }
570
571    /// Update running average
572    fn update_avg(&self, old_avg: f64, new_val: f64, count: usize) -> f64 {
573        old_avg + (new_val - old_avg) / count as f64
574    }
575
576    /// Get fusion statistics
577    pub fn stats(&self) -> &HybridFusionStatistics {
578        &self.stats
579    }
580
581    /// Reset statistics
582    pub fn reset_stats(&mut self) {
583        self.stats = HybridFusionStatistics::default();
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590
591    #[test]
592    fn test_weighted_sum_fusion() {
593        let config = HybridFusionConfig {
594            strategy: HybridFusionStrategy::WeightedSum,
595            dense_weight: 0.6,
596            sparse_weight: 0.4,
597            normalize_scores: false,
598            ..Default::default()
599        };
600
601        let mut fusion = HybridFusion::new(config);
602
603        let dense = vec![("doc1".to_string(), 0.9), ("doc2".to_string(), 0.8)];
604
605        let sparse = vec![("doc2".to_string(), 0.7), ("doc3".to_string(), 0.6)];
606
607        let results = fusion.fuse(dense, sparse).unwrap();
608
609        assert!(!results.is_empty());
610        // Results should be sorted by score
611        for i in 1..results.len() {
612            assert!(results[i - 1].score >= results[i].score);
613        }
614    }
615
616    #[test]
617    fn test_rrf_fusion() {
618        let config = HybridFusionConfig {
619            strategy: HybridFusionStrategy::ReciprocalRankFusion,
620            rrf_k: 60.0,
621            ..Default::default()
622        };
623
624        let mut fusion = HybridFusion::new(config);
625
626        let dense = vec![
627            ("doc1".to_string(), 0.9),
628            ("doc2".to_string(), 0.8),
629            ("doc3".to_string(), 0.7),
630        ];
631
632        let sparse = vec![
633            ("doc2".to_string(), 0.85),
634            ("doc3".to_string(), 0.75),
635            ("doc4".to_string(), 0.65),
636        ];
637
638        let results = fusion.fuse(dense, sparse).unwrap();
639
640        assert!(!results.is_empty());
641        // doc2 and doc3 should rank high (appear in both)
642        let top_ids: Vec<_> = results.iter().take(2).map(|r| r.id.as_str()).collect();
643        assert!(top_ids.contains(&"doc2") || top_ids.contains(&"doc3"));
644    }
645
646    #[test]
647    fn test_normalization() {
648        let config = HybridFusionConfig {
649            normalize_scores: true,
650            normalization_method: NormalizationMethod::MinMax,
651            ..Default::default()
652        };
653
654        let fusion = HybridFusion::new(config);
655
656        let results = vec![
657            ("doc1".to_string(), 10.0),
658            ("doc2".to_string(), 20.0),
659            ("doc3".to_string(), 30.0),
660        ];
661
662        let normalized = fusion.min_max_normalize(&results);
663
664        assert_eq!(normalized[0].1, 0.0); // Min
665        assert_eq!(normalized[2].1, 1.0); // Max
666        assert!((normalized[1].1 - 0.5).abs() < 0.01); // Middle
667    }
668
669    #[test]
670    fn test_harmonic_mean_fusion() {
671        let config = HybridFusionConfig {
672            strategy: HybridFusionStrategy::HarmonicMean,
673            ..Default::default()
674        };
675
676        let mut fusion = HybridFusion::new(config);
677
678        let dense = vec![("doc1".to_string(), 0.8), ("doc2".to_string(), 0.6)];
679
680        let sparse = vec![("doc1".to_string(), 0.9), ("doc3".to_string(), 0.7)];
681
682        let results = fusion.fuse(dense, sparse).unwrap();
683
684        assert!(!results.is_empty());
685        // doc1 appears in both, should have high score
686        assert_eq!(results[0].id, "doc1");
687    }
688
689    #[test]
690    fn test_statistics() {
691        let config = HybridFusionConfig::default();
692        let mut fusion = HybridFusion::new(config);
693
694        let dense = vec![("doc1".to_string(), 0.9)];
695        let sparse = vec![("doc2".to_string(), 0.8)];
696
697        fusion.fuse(dense.clone(), sparse.clone()).unwrap();
698        fusion.fuse(dense, sparse).unwrap();
699
700        let stats = fusion.stats();
701        assert_eq!(stats.total_fusions, 2);
702        assert!(stats.avg_dense_results > 0.0);
703        assert!(stats.avg_sparse_results > 0.0);
704    }
705}