1use super::{
6 EmbeddingFusionStrategy, EmbeddingWeights, ExtractedTable, FusionStrategy, MultiModalDocument,
7 MultiModalEmbeddings, ProcessedImage,
8};
9use crate::RragResult;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13pub struct DefaultFusionStrategy {
15 strategy: FusionStrategy,
17
18 config: FusionConfig,
20
21 weight_calculator: WeightCalculator,
23
24 dimension_normalizer: DimensionNormalizer,
26
27 attention_mechanism: Option<AttentionMechanism>,
29}
30
31#[derive(Debug, Clone)]
33pub struct FusionConfig {
34 pub target_dimension: usize,
36
37 pub normalize_embeddings: bool,
39
40 pub adaptive_weights: bool,
42
43 pub min_weight: f32,
45
46 pub max_weight: f32,
48
49 pub learning_rate: f32,
51}
52
53pub struct WeightCalculator {
55 content_analyzer: ContentAnalyzer,
57
58 quality_assessor: QualityAssessor,
60}
61
62pub struct DimensionNormalizer {
64 target_dim: usize,
66
67 strategy: NormalizationStrategy,
69}
70
71pub struct AttentionMechanism {
73 attention_weights: HashMap<String, Vec<f32>>,
75
76 query_projection: AttentionProjection,
78
79 key_projection: AttentionProjection,
81
82 value_projection: AttentionProjection,
84}
85
86pub struct ContentAnalyzer {
88 text_scorer: TextImportanceScorer,
90
91 visual_scorer: VisualImportanceScorer,
93
94 table_scorer: TableImportanceScorer,
96}
97
98pub struct QualityAssessor {
100 quality_metrics: Vec<QualityMetric>,
102}
103
104#[derive(Debug, Clone, Copy)]
106pub enum NormalizationStrategy {
107 L2Norm,
109
110 MinMax,
112
113 ZScore,
115
116 LinearProjection,
118
119 PCA,
121}
122
123#[derive(Debug, Clone)]
125pub struct AttentionProjection {
126 pub weights: Vec<Vec<f32>>,
128
129 pub bias: Vec<f32>,
131}
132
133pub struct TextImportanceScorer {
135 tfidf_calculator: TfIdfCalculator,
137
138 ner: NamedEntityRecognizer,
140}
141
142pub struct VisualImportanceScorer {
144 saliency_detector: SaliencyDetector,
146
147 aesthetic_analyzer: AestheticAnalyzer,
149}
150
151pub struct TableImportanceScorer {
153 density_calculator: InformationDensityCalculator,
155}
156
157#[derive(Debug, Clone)]
159pub struct QualityMetric {
160 pub name: String,
162
163 pub weight: f32,
165
166 pub metric_type: QualityMetricType,
168}
169
170#[derive(Debug, Clone, Copy)]
172pub enum QualityMetricType {
173 EmbeddingNorm,
175
176 Variance,
178
179 Coherence,
181
182 Distinctiveness,
184}
185
186pub struct TfIdfCalculator {
188 document_frequencies: HashMap<String, usize>,
190
191 total_documents: usize,
193}
194
195pub struct NamedEntityRecognizer;
197
198pub struct SaliencyDetector;
200
201pub struct AestheticAnalyzer;
203
204pub struct InformationDensityCalculator;
206
207#[derive(Debug, Clone)]
209pub struct FusionResult {
210 pub fused_embedding: Vec<f32>,
212
213 pub weights: EmbeddingWeights,
215
216 pub confidence: f32,
218
219 pub modality_scores: ModalityScores,
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct ModalityScores {
226 pub text_score: f32,
228
229 pub visual_score: f32,
231
232 pub table_score: f32,
234
235 pub chart_score: f32,
237}
238
239impl DefaultFusionStrategy {
240 pub fn new(strategy: FusionStrategy) -> RragResult<Self> {
242 let config = FusionConfig::default();
243 let weight_calculator = WeightCalculator::new()?;
244 let dimension_normalizer = DimensionNormalizer::new(config.target_dimension);
245
246 let attention_mechanism = if matches!(strategy, FusionStrategy::Attention) {
247 Some(AttentionMechanism::new(config.target_dimension)?)
248 } else {
249 None
250 };
251
252 Ok(Self {
253 strategy,
254 config,
255 weight_calculator,
256 dimension_normalizer,
257 attention_mechanism,
258 })
259 }
260
261 pub fn fuse_embeddings_detailed(
263 &self,
264 document: &MultiModalDocument,
265 ) -> RragResult<FusionResult> {
266 let weights = if self.config.adaptive_weights {
268 self.calculate_weights(document)?
269 } else {
270 document.embeddings.weights.clone()
271 };
272
273 let modality_scores = self.calculate_modality_scores(document)?;
275
276 let normalized_embeddings = self.normalize_embeddings(&document.embeddings)?;
278
279 let fused_embedding = match self.strategy {
281 FusionStrategy::Average => self.fuse_average(&normalized_embeddings, &weights)?,
282 FusionStrategy::Weighted => self.fuse_weighted(&normalized_embeddings, &weights)?,
283 FusionStrategy::Concatenate => self.fuse_concatenate(&normalized_embeddings)?,
284 FusionStrategy::Attention => self.fuse_attention(&normalized_embeddings, &weights)?,
285 FusionStrategy::Learned => self.fuse_learned(&normalized_embeddings, &weights)?,
286 };
287
288 let confidence = self.calculate_fusion_confidence(&fused_embedding, &modality_scores)?;
290
291 Ok(FusionResult {
292 fused_embedding,
293 weights,
294 confidence,
295 modality_scores,
296 })
297 }
298
299 fn normalize_embeddings(
301 &self,
302 embeddings: &MultiModalEmbeddings,
303 ) -> RragResult<NormalizedEmbeddings> {
304 let text_normalized = self
305 .dimension_normalizer
306 .normalize(&embeddings.text_embeddings)?;
307
308 let visual_normalized = if let Some(ref visual) = embeddings.visual_embeddings {
309 Some(self.dimension_normalizer.normalize(visual)?)
310 } else {
311 None
312 };
313
314 let table_normalized = if let Some(ref table) = embeddings.table_embeddings {
315 Some(self.dimension_normalizer.normalize(table)?)
316 } else {
317 None
318 };
319
320 Ok(NormalizedEmbeddings {
321 text: text_normalized,
322 visual: visual_normalized,
323 table: table_normalized,
324 })
325 }
326
327 fn fuse_average(
329 &self,
330 embeddings: &NormalizedEmbeddings,
331 _weights: &EmbeddingWeights,
332 ) -> RragResult<Vec<f32>> {
333 let mut fused = embeddings.text.clone();
334 let mut count = 1;
335
336 if let Some(ref visual) = embeddings.visual {
337 for (i, &val) in visual.iter().enumerate() {
338 if i < fused.len() {
339 fused[i] += val;
340 }
341 }
342 count += 1;
343 }
344
345 if let Some(ref table) = embeddings.table {
346 for (i, &val) in table.iter().enumerate() {
347 if i < fused.len() {
348 fused[i] += val;
349 }
350 }
351 count += 1;
352 }
353
354 for val in &mut fused {
356 *val /= count as f32;
357 }
358
359 Ok(fused)
360 }
361
362 fn fuse_weighted(
364 &self,
365 embeddings: &NormalizedEmbeddings,
366 weights: &EmbeddingWeights,
367 ) -> RragResult<Vec<f32>> {
368 let mut fused = vec![0.0; self.config.target_dimension];
369
370 for (i, &val) in embeddings.text.iter().enumerate() {
372 if i < fused.len() {
373 fused[i] += val * weights.text_weight;
374 }
375 }
376
377 if let Some(ref visual) = embeddings.visual {
378 for (i, &val) in visual.iter().enumerate() {
379 if i < fused.len() {
380 fused[i] += val * weights.visual_weight;
381 }
382 }
383 }
384
385 if let Some(ref table) = embeddings.table {
386 for (i, &val) in table.iter().enumerate() {
387 if i < fused.len() {
388 fused[i] += val * weights.table_weight;
389 }
390 }
391 }
392
393 if self.config.normalize_embeddings {
395 self.l2_normalize(&mut fused);
396 }
397
398 Ok(fused)
399 }
400
401 fn fuse_concatenate(&self, embeddings: &NormalizedEmbeddings) -> RragResult<Vec<f32>> {
403 let mut fused = embeddings.text.clone();
404
405 if let Some(ref visual) = embeddings.visual {
406 fused.extend_from_slice(visual);
407 }
408
409 if let Some(ref table) = embeddings.table {
410 fused.extend_from_slice(table);
411 }
412
413 if fused.len() > self.config.target_dimension {
415 fused.truncate(self.config.target_dimension);
416 } else if fused.len() < self.config.target_dimension {
417 fused.resize(self.config.target_dimension, 0.0);
418 }
419
420 Ok(fused)
421 }
422
423 fn fuse_attention(
425 &self,
426 embeddings: &NormalizedEmbeddings,
427 _weights: &EmbeddingWeights,
428 ) -> RragResult<Vec<f32>> {
429 if let Some(ref attention) = self.attention_mechanism {
430 attention.apply_attention(embeddings)
431 } else {
432 self.fuse_weighted(embeddings, _weights)
434 }
435 }
436
437 fn fuse_learned(
439 &self,
440 embeddings: &NormalizedEmbeddings,
441 weights: &EmbeddingWeights,
442 ) -> RragResult<Vec<f32>> {
443 self.fuse_weighted(embeddings, weights)
445 }
446
447 fn calculate_modality_scores(
449 &self,
450 document: &MultiModalDocument,
451 ) -> RragResult<ModalityScores> {
452 let text_score = self
453 .weight_calculator
454 .content_analyzer
455 .text_scorer
456 .calculate_text_score(&document.text_content)?;
457
458 let visual_score = if !document.images.is_empty() {
459 self.weight_calculator
460 .content_analyzer
461 .visual_scorer
462 .calculate_visual_score(&document.images)?
463 } else {
464 0.0
465 };
466
467 let table_score = if !document.tables.is_empty() {
468 self.weight_calculator
469 .content_analyzer
470 .table_scorer
471 .calculate_table_score(&document.tables)?
472 } else {
473 0.0
474 };
475
476 let chart_score = if !document.charts.is_empty() {
477 0.7
479 } else {
480 0.0
481 };
482
483 Ok(ModalityScores {
484 text_score,
485 visual_score,
486 table_score,
487 chart_score,
488 })
489 }
490
491 fn calculate_fusion_confidence(
493 &self,
494 _fused_embedding: &[f32],
495 scores: &ModalityScores,
496 ) -> RragResult<f32> {
497 let mut confidence = 0.0;
499 let mut active_modalities = 0;
500
501 if scores.text_score > 0.0 {
502 confidence += scores.text_score * 0.4;
503 active_modalities += 1;
504 }
505
506 if scores.visual_score > 0.0 {
507 confidence += scores.visual_score * 0.3;
508 active_modalities += 1;
509 }
510
511 if scores.table_score > 0.0 {
512 confidence += scores.table_score * 0.2;
513 active_modalities += 1;
514 }
515
516 if scores.chart_score > 0.0 {
517 confidence += scores.chart_score * 0.1;
518 active_modalities += 1;
519 }
520
521 if active_modalities > 1 {
523 confidence *= 1.0 + (active_modalities as f32 - 1.0) * 0.1;
524 }
525
526 Ok(confidence.min(1.0))
527 }
528
529 fn l2_normalize(&self, vector: &mut [f32]) {
531 let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
532 if norm > 0.0 {
533 for val in vector {
534 *val /= norm;
535 }
536 }
537 }
538}
539
540impl EmbeddingFusionStrategy for DefaultFusionStrategy {
541 fn fuse_embeddings(&self, embeddings: &MultiModalEmbeddings) -> RragResult<Vec<f32>> {
542 match self.strategy {
543 FusionStrategy::Average => {
544 let mut fused = embeddings.text_embeddings.clone();
545 let mut count = 1;
546
547 if let Some(ref visual) = embeddings.visual_embeddings {
548 for (i, &val) in visual.iter().enumerate() {
549 if i < fused.len() {
550 fused[i] += val;
551 }
552 }
553 count += 1;
554 }
555
556 for val in &mut fused {
557 *val /= count as f32;
558 }
559
560 Ok(fused)
561 }
562
563 FusionStrategy::Weighted => {
564 let mut fused = vec![0.0; embeddings.text_embeddings.len()];
565 let weights = &embeddings.weights;
566
567 for (i, &val) in embeddings.text_embeddings.iter().enumerate() {
568 fused[i] += val * weights.text_weight;
569 }
570
571 if let Some(ref visual) = embeddings.visual_embeddings {
572 for (i, &val) in visual.iter().enumerate() {
573 if i < fused.len() {
574 fused[i] += val * weights.visual_weight;
575 }
576 }
577 }
578
579 Ok(fused)
580 }
581
582 _ => {
583 self.fuse_embeddings(&MultiModalEmbeddings {
585 text_embeddings: embeddings.text_embeddings.clone(),
586 visual_embeddings: embeddings.visual_embeddings.clone(),
587 table_embeddings: embeddings.table_embeddings.clone(),
588 fused_embedding: vec![],
589 weights: EmbeddingWeights {
590 text_weight: 0.6,
591 visual_weight: 0.3,
592 table_weight: 0.1,
593 chart_weight: 0.0,
594 },
595 })
596 }
597 }
598 }
599
600 fn calculate_weights(&self, document: &MultiModalDocument) -> RragResult<EmbeddingWeights> {
601 self.weight_calculator.calculate_weights(document)
602 }
603}
604
605#[derive(Debug, Clone)]
607pub struct NormalizedEmbeddings {
608 text: Vec<f32>,
609 visual: Option<Vec<f32>>,
610 table: Option<Vec<f32>>,
611}
612
613impl WeightCalculator {
614 pub fn new() -> RragResult<Self> {
616 Ok(Self {
617 content_analyzer: ContentAnalyzer::new()?,
618 quality_assessor: QualityAssessor::new(),
619 })
620 }
621
622 pub fn calculate_weights(&self, document: &MultiModalDocument) -> RragResult<EmbeddingWeights> {
624 let scores = self.content_analyzer.analyze_content(document)?;
625 let quality_scores = self.quality_assessor.assess_quality(&document.embeddings)?;
626
627 let text_weight = scores.text_importance * quality_scores.text_quality;
629 let visual_weight = scores.visual_importance * quality_scores.visual_quality;
630 let table_weight = scores.table_importance * quality_scores.table_quality;
631 let chart_weight = scores.chart_importance * quality_scores.chart_quality;
632
633 let total = text_weight + visual_weight + table_weight + chart_weight;
635
636 if total > 0.0 {
637 Ok(EmbeddingWeights {
638 text_weight: text_weight / total,
639 visual_weight: visual_weight / total,
640 table_weight: table_weight / total,
641 chart_weight: chart_weight / total,
642 })
643 } else {
644 Ok(EmbeddingWeights {
646 text_weight: 0.6,
647 visual_weight: 0.2,
648 table_weight: 0.1,
649 chart_weight: 0.1,
650 })
651 }
652 }
653}
654
655impl DimensionNormalizer {
656 pub fn new(target_dim: usize) -> Self {
658 Self {
659 target_dim,
660 strategy: NormalizationStrategy::LinearProjection,
661 }
662 }
663
664 pub fn normalize(&self, embedding: &[f32]) -> RragResult<Vec<f32>> {
666 match self.strategy {
667 NormalizationStrategy::LinearProjection => {
668 if embedding.len() == self.target_dim {
669 Ok(embedding.to_vec())
670 } else if embedding.len() > self.target_dim {
671 Ok(embedding[..self.target_dim].to_vec())
673 } else {
674 let mut normalized = embedding.to_vec();
676 normalized.resize(self.target_dim, 0.0);
677 Ok(normalized)
678 }
679 }
680
681 NormalizationStrategy::L2Norm => {
682 let mut normalized = embedding.to_vec();
683 let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
684 if norm > 0.0 {
685 for val in &mut normalized {
686 *val /= norm;
687 }
688 }
689
690 if normalized.len() != self.target_dim {
692 normalized.resize(self.target_dim, 0.0);
693 }
694
695 Ok(normalized)
696 }
697
698 _ => {
699 self.normalize(embedding)
701 }
702 }
703 }
704}
705
706impl AttentionMechanism {
707 pub fn new(dim: usize) -> RragResult<Self> {
709 Ok(Self {
710 attention_weights: HashMap::new(),
711 query_projection: AttentionProjection::new(dim, dim)?,
712 key_projection: AttentionProjection::new(dim, dim)?,
713 value_projection: AttentionProjection::new(dim, dim)?,
714 })
715 }
716
717 pub fn apply_attention(&self, embeddings: &NormalizedEmbeddings) -> RragResult<Vec<f32>> {
719 let query = &embeddings.text;
723 let mut attended = query.clone();
724
725 if let Some(ref visual) = embeddings.visual {
726 let attention_score = self.compute_attention_score(query, visual)?;
727 for (i, &val) in visual.iter().enumerate() {
728 if i < attended.len() {
729 attended[i] += val * attention_score;
730 }
731 }
732 }
733
734 if let Some(ref table) = embeddings.table {
735 let attention_score = self.compute_attention_score(query, table)?;
736 for (i, &val) in table.iter().enumerate() {
737 if i < attended.len() {
738 attended[i] += val * attention_score;
739 }
740 }
741 }
742
743 Ok(attended)
744 }
745
746 fn compute_attention_score(&self, query: &[f32], key: &[f32]) -> RragResult<f32> {
748 let score: f32 = query.iter().zip(key.iter()).map(|(q, k)| q * k).sum();
750
751 let normalized_score = score / (query.len() as f32).sqrt();
753
754 Ok(normalized_score.exp() / (1.0 + normalized_score.exp()))
756 }
757}
758
759impl AttentionProjection {
760 pub fn new(input_dim: usize, output_dim: usize) -> RragResult<Self> {
762 let weights = vec![vec![0.01; input_dim]; output_dim];
764 let bias = vec![0.0; output_dim];
765
766 Ok(Self { weights, bias })
767 }
768}
769
770impl ContentAnalyzer {
771 pub fn new() -> RragResult<Self> {
773 Ok(Self {
774 text_scorer: TextImportanceScorer::new()?,
775 visual_scorer: VisualImportanceScorer::new(),
776 table_scorer: TableImportanceScorer::new(),
777 })
778 }
779
780 pub fn analyze_content(&self, document: &MultiModalDocument) -> RragResult<ContentScores> {
782 let text_importance = self
783 .text_scorer
784 .calculate_text_score(&document.text_content)?;
785 let visual_importance = self
786 .visual_scorer
787 .calculate_visual_score(&document.images)?;
788 let table_importance = self.table_scorer.calculate_table_score(&document.tables)?;
789 let chart_importance = if !document.charts.is_empty() {
790 0.7
791 } else {
792 0.0
793 };
794
795 Ok(ContentScores {
796 text_importance,
797 visual_importance,
798 table_importance,
799 chart_importance,
800 })
801 }
802}
803
804#[derive(Debug, Clone)]
806pub struct ContentScores {
807 pub text_importance: f32,
808 pub visual_importance: f32,
809 pub table_importance: f32,
810 pub chart_importance: f32,
811}
812
813#[derive(Debug, Clone)]
815pub struct QualityScores {
816 pub text_quality: f32,
817 pub visual_quality: f32,
818 pub table_quality: f32,
819 pub chart_quality: f32,
820}
821
822impl TextImportanceScorer {
823 pub fn new() -> RragResult<Self> {
824 Ok(Self {
825 tfidf_calculator: TfIdfCalculator::new(),
826 ner: NamedEntityRecognizer,
827 })
828 }
829
830 pub fn calculate_text_score(&self, text: &str) -> RragResult<f32> {
831 let word_count = text.split_whitespace().count();
832 let entity_score = self.ner.calculate_entity_score(text)?;
833
834 let length_score = (word_count as f32 / 1000.0).min(1.0);
836 Ok(length_score * 0.7 + entity_score * 0.3)
837 }
838}
839
840impl VisualImportanceScorer {
841 pub fn new() -> Self {
842 Self {
843 saliency_detector: SaliencyDetector,
844 aesthetic_analyzer: AestheticAnalyzer,
845 }
846 }
847
848 pub fn calculate_visual_score(&self, images: &[ProcessedImage]) -> RragResult<f32> {
849 if images.is_empty() {
850 return Ok(0.0);
851 }
852
853 let mut total_score = 0.0;
854 for image in images {
855 let quality_score = image
856 .features
857 .as_ref()
858 .map(|f| (f.quality.sharpness + f.quality.contrast) / 2.0)
859 .unwrap_or(0.5);
860
861 let aesthetic_score = self.aesthetic_analyzer.analyze_aesthetics(image)?;
862 total_score += quality_score * 0.6 + aesthetic_score * 0.4;
863 }
864
865 Ok(total_score / images.len() as f32)
866 }
867}
868
869impl TableImportanceScorer {
870 pub fn new() -> Self {
871 Self {
872 density_calculator: InformationDensityCalculator,
873 }
874 }
875
876 pub fn calculate_table_score(&self, tables: &[ExtractedTable]) -> RragResult<f32> {
877 if tables.is_empty() {
878 return Ok(0.0);
879 }
880
881 let mut total_score = 0.0;
882 for table in tables {
883 let size_score = (table.rows.len() * table.headers.len()) as f32 / 100.0;
884 let density_score = self.density_calculator.calculate_density(table)?;
885 total_score += size_score.min(1.0) * 0.5 + density_score * 0.5;
886 }
887
888 Ok(total_score / tables.len() as f32)
889 }
890}
891
892impl QualityAssessor {
893 pub fn new() -> Self {
894 Self {
895 quality_metrics: vec![
896 QualityMetric {
897 name: "norm".to_string(),
898 weight: 0.3,
899 metric_type: QualityMetricType::EmbeddingNorm,
900 },
901 QualityMetric {
902 name: "variance".to_string(),
903 weight: 0.4,
904 metric_type: QualityMetricType::Variance,
905 },
906 QualityMetric {
907 name: "coherence".to_string(),
908 weight: 0.3,
909 metric_type: QualityMetricType::Coherence,
910 },
911 ],
912 }
913 }
914
915 pub fn assess_quality(&self, embeddings: &MultiModalEmbeddings) -> RragResult<QualityScores> {
916 let text_quality = self.calculate_embedding_quality(&embeddings.text_embeddings)?;
917
918 let visual_quality = if let Some(ref visual) = embeddings.visual_embeddings {
919 self.calculate_embedding_quality(visual)?
920 } else {
921 0.0
922 };
923
924 let table_quality = if let Some(ref table) = embeddings.table_embeddings {
925 self.calculate_embedding_quality(table)?
926 } else {
927 0.0
928 };
929
930 Ok(QualityScores {
931 text_quality,
932 visual_quality,
933 table_quality,
934 chart_quality: 0.7, })
936 }
937
938 fn calculate_embedding_quality(&self, embedding: &[f32]) -> RragResult<f32> {
939 let mut quality_score = 0.0;
940
941 for metric in &self.quality_metrics {
942 let score = match metric.metric_type {
943 QualityMetricType::EmbeddingNorm => {
944 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
945 (norm / embedding.len() as f32).min(1.0)
946 }
947 QualityMetricType::Variance => {
948 let mean = embedding.iter().sum::<f32>() / embedding.len() as f32;
949 let variance = embedding.iter().map(|x| (x - mean).powi(2)).sum::<f32>()
950 / embedding.len() as f32;
951 variance.min(1.0)
952 }
953 QualityMetricType::Coherence => 0.8, QualityMetricType::Distinctiveness => 0.7, };
956
957 quality_score += score * metric.weight;
958 }
959
960 Ok(quality_score)
961 }
962}
963
964impl TfIdfCalculator {
966 pub fn new() -> Self {
967 Self {
968 document_frequencies: HashMap::new(),
969 total_documents: 0,
970 }
971 }
972}
973
974impl NamedEntityRecognizer {
975 pub fn calculate_entity_score(&self, _text: &str) -> RragResult<f32> {
976 Ok(0.6)
978 }
979}
980
981impl SaliencyDetector {}
982
983impl AestheticAnalyzer {
984 pub fn analyze_aesthetics(&self, _image: &ProcessedImage) -> RragResult<f32> {
985 Ok(0.7)
987 }
988}
989
990impl InformationDensityCalculator {
991 pub fn calculate_density(&self, table: &ExtractedTable) -> RragResult<f32> {
992 let total_cells = table.rows.len() * table.headers.len();
993 let filled_cells = table
994 .rows
995 .iter()
996 .flatten()
997 .filter(|cell| !cell.value.trim().is_empty())
998 .count();
999
1000 Ok(filled_cells as f32 / total_cells as f32)
1001 }
1002}
1003
1004impl Default for FusionConfig {
1005 fn default() -> Self {
1006 Self {
1007 target_dimension: 768,
1008 normalize_embeddings: true,
1009 adaptive_weights: true,
1010 min_weight: 0.01,
1011 max_weight: 0.99,
1012 learning_rate: 0.001,
1013 }
1014 }
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019 use super::*;
1020
1021 #[test]
1022 fn test_fusion_strategy_creation() {
1023 let strategy = DefaultFusionStrategy::new(FusionStrategy::Weighted).unwrap();
1024 assert!(matches!(strategy.strategy, FusionStrategy::Weighted));
1025 }
1026
1027 #[test]
1028 fn test_dimension_normalization() {
1029 let normalizer = DimensionNormalizer::new(512);
1030
1031 let embedding = vec![1.0, 2.0, 3.0];
1032 let normalized = normalizer.normalize(&embedding).unwrap();
1033
1034 assert_eq!(normalized.len(), 512);
1035 assert_eq!(&normalized[..3], &[1.0, 2.0, 3.0]);
1036 }
1037
1038 #[test]
1039 fn test_weight_calculation() {
1040 let calculator = WeightCalculator::new().unwrap();
1041
1042 let document = MultiModalDocument {
1044 id: "test".to_string(),
1045 text_content: "Test content".to_string(),
1046 images: vec![],
1047 tables: vec![],
1048 charts: vec![],
1049 layout: super::super::DocumentLayout {
1050 pages: 1,
1051 sections: vec![],
1052 reading_order: vec![],
1053 columns: None,
1054 document_type: super::super::DocumentType::PlainText,
1055 },
1056 embeddings: MultiModalEmbeddings {
1057 text_embeddings: vec![0.1, 0.2, 0.3],
1058 visual_embeddings: None,
1059 table_embeddings: None,
1060 fused_embedding: vec![],
1061 weights: EmbeddingWeights {
1062 text_weight: 1.0,
1063 visual_weight: 0.0,
1064 table_weight: 0.0,
1065 chart_weight: 0.0,
1066 },
1067 },
1068 metadata: super::super::DocumentMetadata {
1069 title: None,
1070 author: None,
1071 creation_date: None,
1072 modification_date: None,
1073 page_count: 1,
1074 word_count: 2,
1075 language: "en".to_string(),
1076 format: super::super::DocumentType::PlainText,
1077 },
1078 };
1079
1080 let weights = calculator.calculate_weights(&document).unwrap();
1081 assert!(weights.text_weight > 0.0);
1082 }
1083}