use super::types::DocumentScore;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub struct MultimodalFusion {
config: FusionConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusionConfig {
pub default_strategy: FusionStrategy,
pub score_normalization: NormalizationMethod,
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
default_strategy: FusionStrategy::RankFusion,
score_normalization: NormalizationMethod::MinMax,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FusionStrategy {
Weighted { weights: Vec<f64> },
Sequential { order: Vec<Modality> },
Cascade { thresholds: Vec<f64> },
RankFusion,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Modality {
Text,
Vector,
Spatial,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum NormalizationMethod {
MinMax,
ZScore,
Sigmoid,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusedResult {
pub uri: String,
pub scores: HashMap<Modality, f64>,
pub total_score: f64,
}
impl FusedResult {
pub fn new(uri: String) -> Self {
Self {
uri,
scores: HashMap::new(),
total_score: 0.0,
}
}
pub fn add_score(&mut self, modality: Modality, score: f64) {
*self.scores.entry(modality).or_insert(0.0) += score;
}
pub fn calculate_total(&mut self) {
self.total_score = self.scores.values().sum();
}
pub fn get_score(&self, modality: Modality) -> Option<f64> {
self.scores.get(&modality).copied()
}
}
impl MultimodalFusion {
pub fn new(config: FusionConfig) -> Self {
Self { config }
}
pub fn fuse(
&self,
text_results: &[DocumentScore],
vector_results: &[DocumentScore],
spatial_results: &[DocumentScore],
strategy: Option<FusionStrategy>,
) -> Result<Vec<FusedResult>> {
let strat = strategy.unwrap_or_else(|| self.config.default_strategy.clone());
match strat {
FusionStrategy::Weighted { weights } => {
self.fuse_weighted(text_results, vector_results, spatial_results, &weights)
}
FusionStrategy::Sequential { order } => {
self.fuse_sequential(text_results, vector_results, spatial_results, &order)
}
FusionStrategy::Cascade { thresholds } => {
self.fuse_cascade(text_results, vector_results, spatial_results, &thresholds)
}
FusionStrategy::RankFusion => {
self.fuse_rank(text_results, vector_results, spatial_results)
}
}
}
fn fuse_weighted(
&self,
text: &[DocumentScore],
vector: &[DocumentScore],
spatial: &[DocumentScore],
weights: &[f64],
) -> Result<Vec<FusedResult>> {
if weights.len() != 3 {
anyhow::bail!("Weighted fusion requires exactly 3 weights (text, vector, spatial)");
}
let text_norm = self.normalize_scores(text)?;
let vector_norm = self.normalize_scores(vector)?;
let spatial_norm = self.normalize_scores(spatial)?;
let mut combined: HashMap<String, FusedResult> = HashMap::new();
for (result, score) in text.iter().zip(text_norm.iter()) {
combined
.entry(result.doc_id.clone())
.or_insert_with(|| FusedResult::new(result.doc_id.clone()))
.add_score(Modality::Text, score * weights[0]);
}
for (result, score) in vector.iter().zip(vector_norm.iter()) {
combined
.entry(result.doc_id.clone())
.or_insert_with(|| FusedResult::new(result.doc_id.clone()))
.add_score(Modality::Vector, score * weights[1]);
}
for (result, score) in spatial.iter().zip(spatial_norm.iter()) {
combined
.entry(result.doc_id.clone())
.or_insert_with(|| FusedResult::new(result.doc_id.clone()))
.add_score(Modality::Spatial, score * weights[2]);
}
let mut results: Vec<FusedResult> = combined
.into_values()
.map(|mut r| {
r.calculate_total();
r
})
.collect();
results.sort_by(|a, b| {
b.total_score
.partial_cmp(&a.total_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
fn fuse_sequential(
&self,
text: &[DocumentScore],
vector: &[DocumentScore],
spatial: &[DocumentScore],
order: &[Modality],
) -> Result<Vec<FusedResult>> {
if order.len() < 2 {
anyhow::bail!("Sequential fusion requires at least 2 modalities in order");
}
let filter_results = match order[0] {
Modality::Text => text,
Modality::Vector => vector,
Modality::Spatial => spatial,
};
let candidates: HashMap<String, ()> = filter_results
.iter()
.map(|r| (r.doc_id.clone(), ()))
.collect();
let rank_results = match order[1] {
Modality::Text => text,
Modality::Vector => vector,
Modality::Spatial => spatial,
};
let rank_norm = self.normalize_scores(rank_results)?;
let mut results: Vec<FusedResult> = rank_results
.iter()
.zip(rank_norm.iter())
.filter(|(r, _)| candidates.contains_key(&r.doc_id))
.map(|(r, score)| {
let mut result = FusedResult::new(r.doc_id.clone());
result.add_score(order[1], *score);
result.calculate_total();
result
})
.collect();
results.sort_by(|a, b| {
b.total_score
.partial_cmp(&a.total_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
fn fuse_cascade(
&self,
text: &[DocumentScore],
vector: &[DocumentScore],
spatial: &[DocumentScore],
thresholds: &[f64],
) -> Result<Vec<FusedResult>> {
if thresholds.len() != 3 {
anyhow::bail!("Cascade fusion requires exactly 3 thresholds (text, vector, spatial)");
}
let text_norm = self.normalize_scores(text)?;
let mut candidates: HashMap<String, f64> = text
.iter()
.zip(text_norm.iter())
.filter(|(_, score)| **score >= thresholds[0])
.map(|(r, score)| (r.doc_id.clone(), *score))
.collect();
if candidates.is_empty() {
return Ok(Vec::new());
}
let vector_norm = self.normalize_scores(vector)?;
let vector_map: HashMap<String, f64> = vector
.iter()
.zip(vector_norm.iter())
.filter(|(r, score)| candidates.contains_key(&r.doc_id) && **score >= thresholds[1])
.map(|(r, score)| (r.doc_id.clone(), *score))
.collect();
candidates.retain(|uri, _| vector_map.contains_key(uri));
if candidates.is_empty() {
return Ok(Vec::new());
}
let spatial_norm = self.normalize_scores(spatial)?;
let mut results: Vec<FusedResult> = spatial
.iter()
.zip(spatial_norm.iter())
.filter(|(r, score)| candidates.contains_key(&r.doc_id) && **score >= thresholds[2])
.map(|(r, score)| {
let mut result = FusedResult::new(r.doc_id.clone());
result.add_score(Modality::Spatial, *score);
if let Some(&text_score) = candidates.get(&r.doc_id) {
result.add_score(Modality::Text, text_score);
}
if let Some(&vec_score) = vector_map.get(&r.doc_id) {
result.add_score(Modality::Vector, vec_score);
}
result.calculate_total();
result
})
.collect();
results.sort_by(|a, b| {
b.total_score
.partial_cmp(&a.total_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
fn fuse_rank(
&self,
text: &[DocumentScore],
vector: &[DocumentScore],
spatial: &[DocumentScore],
) -> Result<Vec<FusedResult>> {
const K: f64 = 60.0;
let mut rrf_scores: HashMap<String, f64> = HashMap::new();
for (rank, result) in text.iter().enumerate() {
*rrf_scores.entry(result.doc_id.clone()).or_insert(0.0) +=
1.0 / (K + rank as f64 + 1.0);
}
for (rank, result) in vector.iter().enumerate() {
*rrf_scores.entry(result.doc_id.clone()).or_insert(0.0) +=
1.0 / (K + rank as f64 + 1.0);
}
for (rank, result) in spatial.iter().enumerate() {
*rrf_scores.entry(result.doc_id.clone()).or_insert(0.0) +=
1.0 / (K + rank as f64 + 1.0);
}
let mut results: Vec<FusedResult> = rrf_scores
.into_iter()
.map(|(uri, score)| {
let mut result = FusedResult::new(uri);
result.total_score = score;
result.scores.insert(Modality::Text, score);
result
})
.collect();
results.sort_by(|a, b| {
b.total_score
.partial_cmp(&a.total_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
pub fn normalize_scores(&self, results: &[DocumentScore]) -> Result<Vec<f64>> {
if results.is_empty() {
return Ok(Vec::new());
}
let scores: Vec<f64> = results.iter().map(|r| r.score as f64).collect();
match self.config.score_normalization {
NormalizationMethod::MinMax => self.min_max_normalize(&scores),
NormalizationMethod::ZScore => self.z_score_normalize(&scores),
NormalizationMethod::Sigmoid => self.sigmoid_normalize(&scores),
}
}
fn min_max_normalize(&self, scores: &[f64]) -> Result<Vec<f64>> {
if scores.is_empty() {
return Ok(Vec::new());
}
let min_score = scores
.iter()
.copied()
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(0.0);
let max_score = scores
.iter()
.copied()
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(1.0);
let range = (max_score - min_score).max(1e-10);
Ok(scores.iter().map(|&s| (s - min_score) / range).collect())
}
fn z_score_normalize(&self, scores: &[f64]) -> Result<Vec<f64>> {
if scores.is_empty() {
return Ok(Vec::new());
}
let n = scores.len() as f64;
let mean = scores.iter().sum::<f64>() / n;
let variance = scores.iter().map(|&s| (s - mean).powi(2)).sum::<f64>() / n;
let std = variance.sqrt().max(1e-10);
Ok(scores.iter().map(|&s| (s - mean) / std).collect())
}
fn sigmoid_normalize(&self, scores: &[f64]) -> Result<Vec<f64>> {
Ok(scores.iter().map(|&s| 1.0 / (1.0 + (-s).exp())).collect())
}
pub fn config(&self) -> &FusionConfig {
&self.config
}
pub fn set_config(&mut self, config: FusionConfig) {
self.config = config;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_results() -> (Vec<DocumentScore>, Vec<DocumentScore>, Vec<DocumentScore>) {
let text = vec![
DocumentScore {
doc_id: "doc1".to_string(),
score: 10.0,
rank: 0,
},
DocumentScore {
doc_id: "doc2".to_string(),
score: 8.0,
rank: 1,
},
DocumentScore {
doc_id: "doc3".to_string(),
score: 5.0,
rank: 2,
},
];
let vector = vec![
DocumentScore {
doc_id: "doc2".to_string(),
score: 0.95,
rank: 0,
},
DocumentScore {
doc_id: "doc4".to_string(),
score: 0.90,
rank: 1,
},
DocumentScore {
doc_id: "doc1".to_string(),
score: 0.85,
rank: 2,
},
];
let spatial = vec![
DocumentScore {
doc_id: "doc3".to_string(),
score: 0.99,
rank: 0,
},
DocumentScore {
doc_id: "doc1".to_string(),
score: 0.92,
rank: 1,
},
DocumentScore {
doc_id: "doc5".to_string(),
score: 0.88,
rank: 2,
},
];
(text, vector, spatial)
}
#[test]
fn test_weighted_fusion() -> Result<()> {
let (text, vector, spatial) = create_test_results();
let fusion = MultimodalFusion::new(FusionConfig::default());
let weights = vec![0.4, 0.4, 0.2]; let strategy = FusionStrategy::Weighted { weights };
let results = fusion.fuse(&text, &vector, &spatial, Some(strategy))?;
assert!(!results.is_empty());
assert!(results[0].total_score > 0.0);
let doc1 = results
.iter()
.find(|r| r.uri == "doc1")
.expect("doc1 should be found");
assert!(doc1.scores.len() == 3);
Ok(())
}
#[test]
fn test_sequential_fusion() -> Result<()> {
let (text, vector, spatial) = create_test_results();
let fusion = MultimodalFusion::new(FusionConfig::default());
let order = vec![Modality::Text, Modality::Vector];
let strategy = FusionStrategy::Sequential { order };
let results = fusion.fuse(&text, &vector, &spatial, Some(strategy))?;
assert!(!results.is_empty());
assert!(results
.iter()
.all(|r| ["doc1", "doc2", "doc3"].contains(&r.uri.as_str())));
Ok(())
}
#[test]
fn test_cascade_fusion() -> Result<()> {
let (text, vector, spatial) = create_test_results();
let fusion = MultimodalFusion::new(FusionConfig::default());
let thresholds = vec![0.0, 0.0, 0.0]; let strategy = FusionStrategy::Cascade { thresholds };
let results = fusion.fuse(&text, &vector, &spatial, Some(strategy))?;
assert!(!results.is_empty());
if let Some(doc1) = results.iter().find(|r| r.uri == "doc1") {
assert!(doc1.scores.len() >= 2);
}
Ok(())
}
#[test]
fn test_rank_fusion() -> Result<()> {
let (text, vector, spatial) = create_test_results();
let fusion = MultimodalFusion::new(FusionConfig::default());
let strategy = FusionStrategy::RankFusion;
let results = fusion.fuse(&text, &vector, &spatial, Some(strategy))?;
assert!(!results.is_empty());
let doc1 = results
.iter()
.find(|r| r.uri == "doc1")
.expect("doc1 should be found");
let doc4 = results
.iter()
.find(|r| r.uri == "doc4")
.expect("doc4 should be found");
assert!(doc1.total_score > doc4.total_score);
Ok(())
}
#[test]
fn test_min_max_normalization() -> Result<()> {
let fusion = MultimodalFusion::new(FusionConfig::default());
let scores = vec![10.0, 5.0, 0.0];
let normalized = fusion.min_max_normalize(&scores)?;
assert!((normalized[0] - 1.0).abs() < 1e-6);
assert!((normalized[1] - 0.5).abs() < 1e-6);
assert!((normalized[2] - 0.0).abs() < 1e-6);
Ok(())
}
#[test]
fn test_z_score_normalization() -> Result<()> {
let fusion = MultimodalFusion::new(FusionConfig::default());
let scores = vec![10.0, 5.0, 0.0];
let normalized = fusion.z_score_normalize(&scores)?;
let mean: f64 = normalized.iter().sum::<f64>() / normalized.len() as f64;
assert!(mean.abs() < 1e-6);
Ok(())
}
#[test]
fn test_sigmoid_normalization() -> Result<()> {
let fusion = MultimodalFusion::new(FusionConfig::default());
let scores = vec![0.0, 1.0, -1.0];
let normalized = fusion.sigmoid_normalize(&scores)?;
assert!((normalized[0] - 0.5).abs() < 1e-6);
assert!(normalized.iter().all(|&s| s > 0.0 && s < 1.0));
Ok(())
}
#[test]
fn test_empty_results() -> Result<()> {
let fusion = MultimodalFusion::new(FusionConfig::default());
let empty: Vec<DocumentScore> = Vec::new();
let strategy = FusionStrategy::RankFusion;
let results = fusion.fuse(&empty, &empty, &empty, Some(strategy))?;
assert!(results.is_empty());
Ok(())
}
#[test]
fn test_fused_result_operations() {
let mut result = FusedResult::new("test_doc".to_string());
result.add_score(Modality::Text, 0.5);
result.add_score(Modality::Vector, 0.3);
result.add_score(Modality::Spatial, 0.2);
assert_eq!(result.get_score(Modality::Text), Some(0.5));
assert_eq!(result.get_score(Modality::Vector), Some(0.3));
assert_eq!(result.get_score(Modality::Spatial), Some(0.2));
result.calculate_total();
assert!((result.total_score - 1.0).abs() < 1e-6);
}
}