1use super::types::DocumentScore;
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12pub struct MultimodalFusion {
14 config: FusionConfig,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct FusionConfig {
20 pub default_strategy: FusionStrategy,
22 pub score_normalization: NormalizationMethod,
24}
25
26impl Default for FusionConfig {
27 fn default() -> Self {
28 Self {
29 default_strategy: FusionStrategy::RankFusion,
30 score_normalization: NormalizationMethod::MinMax,
31 }
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub enum FusionStrategy {
38 Weighted { weights: Vec<f64> },
40 Sequential { order: Vec<Modality> },
42 Cascade { thresholds: Vec<f64> },
44 RankFusion,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
50pub enum Modality {
51 Text,
53 Vector,
55 Spatial,
57}
58
59#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
61pub enum NormalizationMethod {
62 MinMax,
64 ZScore,
66 Sigmoid,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct FusedResult {
73 pub uri: String,
75 pub scores: HashMap<Modality, f64>,
77 pub total_score: f64,
79}
80
81impl FusedResult {
82 pub fn new(uri: String) -> Self {
84 Self {
85 uri,
86 scores: HashMap::new(),
87 total_score: 0.0,
88 }
89 }
90
91 pub fn add_score(&mut self, modality: Modality, score: f64) {
93 *self.scores.entry(modality).or_insert(0.0) += score;
94 }
95
96 pub fn calculate_total(&mut self) {
98 self.total_score = self.scores.values().sum();
99 }
100
101 pub fn get_score(&self, modality: Modality) -> Option<f64> {
103 self.scores.get(&modality).copied()
104 }
105}
106
107impl MultimodalFusion {
108 pub fn new(config: FusionConfig) -> Self {
110 Self { config }
111 }
112
113 pub fn fuse(
124 &self,
125 text_results: &[DocumentScore],
126 vector_results: &[DocumentScore],
127 spatial_results: &[DocumentScore],
128 strategy: Option<FusionStrategy>,
129 ) -> Result<Vec<FusedResult>> {
130 let strat = strategy.unwrap_or_else(|| self.config.default_strategy.clone());
131
132 match strat {
133 FusionStrategy::Weighted { weights } => {
134 self.fuse_weighted(text_results, vector_results, spatial_results, &weights)
135 }
136 FusionStrategy::Sequential { order } => {
137 self.fuse_sequential(text_results, vector_results, spatial_results, &order)
138 }
139 FusionStrategy::Cascade { thresholds } => {
140 self.fuse_cascade(text_results, vector_results, spatial_results, &thresholds)
141 }
142 FusionStrategy::RankFusion => {
143 self.fuse_rank(text_results, vector_results, spatial_results)
144 }
145 }
146 }
147
148 fn fuse_weighted(
152 &self,
153 text: &[DocumentScore],
154 vector: &[DocumentScore],
155 spatial: &[DocumentScore],
156 weights: &[f64],
157 ) -> Result<Vec<FusedResult>> {
158 if weights.len() != 3 {
159 anyhow::bail!("Weighted fusion requires exactly 3 weights (text, vector, spatial)");
160 }
161
162 let text_norm = self.normalize_scores(text)?;
164 let vector_norm = self.normalize_scores(vector)?;
165 let spatial_norm = self.normalize_scores(spatial)?;
166
167 let mut combined: HashMap<String, FusedResult> = HashMap::new();
169
170 for (result, score) in text.iter().zip(text_norm.iter()) {
172 combined
173 .entry(result.doc_id.clone())
174 .or_insert_with(|| FusedResult::new(result.doc_id.clone()))
175 .add_score(Modality::Text, score * weights[0]);
176 }
177
178 for (result, score) in vector.iter().zip(vector_norm.iter()) {
180 combined
181 .entry(result.doc_id.clone())
182 .or_insert_with(|| FusedResult::new(result.doc_id.clone()))
183 .add_score(Modality::Vector, score * weights[1]);
184 }
185
186 for (result, score) in spatial.iter().zip(spatial_norm.iter()) {
188 combined
189 .entry(result.doc_id.clone())
190 .or_insert_with(|| FusedResult::new(result.doc_id.clone()))
191 .add_score(Modality::Spatial, score * weights[2]);
192 }
193
194 let mut results: Vec<FusedResult> = combined
196 .into_values()
197 .map(|mut r| {
198 r.calculate_total();
199 r
200 })
201 .collect();
202
203 results.sort_by(|a, b| {
204 b.total_score
205 .partial_cmp(&a.total_score)
206 .unwrap_or(std::cmp::Ordering::Equal)
207 });
208
209 Ok(results)
210 }
211
212 fn fuse_sequential(
216 &self,
217 text: &[DocumentScore],
218 vector: &[DocumentScore],
219 spatial: &[DocumentScore],
220 order: &[Modality],
221 ) -> Result<Vec<FusedResult>> {
222 if order.len() < 2 {
223 anyhow::bail!("Sequential fusion requires at least 2 modalities in order");
224 }
225
226 let filter_results = match order[0] {
228 Modality::Text => text,
229 Modality::Vector => vector,
230 Modality::Spatial => spatial,
231 };
232
233 let candidates: HashMap<String, ()> = filter_results
235 .iter()
236 .map(|r| (r.doc_id.clone(), ()))
237 .collect();
238
239 let rank_results = match order[1] {
241 Modality::Text => text,
242 Modality::Vector => vector,
243 Modality::Spatial => spatial,
244 };
245
246 let rank_norm = self.normalize_scores(rank_results)?;
248
249 let mut results: Vec<FusedResult> = rank_results
251 .iter()
252 .zip(rank_norm.iter())
253 .filter(|(r, _)| candidates.contains_key(&r.doc_id))
254 .map(|(r, score)| {
255 let mut result = FusedResult::new(r.doc_id.clone());
256 result.add_score(order[1], *score);
257 result.calculate_total();
258 result
259 })
260 .collect();
261
262 results.sort_by(|a, b| {
263 b.total_score
264 .partial_cmp(&a.total_score)
265 .unwrap_or(std::cmp::Ordering::Equal)
266 });
267
268 Ok(results)
269 }
270
271 fn fuse_cascade(
275 &self,
276 text: &[DocumentScore],
277 vector: &[DocumentScore],
278 spatial: &[DocumentScore],
279 thresholds: &[f64],
280 ) -> Result<Vec<FusedResult>> {
281 if thresholds.len() != 3 {
282 anyhow::bail!("Cascade fusion requires exactly 3 thresholds (text, vector, spatial)");
283 }
284
285 let text_norm = self.normalize_scores(text)?;
287 let mut candidates: HashMap<String, f64> = text
288 .iter()
289 .zip(text_norm.iter())
290 .filter(|(_, score)| **score >= thresholds[0])
291 .map(|(r, score)| (r.doc_id.clone(), *score))
292 .collect();
293
294 if candidates.is_empty() {
295 return Ok(Vec::new());
296 }
297
298 let vector_norm = self.normalize_scores(vector)?;
300 let vector_map: HashMap<String, f64> = vector
301 .iter()
302 .zip(vector_norm.iter())
303 .filter(|(r, score)| candidates.contains_key(&r.doc_id) && **score >= thresholds[1])
304 .map(|(r, score)| (r.doc_id.clone(), *score))
305 .collect();
306
307 candidates.retain(|uri, _| vector_map.contains_key(uri));
309
310 if candidates.is_empty() {
311 return Ok(Vec::new());
312 }
313
314 let spatial_norm = self.normalize_scores(spatial)?;
316 let mut results: Vec<FusedResult> = spatial
317 .iter()
318 .zip(spatial_norm.iter())
319 .filter(|(r, score)| candidates.contains_key(&r.doc_id) && **score >= thresholds[2])
320 .map(|(r, score)| {
321 let mut result = FusedResult::new(r.doc_id.clone());
322 result.add_score(Modality::Spatial, *score);
323 if let Some(&text_score) = candidates.get(&r.doc_id) {
324 result.add_score(Modality::Text, text_score);
325 }
326 if let Some(&vec_score) = vector_map.get(&r.doc_id) {
327 result.add_score(Modality::Vector, vec_score);
328 }
329 result.calculate_total();
330 result
331 })
332 .collect();
333
334 results.sort_by(|a, b| {
335 b.total_score
336 .partial_cmp(&a.total_score)
337 .unwrap_or(std::cmp::Ordering::Equal)
338 });
339
340 Ok(results)
341 }
342
343 fn fuse_rank(
348 &self,
349 text: &[DocumentScore],
350 vector: &[DocumentScore],
351 spatial: &[DocumentScore],
352 ) -> Result<Vec<FusedResult>> {
353 const K: f64 = 60.0; let mut rrf_scores: HashMap<String, f64> = HashMap::new();
356
357 for (rank, result) in text.iter().enumerate() {
359 *rrf_scores.entry(result.doc_id.clone()).or_insert(0.0) +=
360 1.0 / (K + rank as f64 + 1.0);
361 }
362
363 for (rank, result) in vector.iter().enumerate() {
365 *rrf_scores.entry(result.doc_id.clone()).or_insert(0.0) +=
366 1.0 / (K + rank as f64 + 1.0);
367 }
368
369 for (rank, result) in spatial.iter().enumerate() {
371 *rrf_scores.entry(result.doc_id.clone()).or_insert(0.0) +=
372 1.0 / (K + rank as f64 + 1.0);
373 }
374
375 let mut results: Vec<FusedResult> = rrf_scores
376 .into_iter()
377 .map(|(uri, score)| {
378 let mut result = FusedResult::new(uri);
379 result.total_score = score;
380 result.scores.insert(Modality::Text, score);
382 result
383 })
384 .collect();
385
386 results.sort_by(|a, b| {
387 b.total_score
388 .partial_cmp(&a.total_score)
389 .unwrap_or(std::cmp::Ordering::Equal)
390 });
391
392 Ok(results)
393 }
394
395 pub fn normalize_scores(&self, results: &[DocumentScore]) -> Result<Vec<f64>> {
397 if results.is_empty() {
398 return Ok(Vec::new());
399 }
400
401 let scores: Vec<f64> = results.iter().map(|r| r.score as f64).collect();
402
403 match self.config.score_normalization {
404 NormalizationMethod::MinMax => self.min_max_normalize(&scores),
405 NormalizationMethod::ZScore => self.z_score_normalize(&scores),
406 NormalizationMethod::Sigmoid => self.sigmoid_normalize(&scores),
407 }
408 }
409
410 fn min_max_normalize(&self, scores: &[f64]) -> Result<Vec<f64>> {
412 if scores.is_empty() {
413 return Ok(Vec::new());
414 }
415
416 let min_score = scores
417 .iter()
418 .copied()
419 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
420 .unwrap_or(0.0);
421
422 let max_score = scores
423 .iter()
424 .copied()
425 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
426 .unwrap_or(1.0);
427
428 let range = (max_score - min_score).max(1e-10); Ok(scores.iter().map(|&s| (s - min_score) / range).collect())
431 }
432
433 fn z_score_normalize(&self, scores: &[f64]) -> Result<Vec<f64>> {
435 if scores.is_empty() {
436 return Ok(Vec::new());
437 }
438
439 let n = scores.len() as f64;
440 let mean = scores.iter().sum::<f64>() / n;
441
442 let variance = scores.iter().map(|&s| (s - mean).powi(2)).sum::<f64>() / n;
443 let std = variance.sqrt().max(1e-10); Ok(scores.iter().map(|&s| (s - mean) / std).collect())
446 }
447
448 fn sigmoid_normalize(&self, scores: &[f64]) -> Result<Vec<f64>> {
450 Ok(scores.iter().map(|&s| 1.0 / (1.0 + (-s).exp())).collect())
451 }
452
453 pub fn config(&self) -> &FusionConfig {
455 &self.config
456 }
457
458 pub fn set_config(&mut self, config: FusionConfig) {
460 self.config = config;
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467
468 fn create_test_results() -> (Vec<DocumentScore>, Vec<DocumentScore>, Vec<DocumentScore>) {
469 let text = vec![
470 DocumentScore {
471 doc_id: "doc1".to_string(),
472 score: 10.0,
473 rank: 0,
474 },
475 DocumentScore {
476 doc_id: "doc2".to_string(),
477 score: 8.0,
478 rank: 1,
479 },
480 DocumentScore {
481 doc_id: "doc3".to_string(),
482 score: 5.0,
483 rank: 2,
484 },
485 ];
486
487 let vector = vec![
488 DocumentScore {
489 doc_id: "doc2".to_string(),
490 score: 0.95,
491 rank: 0,
492 },
493 DocumentScore {
494 doc_id: "doc4".to_string(),
495 score: 0.90,
496 rank: 1,
497 },
498 DocumentScore {
499 doc_id: "doc1".to_string(),
500 score: 0.85,
501 rank: 2,
502 },
503 ];
504
505 let spatial = vec![
506 DocumentScore {
507 doc_id: "doc3".to_string(),
508 score: 0.99,
509 rank: 0,
510 },
511 DocumentScore {
512 doc_id: "doc1".to_string(),
513 score: 0.92,
514 rank: 1,
515 },
516 DocumentScore {
517 doc_id: "doc5".to_string(),
518 score: 0.88,
519 rank: 2,
520 },
521 ];
522
523 (text, vector, spatial)
524 }
525
526 #[test]
527 fn test_weighted_fusion() {
528 let (text, vector, spatial) = create_test_results();
529 let fusion = MultimodalFusion::new(FusionConfig::default());
530
531 let weights = vec![0.4, 0.4, 0.2]; let strategy = FusionStrategy::Weighted { weights };
533
534 let results = fusion
535 .fuse(&text, &vector, &spatial, Some(strategy))
536 .unwrap();
537
538 assert!(!results.is_empty());
539 assert!(results[0].total_score > 0.0);
540 let doc1 = results.iter().find(|r| r.uri == "doc1").unwrap();
542 assert!(doc1.scores.len() == 3);
543 }
544
545 #[test]
546 fn test_sequential_fusion() {
547 let (text, vector, spatial) = create_test_results();
548 let fusion = MultimodalFusion::new(FusionConfig::default());
549
550 let order = vec![Modality::Text, Modality::Vector];
551 let strategy = FusionStrategy::Sequential { order };
552
553 let results = fusion
554 .fuse(&text, &vector, &spatial, Some(strategy))
555 .unwrap();
556
557 assert!(!results.is_empty());
558 assert!(results
560 .iter()
561 .all(|r| ["doc1", "doc2", "doc3"].contains(&r.uri.as_str())));
562 }
563
564 #[test]
565 fn test_cascade_fusion() {
566 let (text, vector, spatial) = create_test_results();
567 let fusion = MultimodalFusion::new(FusionConfig::default());
568
569 let thresholds = vec![0.0, 0.0, 0.0]; let strategy = FusionStrategy::Cascade { thresholds };
571
572 let results = fusion
573 .fuse(&text, &vector, &spatial, Some(strategy))
574 .unwrap();
575
576 assert!(!results.is_empty());
577 if let Some(doc1) = results.iter().find(|r| r.uri == "doc1") {
579 assert!(doc1.scores.len() >= 2);
580 }
581 }
582
583 #[test]
584 fn test_rank_fusion() {
585 let (text, vector, spatial) = create_test_results();
586 let fusion = MultimodalFusion::new(FusionConfig::default());
587
588 let strategy = FusionStrategy::RankFusion;
589 let results = fusion
590 .fuse(&text, &vector, &spatial, Some(strategy))
591 .unwrap();
592
593 assert!(!results.is_empty());
594 let doc1 = results.iter().find(|r| r.uri == "doc1").unwrap();
596 let doc4 = results.iter().find(|r| r.uri == "doc4").unwrap();
598 assert!(doc1.total_score > doc4.total_score);
600 }
601
602 #[test]
603 fn test_min_max_normalization() {
604 let fusion = MultimodalFusion::new(FusionConfig::default());
605 let scores = vec![10.0, 5.0, 0.0];
606
607 let normalized = fusion.min_max_normalize(&scores).unwrap();
608
609 assert!((normalized[0] - 1.0).abs() < 1e-6);
610 assert!((normalized[1] - 0.5).abs() < 1e-6);
611 assert!((normalized[2] - 0.0).abs() < 1e-6);
612 }
613
614 #[test]
615 fn test_z_score_normalization() {
616 let fusion = MultimodalFusion::new(FusionConfig::default());
617 let scores = vec![10.0, 5.0, 0.0];
618
619 let normalized = fusion.z_score_normalize(&scores).unwrap();
620
621 let mean: f64 = normalized.iter().sum::<f64>() / normalized.len() as f64;
624 assert!(mean.abs() < 1e-6);
625 }
626
627 #[test]
628 fn test_sigmoid_normalization() {
629 let fusion = MultimodalFusion::new(FusionConfig::default());
630 let scores = vec![0.0, 1.0, -1.0];
631
632 let normalized = fusion.sigmoid_normalize(&scores).unwrap();
633
634 assert!((normalized[0] - 0.5).abs() < 1e-6);
636 assert!(normalized.iter().all(|&s| s > 0.0 && s < 1.0));
638 }
639
640 #[test]
641 fn test_empty_results() {
642 let fusion = MultimodalFusion::new(FusionConfig::default());
643 let empty: Vec<DocumentScore> = Vec::new();
644
645 let strategy = FusionStrategy::RankFusion;
646 let results = fusion.fuse(&empty, &empty, &empty, Some(strategy)).unwrap();
647
648 assert!(results.is_empty());
649 }
650
651 #[test]
652 fn test_fused_result_operations() {
653 let mut result = FusedResult::new("test_doc".to_string());
654
655 result.add_score(Modality::Text, 0.5);
656 result.add_score(Modality::Vector, 0.3);
657 result.add_score(Modality::Spatial, 0.2);
658
659 assert_eq!(result.get_score(Modality::Text), Some(0.5));
660 assert_eq!(result.get_score(Modality::Vector), Some(0.3));
661 assert_eq!(result.get_score(Modality::Spatial), Some(0.2));
662
663 result.calculate_total();
664 assert!((result.total_score - 1.0).abs() < 1e-6);
665 }
666}