use crate::ChunkId;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FusionStrategy {
RRF {
k: f32,
},
Linear {
dense_weight: f32,
},
Convex {
alpha: f32,
},
DBSF,
Union,
Intersection,
#[cfg(feature = "multivector")]
ThreeWay {
dense_weight: f32,
sparse_weight: f32,
multivector_weight: f32,
},
}
impl Default for FusionStrategy {
fn default() -> Self {
Self::RRF { k: 60.0 }
}
}
impl FusionStrategy {
#[must_use]
pub fn fuse(
&self,
dense_results: &[(ChunkId, f32)],
sparse_results: &[(ChunkId, f32)],
) -> Vec<(ChunkId, f32)> {
match self {
FusionStrategy::RRF { k } => {
Self::reciprocal_rank_fusion(dense_results, sparse_results, *k)
}
FusionStrategy::Linear { dense_weight } => {
Self::linear_fusion(dense_results, sparse_results, *dense_weight)
}
FusionStrategy::Convex { alpha } => {
Self::convex_fusion(dense_results, sparse_results, *alpha)
}
FusionStrategy::DBSF => Self::dbsf_fusion(dense_results, sparse_results),
FusionStrategy::Union => Self::union_fusion(dense_results, sparse_results),
FusionStrategy::Intersection => {
Self::intersection_fusion(dense_results, sparse_results)
}
#[cfg(feature = "multivector")]
FusionStrategy::ThreeWay { .. } => {
Self::reciprocal_rank_fusion(dense_results, sparse_results, 60.0)
}
}
}
fn reciprocal_rank_fusion(
dense: &[(ChunkId, f32)],
sparse: &[(ChunkId, f32)],
k: f32,
) -> Vec<(ChunkId, f32)> {
let mut scores: HashMap<ChunkId, f32> = HashMap::new();
for (rank, (id, _)) in dense.iter().enumerate() {
*scores.entry(*id).or_insert(0.0) += 1.0 / (k + rank as f32 + 1.0);
}
for (rank, (id, _)) in sparse.iter().enumerate() {
*scores.entry(*id).or_insert(0.0) += 1.0 / (k + rank as f32 + 1.0);
}
Self::sort_by_score(scores)
}
fn linear_fusion(
dense: &[(ChunkId, f32)],
sparse: &[(ChunkId, f32)],
dense_weight: f32,
) -> Vec<(ChunkId, f32)> {
let sparse_weight = 1.0 - dense_weight;
let dense_normalized = Self::min_max_normalize(dense);
let sparse_normalized = Self::min_max_normalize(sparse);
let mut scores: HashMap<ChunkId, f32> = HashMap::new();
for (id, score) in dense_normalized {
*scores.entry(id).or_insert(0.0) += dense_weight * score;
}
for (id, score) in sparse_normalized {
*scores.entry(id).or_insert(0.0) += sparse_weight * score;
}
Self::sort_by_score(scores)
}
fn convex_fusion(
dense: &[(ChunkId, f32)],
sparse: &[(ChunkId, f32)],
alpha: f32,
) -> Vec<(ChunkId, f32)> {
Self::linear_fusion(dense, sparse, alpha)
}
fn dbsf_fusion(dense: &[(ChunkId, f32)], sparse: &[(ChunkId, f32)]) -> Vec<(ChunkId, f32)> {
let dense_normalized = Self::z_score_normalize(dense);
let sparse_normalized = Self::z_score_normalize(sparse);
let mut scores: HashMap<ChunkId, f32> = HashMap::new();
for (id, score) in dense_normalized {
*scores.entry(id).or_insert(0.0) += score;
}
for (id, score) in sparse_normalized {
*scores.entry(id).or_insert(0.0) += score;
}
Self::sort_by_score(scores)
}
fn union_fusion(dense: &[(ChunkId, f32)], sparse: &[(ChunkId, f32)]) -> Vec<(ChunkId, f32)> {
let mut scores: HashMap<ChunkId, (f32, usize)> = HashMap::new();
for (rank, (id, score)) in dense.iter().enumerate() {
scores.insert(*id, (*score, rank));
}
for (rank, (id, score)) in sparse.iter().enumerate() {
scores.entry(*id).or_insert((*score, dense.len() + rank));
}
let mut results: Vec<_> = scores.into_iter().collect();
results.sort_by(|a, b| a.1 .1.cmp(&b.1 .1)); results.into_iter().map(|(id, (score, _))| (id, score)).collect()
}
fn intersection_fusion(
dense: &[(ChunkId, f32)],
sparse: &[(ChunkId, f32)],
) -> Vec<(ChunkId, f32)> {
let dense_ids: HashMap<ChunkId, f32> = dense.iter().copied().collect();
let sparse_ids: HashMap<ChunkId, f32> = sparse.iter().copied().collect();
let mut scores: HashMap<ChunkId, f32> = HashMap::new();
for (id, dense_score) in &dense_ids {
if let Some(sparse_score) = sparse_ids.get(id) {
scores.insert(*id, (dense_score + sparse_score) / 2.0);
}
}
Self::sort_by_score(scores)
}
fn min_max_normalize(results: &[(ChunkId, f32)]) -> Vec<(ChunkId, f32)> {
if results.is_empty() {
return Vec::new();
}
let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
let min = scores.iter().cloned().fold(f32::INFINITY, f32::min);
let max = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let range = max - min;
if range.abs() < f32::EPSILON {
return results.iter().map(|(id, _)| (*id, 1.0)).collect();
}
results.iter().map(|(id, score)| (*id, (score - min) / range)).collect()
}
fn z_score_normalize(results: &[(ChunkId, f32)]) -> Vec<(ChunkId, f32)> {
if results.is_empty() {
return Vec::new();
}
let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
let n = scores.len() as f32;
let mean: f32 = scores.iter().sum::<f32>() / n;
let variance: f32 = scores.iter().map(|s| (s - mean).powi(2)).sum::<f32>() / n;
let std_dev = variance.sqrt();
if std_dev.abs() < f32::EPSILON {
return results.iter().map(|(id, _)| (*id, 0.0)).collect();
}
results.iter().map(|(id, score)| (*id, (score - mean) / std_dev)).collect()
}
fn sort_by_score(scores: HashMap<ChunkId, f32>) -> Vec<(ChunkId, f32)> {
let mut results: Vec<_> = scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
#[cfg(feature = "multivector")]
#[must_use]
pub fn fuse_three(
&self,
dense: &[(ChunkId, f32)],
sparse: &[(ChunkId, f32)],
multivector: &[(ChunkId, f32)],
) -> Vec<(ChunkId, f32)> {
if let FusionStrategy::ThreeWay { dense_weight, sparse_weight, multivector_weight } = self {
Self::three_way_linear(
dense,
sparse,
multivector,
*dense_weight,
*sparse_weight,
*multivector_weight,
)
} else {
let dense_sparse = self.fuse(dense, sparse);
self.fuse(&dense_sparse, multivector)
}
}
#[cfg(feature = "multivector")]
fn three_way_linear(
dense: &[(ChunkId, f32)],
sparse: &[(ChunkId, f32)],
multivector: &[(ChunkId, f32)],
w_dense: f32,
w_sparse: f32,
w_multi: f32,
) -> Vec<(ChunkId, f32)> {
let mut scores: HashMap<ChunkId, f32> = HashMap::new();
let dense_norm = Self::min_max_normalize(dense);
let sparse_norm = Self::min_max_normalize(sparse);
let multi_norm = Self::min_max_normalize(multivector);
for (id, score) in dense_norm {
*scores.entry(id).or_insert(0.0) += w_dense * score;
}
for (id, score) in sparse_norm {
*scores.entry(id).or_insert(0.0) += w_sparse * score;
}
for (id, score) in multi_norm {
*scores.entry(id).or_insert(0.0) += w_multi * score;
}
Self::sort_by_score(scores)
}
#[cfg(feature = "multivector")]
#[must_use]
pub fn three_way(dense_weight: f32, sparse_weight: f32, multivector_weight: f32) -> Self {
Self::ThreeWay { dense_weight, sparse_weight, multivector_weight }
}
}
#[cfg(test)]
mod tests {
use super::*;
fn chunk_id(n: u128) -> ChunkId {
ChunkId(uuid::Uuid::from_u128(n))
}
#[test]
fn test_fusion_strategy_default() {
let strategy = FusionStrategy::default();
match strategy {
FusionStrategy::RRF { k } => assert!((k - 60.0).abs() < 0.01),
_ => panic!("Expected RRF"),
}
}
#[test]
fn test_fusion_strategy_serialization() {
let strategy = FusionStrategy::Linear { dense_weight: 0.7 };
let json = serde_json::to_string(&strategy).unwrap();
let deserialized: FusionStrategy = serde_json::from_str(&json).unwrap();
match deserialized {
FusionStrategy::Linear { dense_weight } => {
assert!((dense_weight - 0.7).abs() < 0.01);
}
_ => panic!("Wrong strategy type"),
}
}
#[test]
fn test_rrf_empty() {
let strategy = FusionStrategy::RRF { k: 60.0 };
let results = strategy.fuse(&[], &[]);
assert!(results.is_empty());
}
#[test]
fn test_rrf_dense_only() {
let strategy = FusionStrategy::RRF { k: 60.0 };
let dense = vec![(chunk_id(1), 0.9), (chunk_id(2), 0.8)];
let results = strategy.fuse(&dense, &[]);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, chunk_id(1));
assert_eq!(results[1].0, chunk_id(2));
}
#[test]
fn test_rrf_sparse_only() {
let strategy = FusionStrategy::RRF { k: 60.0 };
let sparse = vec![(chunk_id(1), 0.9), (chunk_id(2), 0.8)];
let results = strategy.fuse(&[], &sparse);
assert_eq!(results.len(), 2);
}
#[test]
fn test_rrf_combines_ranks() {
let strategy = FusionStrategy::RRF { k: 60.0 };
let dense = vec![(chunk_id(1), 0.9), (chunk_id(2), 0.8)];
let sparse = vec![(chunk_id(1), 0.9), (chunk_id(3), 0.8)];
let results = strategy.fuse(&dense, &sparse);
assert_eq!(results.len(), 3);
assert_eq!(results[0].0, chunk_id(1));
}
#[test]
fn test_rrf_score_calculation() {
let strategy = FusionStrategy::RRF { k: 60.0 };
let dense = vec![(chunk_id(1), 1.0)]; let sparse = vec![(chunk_id(1), 1.0)];
let results = strategy.fuse(&dense, &sparse);
let expected = 2.0 / 61.0;
assert!((results[0].1 - expected).abs() < 0.001);
}
#[test]
fn test_linear_empty() {
let strategy = FusionStrategy::Linear { dense_weight: 0.5 };
let results = strategy.fuse(&[], &[]);
assert!(results.is_empty());
}
#[test]
fn test_linear_dense_only() {
let strategy = FusionStrategy::Linear { dense_weight: 0.7 };
let dense = vec![(chunk_id(1), 1.0), (chunk_id(2), 0.5)];
let results = strategy.fuse(&dense, &[]);
assert!(!results.is_empty());
}
#[test]
fn test_linear_equal_weights() {
let strategy = FusionStrategy::Linear { dense_weight: 0.5 };
let dense = vec![(chunk_id(1), 1.0)];
let sparse = vec![(chunk_id(1), 1.0)];
let results = strategy.fuse(&dense, &sparse);
assert!((results[0].1 - 1.0).abs() < 0.01);
}
#[test]
fn test_linear_weight_preference() {
let strategy = FusionStrategy::Linear { dense_weight: 0.9 };
let dense = vec![(chunk_id(1), 1.0), (chunk_id(2), 0.0)];
let sparse = vec![(chunk_id(2), 1.0), (chunk_id(1), 0.0)];
let results = strategy.fuse(&dense, &sparse);
assert_eq!(results[0].0, chunk_id(1));
}
#[test]
fn test_convex_same_as_linear() {
let linear = FusionStrategy::Linear { dense_weight: 0.6 };
let convex = FusionStrategy::Convex { alpha: 0.6 };
let dense = vec![(chunk_id(1), 0.9), (chunk_id(2), 0.5)];
let sparse = vec![(chunk_id(2), 0.8), (chunk_id(3), 0.4)];
let linear_results = linear.fuse(&dense, &sparse);
let convex_results = convex.fuse(&dense, &sparse);
assert_eq!(linear_results.len(), convex_results.len());
}
#[test]
fn test_dbsf_empty() {
let strategy = FusionStrategy::DBSF;
let results = strategy.fuse(&[], &[]);
assert!(results.is_empty());
}
#[test]
fn test_dbsf_z_score() {
let strategy = FusionStrategy::DBSF;
let dense = vec![(chunk_id(1), 10.0), (chunk_id(2), 5.0), (chunk_id(3), 0.0)];
let sparse = vec![(chunk_id(1), 100.0), (chunk_id(2), 50.0), (chunk_id(3), 0.0)];
let results = strategy.fuse(&dense, &sparse);
assert_eq!(results[0].0, chunk_id(1));
}
#[test]
fn test_union_combines_all() {
let strategy = FusionStrategy::Union;
let dense = vec![(chunk_id(1), 0.9)];
let sparse = vec![(chunk_id(2), 0.8)];
let results = strategy.fuse(&dense, &sparse);
assert_eq!(results.len(), 2);
}
#[test]
fn test_union_deduplicates() {
let strategy = FusionStrategy::Union;
let dense = vec![(chunk_id(1), 0.9), (chunk_id(2), 0.8)];
let sparse = vec![(chunk_id(1), 0.7), (chunk_id(3), 0.6)];
let results = strategy.fuse(&dense, &sparse);
assert_eq!(results.len(), 3);
}
#[test]
fn test_union_prefers_dense_rank() {
let strategy = FusionStrategy::Union;
let dense = vec![(chunk_id(1), 0.9)];
let sparse = vec![(chunk_id(1), 0.5)];
let results = strategy.fuse(&dense, &sparse);
assert!((results[0].1 - 0.9).abs() < f32::EPSILON);
}
#[test]
fn test_intersection_empty_no_overlap() {
let strategy = FusionStrategy::Intersection;
let dense = vec![(chunk_id(1), 0.9)];
let sparse = vec![(chunk_id(2), 0.8)];
let results = strategy.fuse(&dense, &sparse);
assert!(results.is_empty());
}
#[test]
fn test_intersection_keeps_overlap() {
let strategy = FusionStrategy::Intersection;
let dense = vec![(chunk_id(1), 0.8), (chunk_id(2), 0.6)];
let sparse = vec![(chunk_id(2), 0.9), (chunk_id(3), 0.5)];
let results = strategy.fuse(&dense, &sparse);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, chunk_id(2));
}
#[test]
fn test_intersection_averages_scores() {
let strategy = FusionStrategy::Intersection;
let dense = vec![(chunk_id(1), 0.8)];
let sparse = vec![(chunk_id(1), 0.4)];
let results = strategy.fuse(&dense, &sparse);
assert!((results[0].1 - 0.6).abs() < 0.001);
}
#[test]
fn test_min_max_normalize_empty() {
let normalized = FusionStrategy::min_max_normalize(&[]);
assert!(normalized.is_empty());
}
#[test]
fn test_min_max_normalize_single() {
let results = vec![(chunk_id(1), 5.0)];
let normalized = FusionStrategy::min_max_normalize(&results);
assert_eq!(normalized.len(), 1);
assert!((normalized[0].1 - 1.0).abs() < 0.001);
}
#[test]
fn test_min_max_normalize_range() {
let results = vec![(chunk_id(1), 10.0), (chunk_id(2), 5.0), (chunk_id(3), 0.0)];
let normalized = FusionStrategy::min_max_normalize(&results);
assert!((normalized[0].1 - 1.0).abs() < 0.001);
assert!((normalized[2].1 - 0.0).abs() < 0.001);
assert!((normalized[1].1 - 0.5).abs() < 0.001);
}
#[test]
fn test_z_score_normalize_empty() {
let normalized = FusionStrategy::z_score_normalize(&[]);
assert!(normalized.is_empty());
}
#[test]
fn test_z_score_normalize_same_values() {
let results = vec![(chunk_id(1), 5.0), (chunk_id(2), 5.0), (chunk_id(3), 5.0)];
let normalized = FusionStrategy::z_score_normalize(&results);
for (_, score) in normalized {
assert!(score.abs() < 0.001);
}
}
use proptest::prelude::*;
proptest! {
#[test]
fn prop_rrf_scores_positive(
n_dense in 1usize..10,
n_sparse in 1usize..10
) {
let dense: Vec<_> = (0..n_dense)
.map(|i| (chunk_id(i as u128), 1.0 - i as f32 * 0.1))
.collect();
let sparse: Vec<_> = (100..100 + n_sparse)
.map(|i| (chunk_id(i as u128), 1.0 - (i - 100) as f32 * 0.1))
.collect();
let strategy = FusionStrategy::RRF { k: 60.0 };
let results = strategy.fuse(&dense, &sparse);
for (_, score) in results {
prop_assert!(score > 0.0);
}
}
#[test]
fn prop_linear_weights_sum_to_one(dense_weight in 0.0f32..1.0) {
let dense = vec![(chunk_id(1), 1.0)];
let sparse = vec![(chunk_id(1), 1.0)];
let strategy = FusionStrategy::Linear { dense_weight };
let results = strategy.fuse(&dense, &sparse);
prop_assert!((results[0].1 - 1.0).abs() < 0.01);
}
#[test]
fn prop_intersection_subset_of_inputs(
dense_ids in prop::collection::vec(0u128..100, 1..10),
sparse_ids in prop::collection::vec(0u128..100, 1..10)
) {
let dense: Vec<_> = dense_ids.iter().map(|&i| (chunk_id(i), 1.0)).collect();
let sparse: Vec<_> = sparse_ids.iter().map(|&i| (chunk_id(i), 1.0)).collect();
let strategy = FusionStrategy::Intersection;
let results = strategy.fuse(&dense, &sparse);
let dense_set: std::collections::HashSet<_> = dense_ids.iter().copied().collect();
let sparse_set: std::collections::HashSet<_> = sparse_ids.iter().copied().collect();
for (id, _) in results {
let id_val = id.0.as_u128();
prop_assert!(dense_set.contains(&id_val) && sparse_set.contains(&id_val));
}
}
#[test]
fn prop_fusion_deterministic(
n in 1usize..5
) {
let dense: Vec<_> = (0..n).map(|i| (chunk_id(i as u128), 1.0)).collect();
let sparse: Vec<_> = (0..n).map(|i| (chunk_id(i as u128), 0.5)).collect();
let strategy = FusionStrategy::RRF { k: 60.0 };
let results1 = strategy.fuse(&dense, &sparse);
let results2 = strategy.fuse(&dense, &sparse);
prop_assert_eq!(results1.len(), results2.len());
for ((id1, s1), (id2, s2)) in results1.iter().zip(results2.iter()) {
prop_assert_eq!(id1, id2);
prop_assert!((s1 - s2).abs() < 0.0001);
}
}
}
}