use super::{
EmbeddingFusionStrategy, EmbeddingWeights, ExtractedTable, FusionStrategy, MultiModalDocument,
MultiModalEmbeddings, ProcessedImage,
};
use crate::RragResult;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub struct DefaultFusionStrategy {
strategy: FusionStrategy,
config: FusionConfig,
weight_calculator: WeightCalculator,
dimension_normalizer: DimensionNormalizer,
attention_mechanism: Option<AttentionMechanism>,
}
#[derive(Debug, Clone)]
pub struct FusionConfig {
pub target_dimension: usize,
pub normalize_embeddings: bool,
pub adaptive_weights: bool,
pub min_weight: f32,
pub max_weight: f32,
pub learning_rate: f32,
}
pub struct WeightCalculator {
content_analyzer: ContentAnalyzer,
quality_assessor: QualityAssessor,
}
pub struct DimensionNormalizer {
target_dim: usize,
strategy: NormalizationStrategy,
}
pub struct AttentionMechanism {
attention_weights: HashMap<String, Vec<f32>>,
query_projection: AttentionProjection,
key_projection: AttentionProjection,
value_projection: AttentionProjection,
}
pub struct ContentAnalyzer {
text_scorer: TextImportanceScorer,
visual_scorer: VisualImportanceScorer,
table_scorer: TableImportanceScorer,
}
pub struct QualityAssessor {
quality_metrics: Vec<QualityMetric>,
}
#[derive(Debug, Clone, Copy)]
pub enum NormalizationStrategy {
L2Norm,
MinMax,
ZScore,
LinearProjection,
PCA,
}
#[derive(Debug, Clone)]
pub struct AttentionProjection {
pub weights: Vec<Vec<f32>>,
pub bias: Vec<f32>,
}
pub struct TextImportanceScorer {
tfidf_calculator: TfIdfCalculator,
ner: NamedEntityRecognizer,
}
pub struct VisualImportanceScorer {
saliency_detector: SaliencyDetector,
aesthetic_analyzer: AestheticAnalyzer,
}
pub struct TableImportanceScorer {
density_calculator: InformationDensityCalculator,
}
#[derive(Debug, Clone)]
pub struct QualityMetric {
pub name: String,
pub weight: f32,
pub metric_type: QualityMetricType,
}
#[derive(Debug, Clone, Copy)]
pub enum QualityMetricType {
EmbeddingNorm,
Variance,
Coherence,
Distinctiveness,
}
pub struct TfIdfCalculator {
document_frequencies: HashMap<String, usize>,
total_documents: usize,
}
pub struct NamedEntityRecognizer;
pub struct SaliencyDetector;
pub struct AestheticAnalyzer;
pub struct InformationDensityCalculator;
#[derive(Debug, Clone)]
pub struct FusionResult {
pub fused_embedding: Vec<f32>,
pub weights: EmbeddingWeights,
pub confidence: f32,
pub modality_scores: ModalityScores,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModalityScores {
pub text_score: f32,
pub visual_score: f32,
pub table_score: f32,
pub chart_score: f32,
}
impl DefaultFusionStrategy {
pub fn new(strategy: FusionStrategy) -> RragResult<Self> {
let config = FusionConfig::default();
let weight_calculator = WeightCalculator::new()?;
let dimension_normalizer = DimensionNormalizer::new(config.target_dimension);
let attention_mechanism = if matches!(strategy, FusionStrategy::Attention) {
Some(AttentionMechanism::new(config.target_dimension)?)
} else {
None
};
Ok(Self {
strategy,
config,
weight_calculator,
dimension_normalizer,
attention_mechanism,
})
}
pub fn fuse_embeddings_detailed(
&self,
document: &MultiModalDocument,
) -> RragResult<FusionResult> {
let weights = if self.config.adaptive_weights {
self.calculate_weights(document)?
} else {
document.embeddings.weights.clone()
};
let modality_scores = self.calculate_modality_scores(document)?;
let normalized_embeddings = self.normalize_embeddings(&document.embeddings)?;
let fused_embedding = match self.strategy {
FusionStrategy::Average => self.fuse_average(&normalized_embeddings, &weights)?,
FusionStrategy::Weighted => self.fuse_weighted(&normalized_embeddings, &weights)?,
FusionStrategy::Concatenate => self.fuse_concatenate(&normalized_embeddings)?,
FusionStrategy::Attention => self.fuse_attention(&normalized_embeddings, &weights)?,
FusionStrategy::Learned => self.fuse_learned(&normalized_embeddings, &weights)?,
};
let confidence = self.calculate_fusion_confidence(&fused_embedding, &modality_scores)?;
Ok(FusionResult {
fused_embedding,
weights,
confidence,
modality_scores,
})
}
fn normalize_embeddings(
&self,
embeddings: &MultiModalEmbeddings,
) -> RragResult<NormalizedEmbeddings> {
let text_normalized = self
.dimension_normalizer
.normalize(&embeddings.text_embeddings)?;
let visual_normalized = if let Some(ref visual) = embeddings.visual_embeddings {
Some(self.dimension_normalizer.normalize(visual)?)
} else {
None
};
let table_normalized = if let Some(ref table) = embeddings.table_embeddings {
Some(self.dimension_normalizer.normalize(table)?)
} else {
None
};
Ok(NormalizedEmbeddings {
text: text_normalized,
visual: visual_normalized,
table: table_normalized,
})
}
fn fuse_average(
&self,
embeddings: &NormalizedEmbeddings,
_weights: &EmbeddingWeights,
) -> RragResult<Vec<f32>> {
let mut fused = embeddings.text.clone();
let mut count = 1;
if let Some(ref visual) = embeddings.visual {
for (i, &val) in visual.iter().enumerate() {
if i < fused.len() {
fused[i] += val;
}
}
count += 1;
}
if let Some(ref table) = embeddings.table {
for (i, &val) in table.iter().enumerate() {
if i < fused.len() {
fused[i] += val;
}
}
count += 1;
}
for val in &mut fused {
*val /= count as f32;
}
Ok(fused)
}
fn fuse_weighted(
&self,
embeddings: &NormalizedEmbeddings,
weights: &EmbeddingWeights,
) -> RragResult<Vec<f32>> {
let mut fused = vec![0.0; self.config.target_dimension];
for (i, &val) in embeddings.text.iter().enumerate() {
if i < fused.len() {
fused[i] += val * weights.text_weight;
}
}
if let Some(ref visual) = embeddings.visual {
for (i, &val) in visual.iter().enumerate() {
if i < fused.len() {
fused[i] += val * weights.visual_weight;
}
}
}
if let Some(ref table) = embeddings.table {
for (i, &val) in table.iter().enumerate() {
if i < fused.len() {
fused[i] += val * weights.table_weight;
}
}
}
if self.config.normalize_embeddings {
self.l2_normalize(&mut fused);
}
Ok(fused)
}
fn fuse_concatenate(&self, embeddings: &NormalizedEmbeddings) -> RragResult<Vec<f32>> {
let mut fused = embeddings.text.clone();
if let Some(ref visual) = embeddings.visual {
fused.extend_from_slice(visual);
}
if let Some(ref table) = embeddings.table {
fused.extend_from_slice(table);
}
if fused.len() > self.config.target_dimension {
fused.truncate(self.config.target_dimension);
} else if fused.len() < self.config.target_dimension {
fused.resize(self.config.target_dimension, 0.0);
}
Ok(fused)
}
fn fuse_attention(
&self,
embeddings: &NormalizedEmbeddings,
_weights: &EmbeddingWeights,
) -> RragResult<Vec<f32>> {
if let Some(ref attention) = self.attention_mechanism {
attention.apply_attention(embeddings)
} else {
self.fuse_weighted(embeddings, _weights)
}
}
fn fuse_learned(
&self,
embeddings: &NormalizedEmbeddings,
weights: &EmbeddingWeights,
) -> RragResult<Vec<f32>> {
self.fuse_weighted(embeddings, weights)
}
fn calculate_modality_scores(
&self,
document: &MultiModalDocument,
) -> RragResult<ModalityScores> {
let text_score = self
.weight_calculator
.content_analyzer
.text_scorer
.calculate_text_score(&document.text_content)?;
let visual_score = if !document.images.is_empty() {
self.weight_calculator
.content_analyzer
.visual_scorer
.calculate_visual_score(&document.images)?
} else {
0.0
};
let table_score = if !document.tables.is_empty() {
self.weight_calculator
.content_analyzer
.table_scorer
.calculate_table_score(&document.tables)?
} else {
0.0
};
let chart_score = if !document.charts.is_empty() {
0.7
} else {
0.0
};
Ok(ModalityScores {
text_score,
visual_score,
table_score,
chart_score,
})
}
fn calculate_fusion_confidence(
&self,
_fused_embedding: &[f32],
scores: &ModalityScores,
) -> RragResult<f32> {
let mut confidence = 0.0;
let mut active_modalities = 0;
if scores.text_score > 0.0 {
confidence += scores.text_score * 0.4;
active_modalities += 1;
}
if scores.visual_score > 0.0 {
confidence += scores.visual_score * 0.3;
active_modalities += 1;
}
if scores.table_score > 0.0 {
confidence += scores.table_score * 0.2;
active_modalities += 1;
}
if scores.chart_score > 0.0 {
confidence += scores.chart_score * 0.1;
active_modalities += 1;
}
if active_modalities > 1 {
confidence *= 1.0 + (active_modalities as f32 - 1.0) * 0.1;
}
Ok(confidence.min(1.0))
}
fn l2_normalize(&self, vector: &mut [f32]) {
let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in vector {
*val /= norm;
}
}
}
}
impl EmbeddingFusionStrategy for DefaultFusionStrategy {
fn fuse_embeddings(&self, embeddings: &MultiModalEmbeddings) -> RragResult<Vec<f32>> {
match self.strategy {
FusionStrategy::Average => {
let mut fused = embeddings.text_embeddings.clone();
let mut count = 1;
if let Some(ref visual) = embeddings.visual_embeddings {
for (i, &val) in visual.iter().enumerate() {
if i < fused.len() {
fused[i] += val;
}
}
count += 1;
}
for val in &mut fused {
*val /= count as f32;
}
Ok(fused)
}
FusionStrategy::Weighted => {
let mut fused = vec![0.0; embeddings.text_embeddings.len()];
let weights = &embeddings.weights;
for (i, &val) in embeddings.text_embeddings.iter().enumerate() {
fused[i] += val * weights.text_weight;
}
if let Some(ref visual) = embeddings.visual_embeddings {
for (i, &val) in visual.iter().enumerate() {
if i < fused.len() {
fused[i] += val * weights.visual_weight;
}
}
}
Ok(fused)
}
_ => {
self.fuse_embeddings(&MultiModalEmbeddings {
text_embeddings: embeddings.text_embeddings.clone(),
visual_embeddings: embeddings.visual_embeddings.clone(),
table_embeddings: embeddings.table_embeddings.clone(),
fused_embedding: vec![],
weights: EmbeddingWeights {
text_weight: 0.6,
visual_weight: 0.3,
table_weight: 0.1,
chart_weight: 0.0,
},
})
}
}
}
fn calculate_weights(&self, document: &MultiModalDocument) -> RragResult<EmbeddingWeights> {
self.weight_calculator.calculate_weights(document)
}
}
#[derive(Debug, Clone)]
pub struct NormalizedEmbeddings {
text: Vec<f32>,
visual: Option<Vec<f32>>,
table: Option<Vec<f32>>,
}
impl WeightCalculator {
pub fn new() -> RragResult<Self> {
Ok(Self {
content_analyzer: ContentAnalyzer::new()?,
quality_assessor: QualityAssessor::new(),
})
}
pub fn calculate_weights(&self, document: &MultiModalDocument) -> RragResult<EmbeddingWeights> {
let scores = self.content_analyzer.analyze_content(document)?;
let quality_scores = self.quality_assessor.assess_quality(&document.embeddings)?;
let text_weight = scores.text_importance * quality_scores.text_quality;
let visual_weight = scores.visual_importance * quality_scores.visual_quality;
let table_weight = scores.table_importance * quality_scores.table_quality;
let chart_weight = scores.chart_importance * quality_scores.chart_quality;
let total = text_weight + visual_weight + table_weight + chart_weight;
if total > 0.0 {
Ok(EmbeddingWeights {
text_weight: text_weight / total,
visual_weight: visual_weight / total,
table_weight: table_weight / total,
chart_weight: chart_weight / total,
})
} else {
Ok(EmbeddingWeights {
text_weight: 0.6,
visual_weight: 0.2,
table_weight: 0.1,
chart_weight: 0.1,
})
}
}
}
impl DimensionNormalizer {
pub fn new(target_dim: usize) -> Self {
Self {
target_dim,
strategy: NormalizationStrategy::LinearProjection,
}
}
pub fn normalize(&self, embedding: &[f32]) -> RragResult<Vec<f32>> {
match self.strategy {
NormalizationStrategy::LinearProjection => {
if embedding.len() == self.target_dim {
Ok(embedding.to_vec())
} else if embedding.len() > self.target_dim {
Ok(embedding[..self.target_dim].to_vec())
} else {
let mut normalized = embedding.to_vec();
normalized.resize(self.target_dim, 0.0);
Ok(normalized)
}
}
NormalizationStrategy::L2Norm => {
let mut normalized = embedding.to_vec();
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut normalized {
*val /= norm;
}
}
if normalized.len() != self.target_dim {
normalized.resize(self.target_dim, 0.0);
}
Ok(normalized)
}
_ => {
self.normalize(embedding)
}
}
}
}
impl AttentionMechanism {
pub fn new(dim: usize) -> RragResult<Self> {
Ok(Self {
attention_weights: HashMap::new(),
query_projection: AttentionProjection::new(dim, dim)?,
key_projection: AttentionProjection::new(dim, dim)?,
value_projection: AttentionProjection::new(dim, dim)?,
})
}
pub fn apply_attention(&self, embeddings: &NormalizedEmbeddings) -> RragResult<Vec<f32>> {
let query = &embeddings.text;
let mut attended = query.clone();
if let Some(ref visual) = embeddings.visual {
let attention_score = self.compute_attention_score(query, visual)?;
for (i, &val) in visual.iter().enumerate() {
if i < attended.len() {
attended[i] += val * attention_score;
}
}
}
if let Some(ref table) = embeddings.table {
let attention_score = self.compute_attention_score(query, table)?;
for (i, &val) in table.iter().enumerate() {
if i < attended.len() {
attended[i] += val * attention_score;
}
}
}
Ok(attended)
}
fn compute_attention_score(&self, query: &[f32], key: &[f32]) -> RragResult<f32> {
let score: f32 = query.iter().zip(key.iter()).map(|(q, k)| q * k).sum();
let normalized_score = score / (query.len() as f32).sqrt();
Ok(normalized_score.exp() / (1.0 + normalized_score.exp()))
}
}
impl AttentionProjection {
pub fn new(input_dim: usize, output_dim: usize) -> RragResult<Self> {
let weights = vec![vec![0.01; input_dim]; output_dim];
let bias = vec![0.0; output_dim];
Ok(Self { weights, bias })
}
}
impl ContentAnalyzer {
pub fn new() -> RragResult<Self> {
Ok(Self {
text_scorer: TextImportanceScorer::new()?,
visual_scorer: VisualImportanceScorer::new(),
table_scorer: TableImportanceScorer::new(),
})
}
pub fn analyze_content(&self, document: &MultiModalDocument) -> RragResult<ContentScores> {
let text_importance = self
.text_scorer
.calculate_text_score(&document.text_content)?;
let visual_importance = self
.visual_scorer
.calculate_visual_score(&document.images)?;
let table_importance = self.table_scorer.calculate_table_score(&document.tables)?;
let chart_importance = if !document.charts.is_empty() {
0.7
} else {
0.0
};
Ok(ContentScores {
text_importance,
visual_importance,
table_importance,
chart_importance,
})
}
}
#[derive(Debug, Clone)]
pub struct ContentScores {
pub text_importance: f32,
pub visual_importance: f32,
pub table_importance: f32,
pub chart_importance: f32,
}
#[derive(Debug, Clone)]
pub struct QualityScores {
pub text_quality: f32,
pub visual_quality: f32,
pub table_quality: f32,
pub chart_quality: f32,
}
impl TextImportanceScorer {
pub fn new() -> RragResult<Self> {
Ok(Self {
tfidf_calculator: TfIdfCalculator::new(),
ner: NamedEntityRecognizer,
})
}
pub fn calculate_text_score(&self, text: &str) -> RragResult<f32> {
let word_count = text.split_whitespace().count();
let entity_score = self.ner.calculate_entity_score(text)?;
let length_score = (word_count as f32 / 1000.0).min(1.0);
Ok(length_score * 0.7 + entity_score * 0.3)
}
}
impl VisualImportanceScorer {
pub fn new() -> Self {
Self {
saliency_detector: SaliencyDetector,
aesthetic_analyzer: AestheticAnalyzer,
}
}
pub fn calculate_visual_score(&self, images: &[ProcessedImage]) -> RragResult<f32> {
if images.is_empty() {
return Ok(0.0);
}
let mut total_score = 0.0;
for image in images {
let quality_score = image
.features
.as_ref()
.map(|f| (f.quality.sharpness + f.quality.contrast) / 2.0)
.unwrap_or(0.5);
let aesthetic_score = self.aesthetic_analyzer.analyze_aesthetics(image)?;
total_score += quality_score * 0.6 + aesthetic_score * 0.4;
}
Ok(total_score / images.len() as f32)
}
}
impl TableImportanceScorer {
pub fn new() -> Self {
Self {
density_calculator: InformationDensityCalculator,
}
}
pub fn calculate_table_score(&self, tables: &[ExtractedTable]) -> RragResult<f32> {
if tables.is_empty() {
return Ok(0.0);
}
let mut total_score = 0.0;
for table in tables {
let size_score = (table.rows.len() * table.headers.len()) as f32 / 100.0;
let density_score = self.density_calculator.calculate_density(table)?;
total_score += size_score.min(1.0) * 0.5 + density_score * 0.5;
}
Ok(total_score / tables.len() as f32)
}
}
impl QualityAssessor {
pub fn new() -> Self {
Self {
quality_metrics: vec![
QualityMetric {
name: "norm".to_string(),
weight: 0.3,
metric_type: QualityMetricType::EmbeddingNorm,
},
QualityMetric {
name: "variance".to_string(),
weight: 0.4,
metric_type: QualityMetricType::Variance,
},
QualityMetric {
name: "coherence".to_string(),
weight: 0.3,
metric_type: QualityMetricType::Coherence,
},
],
}
}
pub fn assess_quality(&self, embeddings: &MultiModalEmbeddings) -> RragResult<QualityScores> {
let text_quality = self.calculate_embedding_quality(&embeddings.text_embeddings)?;
let visual_quality = if let Some(ref visual) = embeddings.visual_embeddings {
self.calculate_embedding_quality(visual)?
} else {
0.0
};
let table_quality = if let Some(ref table) = embeddings.table_embeddings {
self.calculate_embedding_quality(table)?
} else {
0.0
};
Ok(QualityScores {
text_quality,
visual_quality,
table_quality,
chart_quality: 0.7, })
}
fn calculate_embedding_quality(&self, embedding: &[f32]) -> RragResult<f32> {
let mut quality_score = 0.0;
for metric in &self.quality_metrics {
let score = match metric.metric_type {
QualityMetricType::EmbeddingNorm => {
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
(norm / embedding.len() as f32).min(1.0)
}
QualityMetricType::Variance => {
let mean = embedding.iter().sum::<f32>() / embedding.len() as f32;
let variance = embedding.iter().map(|x| (x - mean).powi(2)).sum::<f32>()
/ embedding.len() as f32;
variance.min(1.0)
}
QualityMetricType::Coherence => 0.8, QualityMetricType::Distinctiveness => 0.7, };
quality_score += score * metric.weight;
}
Ok(quality_score)
}
}
impl TfIdfCalculator {
pub fn new() -> Self {
Self {
document_frequencies: HashMap::new(),
total_documents: 0,
}
}
}
impl NamedEntityRecognizer {
pub fn calculate_entity_score(&self, _text: &str) -> RragResult<f32> {
Ok(0.6)
}
}
impl SaliencyDetector {}
impl AestheticAnalyzer {
pub fn analyze_aesthetics(&self, _image: &ProcessedImage) -> RragResult<f32> {
Ok(0.7)
}
}
impl InformationDensityCalculator {
pub fn calculate_density(&self, table: &ExtractedTable) -> RragResult<f32> {
let total_cells = table.rows.len() * table.headers.len();
let filled_cells = table
.rows
.iter()
.flatten()
.filter(|cell| !cell.value.trim().is_empty())
.count();
Ok(filled_cells as f32 / total_cells as f32)
}
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
target_dimension: 768,
normalize_embeddings: true,
adaptive_weights: true,
min_weight: 0.01,
max_weight: 0.99,
learning_rate: 0.001,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fusion_strategy_creation() {
let strategy = DefaultFusionStrategy::new(FusionStrategy::Weighted).unwrap();
assert!(matches!(strategy.strategy, FusionStrategy::Weighted));
}
#[test]
fn test_dimension_normalization() {
let normalizer = DimensionNormalizer::new(512);
let embedding = vec![1.0, 2.0, 3.0];
let normalized = normalizer.normalize(&embedding).unwrap();
assert_eq!(normalized.len(), 512);
assert_eq!(&normalized[..3], &[1.0, 2.0, 3.0]);
}
#[test]
fn test_weight_calculation() {
let calculator = WeightCalculator::new().unwrap();
let document = MultiModalDocument {
id: "test".to_string(),
text_content: "Test content".to_string(),
images: vec![],
tables: vec![],
charts: vec![],
layout: super::super::DocumentLayout {
pages: 1,
sections: vec![],
reading_order: vec![],
columns: None,
document_type: super::super::DocumentType::PlainText,
},
embeddings: MultiModalEmbeddings {
text_embeddings: vec![0.1, 0.2, 0.3],
visual_embeddings: None,
table_embeddings: None,
fused_embedding: vec![],
weights: EmbeddingWeights {
text_weight: 1.0,
visual_weight: 0.0,
table_weight: 0.0,
chart_weight: 0.0,
},
},
metadata: super::super::DocumentMetadata {
title: None,
author: None,
creation_date: None,
modification_date: None,
page_count: 1,
word_count: 2,
language: "en".to_string(),
format: super::super::DocumentType::PlainText,
},
};
let weights = calculator.calculate_weights(&document).unwrap();
assert!(weights.text_weight > 0.0);
}
}