use crate::store::{
config::AdapterConfig,
error::{Error as StoreError, Result as StoreResult},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum FusionMethod {
#[default]
Rrf,
Linear,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum NormalisationMethod {
#[default]
MinMax,
ZScore,
None,
}
#[derive(Debug, Clone)]
pub struct HybridConfig {
pub fusion: FusionMethod,
pub bm25_weight: f32,
pub vector_weight: f32,
pub normalisation: NormalisationMethod,
pub rrf_k: u32,
pub bm25_candidates: usize,
pub vector_candidates: usize,
pub min_similarity: Option<f32>,
}
impl Default for HybridConfig {
fn default() -> Self {
Self {
fusion: FusionMethod::Rrf,
bm25_weight: 0.5,
vector_weight: 0.5,
normalisation: NormalisationMethod::MinMax,
rrf_k: 60,
bm25_candidates: 200,
vector_candidates: 200,
min_similarity: None,
}
}
}
impl AdapterConfig for HybridConfig {
fn adapter_name(&self) -> &'static str {
"hybrid"
}
fn validate(&self) -> StoreResult<()> {
if !(0.0..=1.0).contains(&self.bm25_weight) {
return Err(StoreError::config("bm25_weight must be in [0.0, 1.0]"));
}
if !(0.0..=1.0).contains(&self.vector_weight) {
return Err(StoreError::config("vector_weight must be in [0.0, 1.0]"));
}
if self.rrf_k == 0 {
return Err(StoreError::config("rrf_k must be greater than zero"));
}
if self.bm25_candidates == 0 {
return Err(StoreError::config("bm25_candidates must be greater than zero"));
}
if self.vector_candidates == 0 {
return Err(StoreError::config("vector_candidates must be greater than zero"));
}
if let Some(min) = self.min_similarity
&& !(0.0..=1.0).contains(&min) {
return Err(StoreError::config("min_similarity must be in [0.0, 1.0]"));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_fusion_is_rrf() {
assert_eq!(HybridConfig::default().fusion, FusionMethod::Rrf);
}
#[test]
fn default_weights_are_equal() {
let c = HybridConfig::default();
assert!((c.bm25_weight - 0.5).abs() < 1e-6);
assert!((c.vector_weight - 0.5).abs() < 1e-6);
}
#[test]
fn default_rrf_k_is_60() {
assert_eq!(HybridConfig::default().rrf_k, 60);
}
#[test]
fn default_candidates_are_200() {
let c = HybridConfig::default();
assert_eq!(c.bm25_candidates, 200);
assert_eq!(c.vector_candidates, 200);
}
#[test]
fn adapter_name_is_hybrid() {
assert_eq!(HybridConfig::default().adapter_name(), "hybrid");
}
#[test]
fn validate_passes_for_valid_config() {
assert!(HybridConfig::default().validate().is_ok());
}
#[test]
fn validate_fails_for_bm25_weight_above_one() {
let c = HybridConfig { bm25_weight: 1.1, ..Default::default() };
assert!(c.validate().is_err());
}
#[test]
fn validate_fails_for_bm25_weight_below_zero() {
let c = HybridConfig { bm25_weight: -0.1, ..Default::default() };
assert!(c.validate().is_err());
}
#[test]
fn validate_fails_for_vector_weight_above_one() {
let c = HybridConfig { vector_weight: 1.5, ..Default::default() };
assert!(c.validate().is_err());
}
#[test]
fn validate_fails_for_zero_rrf_k() {
let c = HybridConfig { rrf_k: 0, ..Default::default() };
assert!(c.validate().is_err());
}
#[test]
fn validate_fails_for_zero_bm25_candidates() {
let c = HybridConfig { bm25_candidates: 0, ..Default::default() };
assert!(c.validate().is_err());
}
#[test]
fn validate_fails_for_zero_vector_candidates() {
let c = HybridConfig { vector_candidates: 0, ..Default::default() };
assert!(c.validate().is_err());
}
#[test]
fn validate_fails_for_min_similarity_above_one() {
let c = HybridConfig { min_similarity: Some(1.5), ..Default::default() };
assert!(c.validate().is_err());
}
#[test]
fn validate_passes_with_valid_min_similarity() {
let c = HybridConfig { min_similarity: Some(0.7), ..Default::default() };
assert!(c.validate().is_ok());
}
#[test]
fn normalisation_method_default_is_min_max() {
assert_eq!(HybridConfig::default().normalisation, NormalisationMethod::MinMax);
}
}