Skip to main content

oxirs_vec/hybrid_search/
multimodal_fusion.rs

1//! Multimodal search fusion for combining text, vector, and spatial search results
2//!
3//! This module provides advanced fusion strategies for combining results from multiple
4//! search modalities: text (keyword/BM25), vector (semantic similarity), and spatial
5//! (geographic queries). It implements four fusion strategies with score normalization.
6
7use super::types::DocumentScore;
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// Multimodal fusion engine for combining text, vector, and spatial search
13pub struct MultimodalFusion {
14    config: FusionConfig,
15}
16
17/// Configuration for multimodal fusion
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct FusionConfig {
20    /// Default fusion strategy to use
21    pub default_strategy: FusionStrategy,
22    /// Score normalization method
23    pub score_normalization: NormalizationMethod,
24}
25
26impl Default for FusionConfig {
27    fn default() -> Self {
28        Self {
29            default_strategy: FusionStrategy::RankFusion,
30            score_normalization: NormalizationMethod::MinMax,
31        }
32    }
33}
34
35/// Fusion strategy for combining multiple modalities
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub enum FusionStrategy {
38    /// Weighted linear combination of normalized scores
39    Weighted { weights: Vec<f64> },
40    /// Sequential filtering: filter with one modality, rank with another
41    Sequential { order: Vec<Modality> },
42    /// Cascade: progressive filtering with thresholds (fast → expensive)
43    Cascade { thresholds: Vec<f64> },
44    /// Reciprocal Rank Fusion (RRF) - position-based fusion
45    RankFusion,
46}
47
48/// Search modality type
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
50pub enum Modality {
51    /// Text/keyword search (BM25, TF-IDF)
52    Text,
53    /// Vector/semantic search (embeddings)
54    Vector,
55    /// Spatial/geographic search (GeoSPARQL)
56    Spatial,
57}
58
59/// Score normalization method
60#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
61pub enum NormalizationMethod {
62    /// Min-max normalization to [0, 1]
63    MinMax,
64    /// Z-score normalization (mean=0, std=1)
65    ZScore,
66    /// Sigmoid normalization to (0, 1)
67    Sigmoid,
68}
69
70/// Result from multimodal fusion
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct FusedResult {
73    /// Resource URI
74    pub uri: String,
75    /// Individual scores per modality
76    pub scores: HashMap<Modality, f64>,
77    /// Final combined score
78    pub total_score: f64,
79}
80
81impl FusedResult {
82    /// Create a new fused result
83    pub fn new(uri: String) -> Self {
84        Self {
85            uri,
86            scores: HashMap::new(),
87            total_score: 0.0,
88        }
89    }
90
91    /// Add a score for a specific modality
92    pub fn add_score(&mut self, modality: Modality, score: f64) {
93        *self.scores.entry(modality).or_insert(0.0) += score;
94    }
95
96    /// Calculate total score from individual scores
97    pub fn calculate_total(&mut self) {
98        self.total_score = self.scores.values().sum();
99    }
100
101    /// Get score for a specific modality
102    pub fn get_score(&self, modality: Modality) -> Option<f64> {
103        self.scores.get(&modality).copied()
104    }
105}
106
107impl MultimodalFusion {
108    /// Create a new multimodal fusion engine with default configuration
109    pub fn new(config: FusionConfig) -> Self {
110        Self { config }
111    }
112
113    /// Fuse results from multiple modalities
114    ///
115    /// # Arguments
116    /// * `text_results` - Results from text/keyword search
117    /// * `vector_results` - Results from vector/semantic search
118    /// * `spatial_results` - Results from spatial/geographic search
119    /// * `strategy` - Optional fusion strategy (uses default if None)
120    ///
121    /// # Returns
122    /// Fused results sorted by combined score (descending)
123    pub fn fuse(
124        &self,
125        text_results: &[DocumentScore],
126        vector_results: &[DocumentScore],
127        spatial_results: &[DocumentScore],
128        strategy: Option<FusionStrategy>,
129    ) -> Result<Vec<FusedResult>> {
130        let strat = strategy.unwrap_or_else(|| self.config.default_strategy.clone());
131
132        match strat {
133            FusionStrategy::Weighted { weights } => {
134                self.fuse_weighted(text_results, vector_results, spatial_results, &weights)
135            }
136            FusionStrategy::Sequential { order } => {
137                self.fuse_sequential(text_results, vector_results, spatial_results, &order)
138            }
139            FusionStrategy::Cascade { thresholds } => {
140                self.fuse_cascade(text_results, vector_results, spatial_results, &thresholds)
141            }
142            FusionStrategy::RankFusion => {
143                self.fuse_rank(text_results, vector_results, spatial_results)
144            }
145        }
146    }
147
148    /// Weighted fusion: Linear combination of normalized scores
149    ///
150    /// Formula: score(d) = w1·norm(text(d)) + w2·norm(vector(d)) + w3·norm(spatial(d))
151    fn fuse_weighted(
152        &self,
153        text: &[DocumentScore],
154        vector: &[DocumentScore],
155        spatial: &[DocumentScore],
156        weights: &[f64],
157    ) -> Result<Vec<FusedResult>> {
158        if weights.len() != 3 {
159            anyhow::bail!("Weighted fusion requires exactly 3 weights (text, vector, spatial)");
160        }
161
162        // Normalize scores to [0, 1]
163        let text_norm = self.normalize_scores(text)?;
164        let vector_norm = self.normalize_scores(vector)?;
165        let spatial_norm = self.normalize_scores(spatial)?;
166
167        // Merge by entity URI
168        let mut combined: HashMap<String, FusedResult> = HashMap::new();
169
170        // Add text scores
171        for (result, score) in text.iter().zip(text_norm.iter()) {
172            combined
173                .entry(result.doc_id.clone())
174                .or_insert_with(|| FusedResult::new(result.doc_id.clone()))
175                .add_score(Modality::Text, score * weights[0]);
176        }
177
178        // Add vector scores
179        for (result, score) in vector.iter().zip(vector_norm.iter()) {
180            combined
181                .entry(result.doc_id.clone())
182                .or_insert_with(|| FusedResult::new(result.doc_id.clone()))
183                .add_score(Modality::Vector, score * weights[1]);
184        }
185
186        // Add spatial scores
187        for (result, score) in spatial.iter().zip(spatial_norm.iter()) {
188            combined
189                .entry(result.doc_id.clone())
190                .or_insert_with(|| FusedResult::new(result.doc_id.clone()))
191                .add_score(Modality::Spatial, score * weights[2]);
192        }
193
194        // Calculate total scores and sort
195        let mut results: Vec<FusedResult> = combined
196            .into_values()
197            .map(|mut r| {
198                r.calculate_total();
199                r
200            })
201            .collect();
202
203        results.sort_by(|a, b| {
204            b.total_score
205                .partial_cmp(&a.total_score)
206                .unwrap_or(std::cmp::Ordering::Equal)
207        });
208
209        Ok(results)
210    }
211
212    /// Sequential fusion: Filter with one modality, rank with another
213    ///
214    /// Example: Filter with text (fast), rank with vector (accurate)
215    fn fuse_sequential(
216        &self,
217        text: &[DocumentScore],
218        vector: &[DocumentScore],
219        spatial: &[DocumentScore],
220        order: &[Modality],
221    ) -> Result<Vec<FusedResult>> {
222        if order.len() < 2 {
223            anyhow::bail!("Sequential fusion requires at least 2 modalities in order");
224        }
225
226        // Get filter results (first modality)
227        let filter_results = match order[0] {
228            Modality::Text => text,
229            Modality::Vector => vector,
230            Modality::Spatial => spatial,
231        };
232
233        // Create candidate set from filter
234        let candidates: HashMap<String, ()> = filter_results
235            .iter()
236            .map(|r| (r.doc_id.clone(), ()))
237            .collect();
238
239        // Get rank results (second modality)
240        let rank_results = match order[1] {
241            Modality::Text => text,
242            Modality::Vector => vector,
243            Modality::Spatial => spatial,
244        };
245
246        // Normalize ranking scores
247        let rank_norm = self.normalize_scores(rank_results)?;
248
249        // Filter and create results
250        let mut results: Vec<FusedResult> = rank_results
251            .iter()
252            .zip(rank_norm.iter())
253            .filter(|(r, _)| candidates.contains_key(&r.doc_id))
254            .map(|(r, score)| {
255                let mut result = FusedResult::new(r.doc_id.clone());
256                result.add_score(order[1], *score);
257                result.calculate_total();
258                result
259            })
260            .collect();
261
262        results.sort_by(|a, b| {
263            b.total_score
264                .partial_cmp(&a.total_score)
265                .unwrap_or(std::cmp::Ordering::Equal)
266        });
267
268        Ok(results)
269    }
270
271    /// Cascade fusion: Progressive filtering (fast → expensive)
272    ///
273    /// Example: Text (threshold 0.5) → Vector (threshold 0.7) → Spatial (threshold 0.8)
274    fn fuse_cascade(
275        &self,
276        text: &[DocumentScore],
277        vector: &[DocumentScore],
278        spatial: &[DocumentScore],
279        thresholds: &[f64],
280    ) -> Result<Vec<FusedResult>> {
281        if thresholds.len() != 3 {
282            anyhow::bail!("Cascade fusion requires exactly 3 thresholds (text, vector, spatial)");
283        }
284
285        // Stage 1: Fast text search with threshold
286        let text_norm = self.normalize_scores(text)?;
287        let mut candidates: HashMap<String, f64> = text
288            .iter()
289            .zip(text_norm.iter())
290            .filter(|(_, score)| **score >= thresholds[0])
291            .map(|(r, score)| (r.doc_id.clone(), *score))
292            .collect();
293
294        if candidates.is_empty() {
295            return Ok(Vec::new());
296        }
297
298        // Stage 2: Vector search on candidates with threshold
299        let vector_norm = self.normalize_scores(vector)?;
300        let vector_map: HashMap<String, f64> = vector
301            .iter()
302            .zip(vector_norm.iter())
303            .filter(|(r, score)| candidates.contains_key(&r.doc_id) && **score >= thresholds[1])
304            .map(|(r, score)| (r.doc_id.clone(), *score))
305            .collect();
306
307        // Keep only candidates that passed vector threshold
308        candidates.retain(|uri, _| vector_map.contains_key(uri));
309
310        if candidates.is_empty() {
311            return Ok(Vec::new());
312        }
313
314        // Stage 3: Expensive spatial search on finalists with threshold
315        let spatial_norm = self.normalize_scores(spatial)?;
316        let mut results: Vec<FusedResult> = spatial
317            .iter()
318            .zip(spatial_norm.iter())
319            .filter(|(r, score)| candidates.contains_key(&r.doc_id) && **score >= thresholds[2])
320            .map(|(r, score)| {
321                let mut result = FusedResult::new(r.doc_id.clone());
322                result.add_score(Modality::Spatial, *score);
323                if let Some(&text_score) = candidates.get(&r.doc_id) {
324                    result.add_score(Modality::Text, text_score);
325                }
326                if let Some(&vec_score) = vector_map.get(&r.doc_id) {
327                    result.add_score(Modality::Vector, vec_score);
328                }
329                result.calculate_total();
330                result
331            })
332            .collect();
333
334        results.sort_by(|a, b| {
335            b.total_score
336                .partial_cmp(&a.total_score)
337                .unwrap_or(std::cmp::Ordering::Equal)
338        });
339
340        Ok(results)
341    }
342
343    /// Reciprocal Rank Fusion (RRF)
344    ///
345    /// Formula: RRF(d) = Σ 1/(K + rank(d))
346    /// where K=60 is a standard constant
347    fn fuse_rank(
348        &self,
349        text: &[DocumentScore],
350        vector: &[DocumentScore],
351        spatial: &[DocumentScore],
352    ) -> Result<Vec<FusedResult>> {
353        const K: f64 = 60.0; // Standard RRF constant
354
355        let mut rrf_scores: HashMap<String, f64> = HashMap::new();
356
357        // Add RRF scores from text results
358        for (rank, result) in text.iter().enumerate() {
359            *rrf_scores.entry(result.doc_id.clone()).or_insert(0.0) +=
360                1.0 / (K + rank as f64 + 1.0);
361        }
362
363        // Add RRF scores from vector results
364        for (rank, result) in vector.iter().enumerate() {
365            *rrf_scores.entry(result.doc_id.clone()).or_insert(0.0) +=
366                1.0 / (K + rank as f64 + 1.0);
367        }
368
369        // Add RRF scores from spatial results
370        for (rank, result) in spatial.iter().enumerate() {
371            *rrf_scores.entry(result.doc_id.clone()).or_insert(0.0) +=
372                1.0 / (K + rank as f64 + 1.0);
373        }
374
375        let mut results: Vec<FusedResult> = rrf_scores
376            .into_iter()
377            .map(|(uri, score)| {
378                let mut result = FusedResult::new(uri);
379                result.total_score = score;
380                // RRF produces a unified score, store it as Text modality for consistency
381                result.scores.insert(Modality::Text, score);
382                result
383            })
384            .collect();
385
386        results.sort_by(|a, b| {
387            b.total_score
388                .partial_cmp(&a.total_score)
389                .unwrap_or(std::cmp::Ordering::Equal)
390        });
391
392        Ok(results)
393    }
394
395    /// Normalize scores to [0, 1] range using configured method
396    pub fn normalize_scores(&self, results: &[DocumentScore]) -> Result<Vec<f64>> {
397        if results.is_empty() {
398            return Ok(Vec::new());
399        }
400
401        let scores: Vec<f64> = results.iter().map(|r| r.score as f64).collect();
402
403        match self.config.score_normalization {
404            NormalizationMethod::MinMax => self.min_max_normalize(&scores),
405            NormalizationMethod::ZScore => self.z_score_normalize(&scores),
406            NormalizationMethod::Sigmoid => self.sigmoid_normalize(&scores),
407        }
408    }
409
410    /// Min-max normalization: (x - min) / (max - min)
411    fn min_max_normalize(&self, scores: &[f64]) -> Result<Vec<f64>> {
412        if scores.is_empty() {
413            return Ok(Vec::new());
414        }
415
416        let min_score = scores
417            .iter()
418            .copied()
419            .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
420            .unwrap_or(0.0);
421
422        let max_score = scores
423            .iter()
424            .copied()
425            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
426            .unwrap_or(1.0);
427
428        let range = (max_score - min_score).max(1e-10); // Avoid division by zero
429
430        Ok(scores.iter().map(|&s| (s - min_score) / range).collect())
431    }
432
433    /// Z-score normalization: (x - mean) / std
434    fn z_score_normalize(&self, scores: &[f64]) -> Result<Vec<f64>> {
435        if scores.is_empty() {
436            return Ok(Vec::new());
437        }
438
439        let n = scores.len() as f64;
440        let mean = scores.iter().sum::<f64>() / n;
441
442        let variance = scores.iter().map(|&s| (s - mean).powi(2)).sum::<f64>() / n;
443        let std = variance.sqrt().max(1e-10); // Avoid division by zero
444
445        Ok(scores.iter().map(|&s| (s - mean) / std).collect())
446    }
447
448    /// Sigmoid normalization: 1 / (1 + exp(-x))
449    fn sigmoid_normalize(&self, scores: &[f64]) -> Result<Vec<f64>> {
450        Ok(scores.iter().map(|&s| 1.0 / (1.0 + (-s).exp())).collect())
451    }
452
453    /// Get the current configuration
454    pub fn config(&self) -> &FusionConfig {
455        &self.config
456    }
457
458    /// Update the configuration
459    pub fn set_config(&mut self, config: FusionConfig) {
460        self.config = config;
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467
468    fn create_test_results() -> (Vec<DocumentScore>, Vec<DocumentScore>, Vec<DocumentScore>) {
469        let text = vec![
470            DocumentScore {
471                doc_id: "doc1".to_string(),
472                score: 10.0,
473                rank: 0,
474            },
475            DocumentScore {
476                doc_id: "doc2".to_string(),
477                score: 8.0,
478                rank: 1,
479            },
480            DocumentScore {
481                doc_id: "doc3".to_string(),
482                score: 5.0,
483                rank: 2,
484            },
485        ];
486
487        let vector = vec![
488            DocumentScore {
489                doc_id: "doc2".to_string(),
490                score: 0.95,
491                rank: 0,
492            },
493            DocumentScore {
494                doc_id: "doc4".to_string(),
495                score: 0.90,
496                rank: 1,
497            },
498            DocumentScore {
499                doc_id: "doc1".to_string(),
500                score: 0.85,
501                rank: 2,
502            },
503        ];
504
505        let spatial = vec![
506            DocumentScore {
507                doc_id: "doc3".to_string(),
508                score: 0.99,
509                rank: 0,
510            },
511            DocumentScore {
512                doc_id: "doc1".to_string(),
513                score: 0.92,
514                rank: 1,
515            },
516            DocumentScore {
517                doc_id: "doc5".to_string(),
518                score: 0.88,
519                rank: 2,
520            },
521        ];
522
523        (text, vector, spatial)
524    }
525
526    #[test]
527    fn test_weighted_fusion() {
528        let (text, vector, spatial) = create_test_results();
529        let fusion = MultimodalFusion::new(FusionConfig::default());
530
531        let weights = vec![0.4, 0.4, 0.2]; // Text, Vector, Spatial
532        let strategy = FusionStrategy::Weighted { weights };
533
534        let results = fusion
535            .fuse(&text, &vector, &spatial, Some(strategy))
536            .unwrap();
537
538        assert!(!results.is_empty());
539        assert!(results[0].total_score > 0.0);
540        // doc1 appears in all three lists, should have high score
541        let doc1 = results.iter().find(|r| r.uri == "doc1").unwrap();
542        assert!(doc1.scores.len() == 3);
543    }
544
545    #[test]
546    fn test_sequential_fusion() {
547        let (text, vector, spatial) = create_test_results();
548        let fusion = MultimodalFusion::new(FusionConfig::default());
549
550        let order = vec![Modality::Text, Modality::Vector];
551        let strategy = FusionStrategy::Sequential { order };
552
553        let results = fusion
554            .fuse(&text, &vector, &spatial, Some(strategy))
555            .unwrap();
556
557        assert!(!results.is_empty());
558        // Should only include docs that passed text filter
559        assert!(results
560            .iter()
561            .all(|r| ["doc1", "doc2", "doc3"].contains(&r.uri.as_str())));
562    }
563
564    #[test]
565    fn test_cascade_fusion() {
566        let (text, vector, spatial) = create_test_results();
567        let fusion = MultimodalFusion::new(FusionConfig::default());
568
569        let thresholds = vec![0.0, 0.0, 0.0]; // Accept all for testing
570        let strategy = FusionStrategy::Cascade { thresholds };
571
572        let results = fusion
573            .fuse(&text, &vector, &spatial, Some(strategy))
574            .unwrap();
575
576        assert!(!results.is_empty());
577        // Should have scores from multiple modalities
578        if let Some(doc1) = results.iter().find(|r| r.uri == "doc1") {
579            assert!(doc1.scores.len() >= 2);
580        }
581    }
582
583    #[test]
584    fn test_rank_fusion() {
585        let (text, vector, spatial) = create_test_results();
586        let fusion = MultimodalFusion::new(FusionConfig::default());
587
588        let strategy = FusionStrategy::RankFusion;
589        let results = fusion
590            .fuse(&text, &vector, &spatial, Some(strategy))
591            .unwrap();
592
593        assert!(!results.is_empty());
594        // doc1 appears in all three lists at good positions
595        let doc1 = results.iter().find(|r| r.uri == "doc1").unwrap();
596        // doc4 appears only in vector list
597        let doc4 = results.iter().find(|r| r.uri == "doc4").unwrap();
598        // doc1 should have higher RRF score
599        assert!(doc1.total_score > doc4.total_score);
600    }
601
602    #[test]
603    fn test_min_max_normalization() {
604        let fusion = MultimodalFusion::new(FusionConfig::default());
605        let scores = vec![10.0, 5.0, 0.0];
606
607        let normalized = fusion.min_max_normalize(&scores).unwrap();
608
609        assert!((normalized[0] - 1.0).abs() < 1e-6);
610        assert!((normalized[1] - 0.5).abs() < 1e-6);
611        assert!((normalized[2] - 0.0).abs() < 1e-6);
612    }
613
614    #[test]
615    fn test_z_score_normalization() {
616        let fusion = MultimodalFusion::new(FusionConfig::default());
617        let scores = vec![10.0, 5.0, 0.0];
618
619        let normalized = fusion.z_score_normalize(&scores).unwrap();
620
621        // Mean should be ~5.0
622        // Z-scores should have mean ~0
623        let mean: f64 = normalized.iter().sum::<f64>() / normalized.len() as f64;
624        assert!(mean.abs() < 1e-6);
625    }
626
627    #[test]
628    fn test_sigmoid_normalization() {
629        let fusion = MultimodalFusion::new(FusionConfig::default());
630        let scores = vec![0.0, 1.0, -1.0];
631
632        let normalized = fusion.sigmoid_normalize(&scores).unwrap();
633
634        // Sigmoid of 0 should be 0.5
635        assert!((normalized[0] - 0.5).abs() < 1e-6);
636        // All values should be in (0, 1)
637        assert!(normalized.iter().all(|&s| s > 0.0 && s < 1.0));
638    }
639
640    #[test]
641    fn test_empty_results() {
642        let fusion = MultimodalFusion::new(FusionConfig::default());
643        let empty: Vec<DocumentScore> = Vec::new();
644
645        let strategy = FusionStrategy::RankFusion;
646        let results = fusion.fuse(&empty, &empty, &empty, Some(strategy)).unwrap();
647
648        assert!(results.is_empty());
649    }
650
651    #[test]
652    fn test_fused_result_operations() {
653        let mut result = FusedResult::new("test_doc".to_string());
654
655        result.add_score(Modality::Text, 0.5);
656        result.add_score(Modality::Vector, 0.3);
657        result.add_score(Modality::Spatial, 0.2);
658
659        assert_eq!(result.get_score(Modality::Text), Some(0.5));
660        assert_eq!(result.get_score(Modality::Vector), Some(0.3));
661        assert_eq!(result.get_score(Modality::Spatial), Some(0.2));
662
663        result.calculate_total();
664        assert!((result.total_score - 1.0).abs() < 1e-6);
665    }
666}