use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum FusionStrategyType {
#[default]
Rrf,
Weighted,
Maximum,
Rsf,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct FusionClause {
pub strategy: FusionStrategyType,
pub k: Option<u32>,
pub vector_weight: Option<f64>,
pub graph_weight: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dense_weight: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub sparse_weight: Option<f32>,
}
impl Default for FusionClause {
fn default() -> Self {
Self {
strategy: FusionStrategyType::Rrf,
k: Some(60),
vector_weight: None,
graph_weight: None,
dense_weight: None,
sparse_weight: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct FusionConfig {
pub strategy: String,
pub params: std::collections::HashMap<String, f64>,
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
strategy: "rrf".to_string(),
params: std::collections::HashMap::new(),
}
}
}
impl FusionConfig {
#[must_use]
pub fn rrf() -> Self {
let mut params = std::collections::HashMap::new();
params.insert("k".to_string(), 60.0);
Self {
strategy: "rrf".to_string(),
params,
}
}
#[must_use]
pub fn weighted(avg_weight: f64, max_weight: f64, hit_weight: f64) -> Self {
assert!(
avg_weight >= 0.0 && max_weight >= 0.0 && hit_weight >= 0.0,
"FusionConfig::weighted: all weights must be non-negative, got avg={}, max={}, hit={}",
avg_weight,
max_weight,
hit_weight
);
let sum = avg_weight + max_weight + hit_weight;
assert!(
(sum - 1.0).abs() < 0.001,
"FusionConfig::weighted: weights must sum to 1.0, got sum={}",
sum
);
let mut params = std::collections::HashMap::new();
params.insert("avg_weight".to_string(), avg_weight);
params.insert("max_weight".to_string(), max_weight);
params.insert("hit_weight".to_string(), hit_weight);
Self {
strategy: "weighted".to_string(),
params,
}
}
}