use serde::{Deserialize, Serialize};
pub const DEFAULT_RRF_K: usize = 60;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FusionStrategy {
#[serde(alias = "Rrf")]
Rrf {
k: usize,
},
#[serde(alias = "Weighted")]
Weighted {
weights: Vec<f64>,
},
#[serde(alias = "Union")]
Union,
#[serde(alias = "VectorOnly")]
VectorOnly,
#[serde(alias = "KeywordOnly")]
KeywordOnly,
}
impl Default for FusionStrategy {
fn default() -> Self {
Self::Rrf { k: DEFAULT_RRF_K }
}
}
impl FusionStrategy {
#[inline]
pub fn rrf() -> Self {
Self::Rrf { k: DEFAULT_RRF_K }
}
#[inline]
pub fn rrf_with_k(k: usize) -> Self {
Self::Rrf { k: k.max(1) } }
#[inline]
pub fn weighted(weights: Vec<f64>) -> Self {
Self::Weighted { weights }
}
#[inline]
pub fn union() -> Self {
Self::Union
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fusion_strategy_default() {
let default = FusionStrategy::default();
assert_eq!(default, FusionStrategy::Rrf { k: 60 });
}
#[test]
fn test_fusion_strategy_builders() {
assert_eq!(FusionStrategy::rrf(), FusionStrategy::Rrf { k: 60 });
assert_eq!(
FusionStrategy::rrf_with_k(20),
FusionStrategy::Rrf { k: 20 }
);
assert_eq!(FusionStrategy::rrf_with_k(0), FusionStrategy::Rrf { k: 1 }); assert_eq!(
FusionStrategy::weighted(vec![0.5, 0.5]),
FusionStrategy::Weighted {
weights: vec![0.5, 0.5]
}
);
assert_eq!(FusionStrategy::union(), FusionStrategy::Union);
}
}