rexis_rag/multimodal/
embedding_fusion.rs

1//! # Embedding Fusion
2//!
3//! Advanced multi-modal embedding fusion strategies for unified representation.
4
5use super::{
6    EmbeddingFusionStrategy, EmbeddingWeights, ExtractedTable, FusionStrategy, MultiModalDocument,
7    MultiModalEmbeddings, ProcessedImage,
8};
9use crate::RragResult;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// Default embedding fusion implementation
14pub struct DefaultFusionStrategy {
15    /// Fusion strategy
16    strategy: FusionStrategy,
17
18    /// Fusion configuration
19    config: FusionConfig,
20
21    /// Weight calculator
22    weight_calculator: WeightCalculator,
23
24    /// Dimension normalizer
25    dimension_normalizer: DimensionNormalizer,
26
27    /// Attention mechanism (for attention-based fusion)
28    attention_mechanism: Option<AttentionMechanism>,
29}
30
31/// Fusion configuration
32#[derive(Debug, Clone)]
33pub struct FusionConfig {
34    /// Target embedding dimension
35    pub target_dimension: usize,
36
37    /// Normalize embeddings before fusion
38    pub normalize_embeddings: bool,
39
40    /// Use adaptive weights
41    pub adaptive_weights: bool,
42
43    /// Minimum weight threshold
44    pub min_weight: f32,
45
46    /// Maximum weight threshold
47    pub max_weight: f32,
48
49    /// Learning rate for adaptive fusion
50    pub learning_rate: f32,
51}
52
53/// Weight calculation strategies
54pub struct WeightCalculator {
55    /// Content analysis
56    content_analyzer: ContentAnalyzer,
57
58    /// Quality assessor
59    quality_assessor: QualityAssessor,
60}
61
62/// Dimension normalization utility
63pub struct DimensionNormalizer {
64    /// Target dimension
65    target_dim: usize,
66
67    /// Normalization strategy
68    strategy: NormalizationStrategy,
69}
70
71/// Attention mechanism for fusion
72pub struct AttentionMechanism {
73    /// Attention weights
74    attention_weights: HashMap<String, Vec<f32>>,
75
76    /// Query projection
77    query_projection: AttentionProjection,
78
79    /// Key projection
80    key_projection: AttentionProjection,
81
82    /// Value projection
83    value_projection: AttentionProjection,
84}
85
86/// Content analyzer for weight calculation
87pub struct ContentAnalyzer {
88    /// Text importance scorer
89    text_scorer: TextImportanceScorer,
90
91    /// Visual importance scorer
92    visual_scorer: VisualImportanceScorer,
93
94    /// Table importance scorer
95    table_scorer: TableImportanceScorer,
96}
97
98/// Quality assessment for embeddings
99pub struct QualityAssessor {
100    /// Embedding quality metrics
101    quality_metrics: Vec<QualityMetric>,
102}
103
104/// Normalization strategies
105#[derive(Debug, Clone, Copy)]
106pub enum NormalizationStrategy {
107    /// L2 normalization
108    L2Norm,
109
110    /// Min-Max scaling
111    MinMax,
112
113    /// Z-score normalization
114    ZScore,
115
116    /// Linear projection
117    LinearProjection,
118
119    /// PCA reduction
120    PCA,
121}
122
123/// Attention projection layer
124#[derive(Debug, Clone)]
125pub struct AttentionProjection {
126    /// Weight matrix
127    pub weights: Vec<Vec<f32>>,
128
129    /// Bias vector
130    pub bias: Vec<f32>,
131}
132
133/// Text importance scoring
134pub struct TextImportanceScorer {
135    /// TF-IDF calculator
136    tfidf_calculator: TfIdfCalculator,
137
138    /// Named entity recognizer
139    ner: NamedEntityRecognizer,
140}
141
142/// Visual importance scoring
143pub struct VisualImportanceScorer {
144    /// Saliency detector
145    saliency_detector: SaliencyDetector,
146
147    /// Aesthetic analyzer
148    aesthetic_analyzer: AestheticAnalyzer,
149}
150
151/// Table importance scoring
152pub struct TableImportanceScorer {
153    /// Information density calculator
154    density_calculator: InformationDensityCalculator,
155}
156
157/// Quality metrics for embeddings
158#[derive(Debug, Clone)]
159pub struct QualityMetric {
160    /// Metric name
161    pub name: String,
162
163    /// Metric weight
164    pub weight: f32,
165
166    /// Metric function
167    pub metric_type: QualityMetricType,
168}
169
170/// Quality metric types
171#[derive(Debug, Clone, Copy)]
172pub enum QualityMetricType {
173    /// Embedding norm
174    EmbeddingNorm,
175
176    /// Variance
177    Variance,
178
179    /// Coherence
180    Coherence,
181
182    /// Distinctiveness
183    Distinctiveness,
184}
185
186/// TF-IDF calculator
187pub struct TfIdfCalculator {
188    /// Document frequency map
189    document_frequencies: HashMap<String, usize>,
190
191    /// Total documents
192    total_documents: usize,
193}
194
195/// Named entity recognizer (simplified)
196pub struct NamedEntityRecognizer;
197
198/// Saliency detection for images
199pub struct SaliencyDetector;
200
201/// Aesthetic analysis for images
202pub struct AestheticAnalyzer;
203
204/// Information density calculator for tables
205pub struct InformationDensityCalculator;
206
207/// Fusion result
208#[derive(Debug, Clone)]
209pub struct FusionResult {
210    /// Fused embedding
211    pub fused_embedding: Vec<f32>,
212
213    /// Final weights used
214    pub weights: EmbeddingWeights,
215
216    /// Fusion confidence
217    pub confidence: f32,
218
219    /// Individual modality scores
220    pub modality_scores: ModalityScores,
221}
222
223/// Scores for each modality
224#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct ModalityScores {
226    /// Text quality score
227    pub text_score: f32,
228
229    /// Visual quality score
230    pub visual_score: f32,
231
232    /// Table quality score
233    pub table_score: f32,
234
235    /// Chart quality score
236    pub chart_score: f32,
237}
238
239impl DefaultFusionStrategy {
240    /// Create new fusion strategy
241    pub fn new(strategy: FusionStrategy) -> RragResult<Self> {
242        let config = FusionConfig::default();
243        let weight_calculator = WeightCalculator::new()?;
244        let dimension_normalizer = DimensionNormalizer::new(config.target_dimension);
245
246        let attention_mechanism = if matches!(strategy, FusionStrategy::Attention) {
247            Some(AttentionMechanism::new(config.target_dimension)?)
248        } else {
249            None
250        };
251
252        Ok(Self {
253            strategy,
254            config,
255            weight_calculator,
256            dimension_normalizer,
257            attention_mechanism,
258        })
259    }
260
261    /// Fuse embeddings with detailed analysis
262    pub fn fuse_embeddings_detailed(
263        &self,
264        document: &MultiModalDocument,
265    ) -> RragResult<FusionResult> {
266        // Calculate optimal weights
267        let weights = if self.config.adaptive_weights {
268            self.calculate_weights(document)?
269        } else {
270            document.embeddings.weights.clone()
271        };
272
273        // Score individual modalities
274        let modality_scores = self.calculate_modality_scores(document)?;
275
276        // Normalize embeddings
277        let normalized_embeddings = self.normalize_embeddings(&document.embeddings)?;
278
279        // Perform fusion based on strategy
280        let fused_embedding = match self.strategy {
281            FusionStrategy::Average => self.fuse_average(&normalized_embeddings, &weights)?,
282            FusionStrategy::Weighted => self.fuse_weighted(&normalized_embeddings, &weights)?,
283            FusionStrategy::Concatenate => self.fuse_concatenate(&normalized_embeddings)?,
284            FusionStrategy::Attention => self.fuse_attention(&normalized_embeddings, &weights)?,
285            FusionStrategy::Learned => self.fuse_learned(&normalized_embeddings, &weights)?,
286        };
287
288        // Calculate fusion confidence
289        let confidence = self.calculate_fusion_confidence(&fused_embedding, &modality_scores)?;
290
291        Ok(FusionResult {
292            fused_embedding,
293            weights,
294            confidence,
295            modality_scores,
296        })
297    }
298
299    /// Normalize embeddings to consistent dimensions
300    fn normalize_embeddings(
301        &self,
302        embeddings: &MultiModalEmbeddings,
303    ) -> RragResult<NormalizedEmbeddings> {
304        let text_normalized = self
305            .dimension_normalizer
306            .normalize(&embeddings.text_embeddings)?;
307
308        let visual_normalized = if let Some(ref visual) = embeddings.visual_embeddings {
309            Some(self.dimension_normalizer.normalize(visual)?)
310        } else {
311            None
312        };
313
314        let table_normalized = if let Some(ref table) = embeddings.table_embeddings {
315            Some(self.dimension_normalizer.normalize(table)?)
316        } else {
317            None
318        };
319
320        Ok(NormalizedEmbeddings {
321            text: text_normalized,
322            visual: visual_normalized,
323            table: table_normalized,
324        })
325    }
326
327    /// Average fusion
328    fn fuse_average(
329        &self,
330        embeddings: &NormalizedEmbeddings,
331        _weights: &EmbeddingWeights,
332    ) -> RragResult<Vec<f32>> {
333        let mut fused = embeddings.text.clone();
334        let mut count = 1;
335
336        if let Some(ref visual) = embeddings.visual {
337            for (i, &val) in visual.iter().enumerate() {
338                if i < fused.len() {
339                    fused[i] += val;
340                }
341            }
342            count += 1;
343        }
344
345        if let Some(ref table) = embeddings.table {
346            for (i, &val) in table.iter().enumerate() {
347                if i < fused.len() {
348                    fused[i] += val;
349                }
350            }
351            count += 1;
352        }
353
354        // Average
355        for val in &mut fused {
356            *val /= count as f32;
357        }
358
359        Ok(fused)
360    }
361
362    /// Weighted fusion
363    fn fuse_weighted(
364        &self,
365        embeddings: &NormalizedEmbeddings,
366        weights: &EmbeddingWeights,
367    ) -> RragResult<Vec<f32>> {
368        let mut fused = vec![0.0; self.config.target_dimension];
369
370        // Weighted combination
371        for (i, &val) in embeddings.text.iter().enumerate() {
372            if i < fused.len() {
373                fused[i] += val * weights.text_weight;
374            }
375        }
376
377        if let Some(ref visual) = embeddings.visual {
378            for (i, &val) in visual.iter().enumerate() {
379                if i < fused.len() {
380                    fused[i] += val * weights.visual_weight;
381                }
382            }
383        }
384
385        if let Some(ref table) = embeddings.table {
386            for (i, &val) in table.iter().enumerate() {
387                if i < fused.len() {
388                    fused[i] += val * weights.table_weight;
389                }
390            }
391        }
392
393        // Normalize if configured
394        if self.config.normalize_embeddings {
395            self.l2_normalize(&mut fused);
396        }
397
398        Ok(fused)
399    }
400
401    /// Concatenation fusion
402    fn fuse_concatenate(&self, embeddings: &NormalizedEmbeddings) -> RragResult<Vec<f32>> {
403        let mut fused = embeddings.text.clone();
404
405        if let Some(ref visual) = embeddings.visual {
406            fused.extend_from_slice(visual);
407        }
408
409        if let Some(ref table) = embeddings.table {
410            fused.extend_from_slice(table);
411        }
412
413        // Resize to target dimension if needed
414        if fused.len() > self.config.target_dimension {
415            fused.truncate(self.config.target_dimension);
416        } else if fused.len() < self.config.target_dimension {
417            fused.resize(self.config.target_dimension, 0.0);
418        }
419
420        Ok(fused)
421    }
422
423    /// Attention-based fusion
424    fn fuse_attention(
425        &self,
426        embeddings: &NormalizedEmbeddings,
427        _weights: &EmbeddingWeights,
428    ) -> RragResult<Vec<f32>> {
429        if let Some(ref attention) = self.attention_mechanism {
430            attention.apply_attention(embeddings)
431        } else {
432            // Fallback to weighted fusion
433            self.fuse_weighted(embeddings, _weights)
434        }
435    }
436
437    /// Learned fusion (placeholder for ML model)
438    fn fuse_learned(
439        &self,
440        embeddings: &NormalizedEmbeddings,
441        weights: &EmbeddingWeights,
442    ) -> RragResult<Vec<f32>> {
443        // For now, use weighted fusion with learned weights
444        self.fuse_weighted(embeddings, weights)
445    }
446
447    /// Calculate modality scores
448    fn calculate_modality_scores(
449        &self,
450        document: &MultiModalDocument,
451    ) -> RragResult<ModalityScores> {
452        let text_score = self
453            .weight_calculator
454            .content_analyzer
455            .text_scorer
456            .calculate_text_score(&document.text_content)?;
457
458        let visual_score = if !document.images.is_empty() {
459            self.weight_calculator
460                .content_analyzer
461                .visual_scorer
462                .calculate_visual_score(&document.images)?
463        } else {
464            0.0
465        };
466
467        let table_score = if !document.tables.is_empty() {
468            self.weight_calculator
469                .content_analyzer
470                .table_scorer
471                .calculate_table_score(&document.tables)?
472        } else {
473            0.0
474        };
475
476        let chart_score = if !document.charts.is_empty() {
477            // Simplified chart scoring
478            0.7
479        } else {
480            0.0
481        };
482
483        Ok(ModalityScores {
484            text_score,
485            visual_score,
486            table_score,
487            chart_score,
488        })
489    }
490
491    /// Calculate fusion confidence
492    fn calculate_fusion_confidence(
493        &self,
494        _fused_embedding: &[f32],
495        scores: &ModalityScores,
496    ) -> RragResult<f32> {
497        // Confidence based on modality diversity and quality
498        let mut confidence = 0.0;
499        let mut active_modalities = 0;
500
501        if scores.text_score > 0.0 {
502            confidence += scores.text_score * 0.4;
503            active_modalities += 1;
504        }
505
506        if scores.visual_score > 0.0 {
507            confidence += scores.visual_score * 0.3;
508            active_modalities += 1;
509        }
510
511        if scores.table_score > 0.0 {
512            confidence += scores.table_score * 0.2;
513            active_modalities += 1;
514        }
515
516        if scores.chart_score > 0.0 {
517            confidence += scores.chart_score * 0.1;
518            active_modalities += 1;
519        }
520
521        // Bonus for multi-modal content
522        if active_modalities > 1 {
523            confidence *= 1.0 + (active_modalities as f32 - 1.0) * 0.1;
524        }
525
526        Ok(confidence.min(1.0))
527    }
528
529    /// L2 normalization
530    fn l2_normalize(&self, vector: &mut [f32]) {
531        let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
532        if norm > 0.0 {
533            for val in vector {
534                *val /= norm;
535            }
536        }
537    }
538}
539
540impl EmbeddingFusionStrategy for DefaultFusionStrategy {
541    fn fuse_embeddings(&self, embeddings: &MultiModalEmbeddings) -> RragResult<Vec<f32>> {
542        match self.strategy {
543            FusionStrategy::Average => {
544                let mut fused = embeddings.text_embeddings.clone();
545                let mut count = 1;
546
547                if let Some(ref visual) = embeddings.visual_embeddings {
548                    for (i, &val) in visual.iter().enumerate() {
549                        if i < fused.len() {
550                            fused[i] += val;
551                        }
552                    }
553                    count += 1;
554                }
555
556                for val in &mut fused {
557                    *val /= count as f32;
558                }
559
560                Ok(fused)
561            }
562
563            FusionStrategy::Weighted => {
564                let mut fused = vec![0.0; embeddings.text_embeddings.len()];
565                let weights = &embeddings.weights;
566
567                for (i, &val) in embeddings.text_embeddings.iter().enumerate() {
568                    fused[i] += val * weights.text_weight;
569                }
570
571                if let Some(ref visual) = embeddings.visual_embeddings {
572                    for (i, &val) in visual.iter().enumerate() {
573                        if i < fused.len() {
574                            fused[i] += val * weights.visual_weight;
575                        }
576                    }
577                }
578
579                Ok(fused)
580            }
581
582            _ => {
583                // Fallback to weighted
584                self.fuse_embeddings(&MultiModalEmbeddings {
585                    text_embeddings: embeddings.text_embeddings.clone(),
586                    visual_embeddings: embeddings.visual_embeddings.clone(),
587                    table_embeddings: embeddings.table_embeddings.clone(),
588                    fused_embedding: vec![],
589                    weights: EmbeddingWeights {
590                        text_weight: 0.6,
591                        visual_weight: 0.3,
592                        table_weight: 0.1,
593                        chart_weight: 0.0,
594                    },
595                })
596            }
597        }
598    }
599
600    fn calculate_weights(&self, document: &MultiModalDocument) -> RragResult<EmbeddingWeights> {
601        self.weight_calculator.calculate_weights(document)
602    }
603}
604
605/// Normalized embeddings container
606#[derive(Debug, Clone)]
607pub struct NormalizedEmbeddings {
608    text: Vec<f32>,
609    visual: Option<Vec<f32>>,
610    table: Option<Vec<f32>>,
611}
612
613impl WeightCalculator {
614    /// Create new weight calculator
615    pub fn new() -> RragResult<Self> {
616        Ok(Self {
617            content_analyzer: ContentAnalyzer::new()?,
618            quality_assessor: QualityAssessor::new(),
619        })
620    }
621
622    /// Calculate optimal weights for document
623    pub fn calculate_weights(&self, document: &MultiModalDocument) -> RragResult<EmbeddingWeights> {
624        let scores = self.content_analyzer.analyze_content(document)?;
625        let quality_scores = self.quality_assessor.assess_quality(&document.embeddings)?;
626
627        // Combine content importance with quality scores
628        let text_weight = scores.text_importance * quality_scores.text_quality;
629        let visual_weight = scores.visual_importance * quality_scores.visual_quality;
630        let table_weight = scores.table_importance * quality_scores.table_quality;
631        let chart_weight = scores.chart_importance * quality_scores.chart_quality;
632
633        // Normalize weights to sum to 1.0
634        let total = text_weight + visual_weight + table_weight + chart_weight;
635
636        if total > 0.0 {
637            Ok(EmbeddingWeights {
638                text_weight: text_weight / total,
639                visual_weight: visual_weight / total,
640                table_weight: table_weight / total,
641                chart_weight: chart_weight / total,
642            })
643        } else {
644            // Fallback to default weights
645            Ok(EmbeddingWeights {
646                text_weight: 0.6,
647                visual_weight: 0.2,
648                table_weight: 0.1,
649                chart_weight: 0.1,
650            })
651        }
652    }
653}
654
655impl DimensionNormalizer {
656    /// Create new dimension normalizer
657    pub fn new(target_dim: usize) -> Self {
658        Self {
659            target_dim,
660            strategy: NormalizationStrategy::LinearProjection,
661        }
662    }
663
664    /// Normalize embedding to target dimension
665    pub fn normalize(&self, embedding: &[f32]) -> RragResult<Vec<f32>> {
666        match self.strategy {
667            NormalizationStrategy::LinearProjection => {
668                if embedding.len() == self.target_dim {
669                    Ok(embedding.to_vec())
670                } else if embedding.len() > self.target_dim {
671                    // Truncate
672                    Ok(embedding[..self.target_dim].to_vec())
673                } else {
674                    // Pad with zeros
675                    let mut normalized = embedding.to_vec();
676                    normalized.resize(self.target_dim, 0.0);
677                    Ok(normalized)
678                }
679            }
680
681            NormalizationStrategy::L2Norm => {
682                let mut normalized = embedding.to_vec();
683                let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
684                if norm > 0.0 {
685                    for val in &mut normalized {
686                        *val /= norm;
687                    }
688                }
689
690                // Resize to target dimension
691                if normalized.len() != self.target_dim {
692                    normalized.resize(self.target_dim, 0.0);
693                }
694
695                Ok(normalized)
696            }
697
698            _ => {
699                // Fallback to linear projection
700                self.normalize(embedding)
701            }
702        }
703    }
704}
705
706impl AttentionMechanism {
707    /// Create new attention mechanism
708    pub fn new(dim: usize) -> RragResult<Self> {
709        Ok(Self {
710            attention_weights: HashMap::new(),
711            query_projection: AttentionProjection::new(dim, dim)?,
712            key_projection: AttentionProjection::new(dim, dim)?,
713            value_projection: AttentionProjection::new(dim, dim)?,
714        })
715    }
716
717    /// Apply attention to embeddings
718    pub fn apply_attention(&self, embeddings: &NormalizedEmbeddings) -> RragResult<Vec<f32>> {
719        // Simplified attention mechanism
720        // In practice, this would implement proper multi-head attention
721
722        let query = &embeddings.text;
723        let mut attended = query.clone();
724
725        if let Some(ref visual) = embeddings.visual {
726            let attention_score = self.compute_attention_score(query, visual)?;
727            for (i, &val) in visual.iter().enumerate() {
728                if i < attended.len() {
729                    attended[i] += val * attention_score;
730                }
731            }
732        }
733
734        if let Some(ref table) = embeddings.table {
735            let attention_score = self.compute_attention_score(query, table)?;
736            for (i, &val) in table.iter().enumerate() {
737                if i < attended.len() {
738                    attended[i] += val * attention_score;
739                }
740            }
741        }
742
743        Ok(attended)
744    }
745
746    /// Compute attention score between query and key
747    fn compute_attention_score(&self, query: &[f32], key: &[f32]) -> RragResult<f32> {
748        // Dot product attention
749        let score: f32 = query.iter().zip(key.iter()).map(|(q, k)| q * k).sum();
750
751        // Normalize by sqrt of dimension
752        let normalized_score = score / (query.len() as f32).sqrt();
753
754        // Apply softmax (simplified)
755        Ok(normalized_score.exp() / (1.0 + normalized_score.exp()))
756    }
757}
758
759impl AttentionProjection {
760    /// Create new attention projection
761    pub fn new(input_dim: usize, output_dim: usize) -> RragResult<Self> {
762        // Initialize with small random values (simplified)
763        let weights = vec![vec![0.01; input_dim]; output_dim];
764        let bias = vec![0.0; output_dim];
765
766        Ok(Self { weights, bias })
767    }
768}
769
770impl ContentAnalyzer {
771    /// Create new content analyzer
772    pub fn new() -> RragResult<Self> {
773        Ok(Self {
774            text_scorer: TextImportanceScorer::new()?,
775            visual_scorer: VisualImportanceScorer::new(),
776            table_scorer: TableImportanceScorer::new(),
777        })
778    }
779
780    /// Analyze content importance
781    pub fn analyze_content(&self, document: &MultiModalDocument) -> RragResult<ContentScores> {
782        let text_importance = self
783            .text_scorer
784            .calculate_text_score(&document.text_content)?;
785        let visual_importance = self
786            .visual_scorer
787            .calculate_visual_score(&document.images)?;
788        let table_importance = self.table_scorer.calculate_table_score(&document.tables)?;
789        let chart_importance = if !document.charts.is_empty() {
790            0.7
791        } else {
792            0.0
793        };
794
795        Ok(ContentScores {
796            text_importance,
797            visual_importance,
798            table_importance,
799            chart_importance,
800        })
801    }
802}
803
804/// Content importance scores
805#[derive(Debug, Clone)]
806pub struct ContentScores {
807    pub text_importance: f32,
808    pub visual_importance: f32,
809    pub table_importance: f32,
810    pub chart_importance: f32,
811}
812
813/// Quality scores for embeddings
814#[derive(Debug, Clone)]
815pub struct QualityScores {
816    pub text_quality: f32,
817    pub visual_quality: f32,
818    pub table_quality: f32,
819    pub chart_quality: f32,
820}
821
822impl TextImportanceScorer {
823    pub fn new() -> RragResult<Self> {
824        Ok(Self {
825            tfidf_calculator: TfIdfCalculator::new(),
826            ner: NamedEntityRecognizer,
827        })
828    }
829
830    pub fn calculate_text_score(&self, text: &str) -> RragResult<f32> {
831        let word_count = text.split_whitespace().count();
832        let entity_score = self.ner.calculate_entity_score(text)?;
833
834        // Combine length and entity density
835        let length_score = (word_count as f32 / 1000.0).min(1.0);
836        Ok(length_score * 0.7 + entity_score * 0.3)
837    }
838}
839
840impl VisualImportanceScorer {
841    pub fn new() -> Self {
842        Self {
843            saliency_detector: SaliencyDetector,
844            aesthetic_analyzer: AestheticAnalyzer,
845        }
846    }
847
848    pub fn calculate_visual_score(&self, images: &[ProcessedImage]) -> RragResult<f32> {
849        if images.is_empty() {
850            return Ok(0.0);
851        }
852
853        let mut total_score = 0.0;
854        for image in images {
855            let quality_score = image
856                .features
857                .as_ref()
858                .map(|f| (f.quality.sharpness + f.quality.contrast) / 2.0)
859                .unwrap_or(0.5);
860
861            let aesthetic_score = self.aesthetic_analyzer.analyze_aesthetics(image)?;
862            total_score += quality_score * 0.6 + aesthetic_score * 0.4;
863        }
864
865        Ok(total_score / images.len() as f32)
866    }
867}
868
869impl TableImportanceScorer {
870    pub fn new() -> Self {
871        Self {
872            density_calculator: InformationDensityCalculator,
873        }
874    }
875
876    pub fn calculate_table_score(&self, tables: &[ExtractedTable]) -> RragResult<f32> {
877        if tables.is_empty() {
878            return Ok(0.0);
879        }
880
881        let mut total_score = 0.0;
882        for table in tables {
883            let size_score = (table.rows.len() * table.headers.len()) as f32 / 100.0;
884            let density_score = self.density_calculator.calculate_density(table)?;
885            total_score += size_score.min(1.0) * 0.5 + density_score * 0.5;
886        }
887
888        Ok(total_score / tables.len() as f32)
889    }
890}
891
892impl QualityAssessor {
893    pub fn new() -> Self {
894        Self {
895            quality_metrics: vec![
896                QualityMetric {
897                    name: "norm".to_string(),
898                    weight: 0.3,
899                    metric_type: QualityMetricType::EmbeddingNorm,
900                },
901                QualityMetric {
902                    name: "variance".to_string(),
903                    weight: 0.4,
904                    metric_type: QualityMetricType::Variance,
905                },
906                QualityMetric {
907                    name: "coherence".to_string(),
908                    weight: 0.3,
909                    metric_type: QualityMetricType::Coherence,
910                },
911            ],
912        }
913    }
914
915    pub fn assess_quality(&self, embeddings: &MultiModalEmbeddings) -> RragResult<QualityScores> {
916        let text_quality = self.calculate_embedding_quality(&embeddings.text_embeddings)?;
917
918        let visual_quality = if let Some(ref visual) = embeddings.visual_embeddings {
919            self.calculate_embedding_quality(visual)?
920        } else {
921            0.0
922        };
923
924        let table_quality = if let Some(ref table) = embeddings.table_embeddings {
925            self.calculate_embedding_quality(table)?
926        } else {
927            0.0
928        };
929
930        Ok(QualityScores {
931            text_quality,
932            visual_quality,
933            table_quality,
934            chart_quality: 0.7, // Simplified
935        })
936    }
937
938    fn calculate_embedding_quality(&self, embedding: &[f32]) -> RragResult<f32> {
939        let mut quality_score = 0.0;
940
941        for metric in &self.quality_metrics {
942            let score = match metric.metric_type {
943                QualityMetricType::EmbeddingNorm => {
944                    let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
945                    (norm / embedding.len() as f32).min(1.0)
946                }
947                QualityMetricType::Variance => {
948                    let mean = embedding.iter().sum::<f32>() / embedding.len() as f32;
949                    let variance = embedding.iter().map(|x| (x - mean).powi(2)).sum::<f32>()
950                        / embedding.len() as f32;
951                    variance.min(1.0)
952                }
953                QualityMetricType::Coherence => 0.8, // Simplified
954                QualityMetricType::Distinctiveness => 0.7, // Simplified
955            };
956
957            quality_score += score * metric.weight;
958        }
959
960        Ok(quality_score)
961    }
962}
963
964// Simplified implementations for helper components
965impl TfIdfCalculator {
966    pub fn new() -> Self {
967        Self {
968            document_frequencies: HashMap::new(),
969            total_documents: 0,
970        }
971    }
972}
973
974impl NamedEntityRecognizer {
975    pub fn calculate_entity_score(&self, _text: &str) -> RragResult<f32> {
976        // Simplified entity scoring
977        Ok(0.6)
978    }
979}
980
981impl SaliencyDetector {}
982
983impl AestheticAnalyzer {
984    pub fn analyze_aesthetics(&self, _image: &ProcessedImage) -> RragResult<f32> {
985        // Simplified aesthetic analysis
986        Ok(0.7)
987    }
988}
989
990impl InformationDensityCalculator {
991    pub fn calculate_density(&self, table: &ExtractedTable) -> RragResult<f32> {
992        let total_cells = table.rows.len() * table.headers.len();
993        let filled_cells = table
994            .rows
995            .iter()
996            .flatten()
997            .filter(|cell| !cell.value.trim().is_empty())
998            .count();
999
1000        Ok(filled_cells as f32 / total_cells as f32)
1001    }
1002}
1003
1004impl Default for FusionConfig {
1005    fn default() -> Self {
1006        Self {
1007            target_dimension: 768,
1008            normalize_embeddings: true,
1009            adaptive_weights: true,
1010            min_weight: 0.01,
1011            max_weight: 0.99,
1012            learning_rate: 0.001,
1013        }
1014    }
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019    use super::*;
1020
1021    #[test]
1022    fn test_fusion_strategy_creation() {
1023        let strategy = DefaultFusionStrategy::new(FusionStrategy::Weighted).unwrap();
1024        assert!(matches!(strategy.strategy, FusionStrategy::Weighted));
1025    }
1026
1027    #[test]
1028    fn test_dimension_normalization() {
1029        let normalizer = DimensionNormalizer::new(512);
1030
1031        let embedding = vec![1.0, 2.0, 3.0];
1032        let normalized = normalizer.normalize(&embedding).unwrap();
1033
1034        assert_eq!(normalized.len(), 512);
1035        assert_eq!(&normalized[..3], &[1.0, 2.0, 3.0]);
1036    }
1037
1038    #[test]
1039    fn test_weight_calculation() {
1040        let calculator = WeightCalculator::new().unwrap();
1041
1042        // Create test document
1043        let document = MultiModalDocument {
1044            id: "test".to_string(),
1045            text_content: "Test content".to_string(),
1046            images: vec![],
1047            tables: vec![],
1048            charts: vec![],
1049            layout: super::super::DocumentLayout {
1050                pages: 1,
1051                sections: vec![],
1052                reading_order: vec![],
1053                columns: None,
1054                document_type: super::super::DocumentType::PlainText,
1055            },
1056            embeddings: MultiModalEmbeddings {
1057                text_embeddings: vec![0.1, 0.2, 0.3],
1058                visual_embeddings: None,
1059                table_embeddings: None,
1060                fused_embedding: vec![],
1061                weights: EmbeddingWeights {
1062                    text_weight: 1.0,
1063                    visual_weight: 0.0,
1064                    table_weight: 0.0,
1065                    chart_weight: 0.0,
1066                },
1067            },
1068            metadata: super::super::DocumentMetadata {
1069                title: None,
1070                author: None,
1071                creation_date: None,
1072                modification_date: None,
1073                page_count: 1,
1074                word_count: 2,
1075                language: "en".to_string(),
1076                format: super::super::DocumentType::PlainText,
1077            },
1078        };
1079
1080        let weights = calculator.calculate_weights(&document).unwrap();
1081        assert!(weights.text_weight > 0.0);
1082    }
1083}