oxirs_embed/vision_language_graph/
transformer.rs1use super::*;
4use anyhow::Result;
5use scirs2_core::ndarray_ext::{Array1, Array2};
6use scirs2_core::random::{Random, Rng};
7use std::collections::HashMap;
8#[derive(Debug)]
9pub struct MultiModalTransformer {
10 pub config: MultiModalTransformerConfig,
11 pub cross_attention_params: HashMap<String, Array2<f32>>,
13 pub fusion_params: HashMap<String, Array2<f32>>,
15 pub modality_embeddings: Array2<f32>,
17}
18
19impl MultiModalTransformer {
20 pub fn new(config: MultiModalTransformerConfig) -> Self {
21 let mut cross_attention_params = HashMap::new();
22 let mut fusion_params = HashMap::new();
23
24 for layer in 0..config.num_fusion_layers {
26 for modality_pair in &["vision_language", "language_graph", "vision_graph"] {
27 let mut random = Random::default();
28 cross_attention_params.insert(
29 format!("{modality_pair}_{layer}"),
30 Array2::from_shape_fn((config.unified_dim, config.unified_dim), |_| {
31 (random.random::<f32>() - 0.5) * 0.1
32 }),
33 );
34 }
35 }
36
37 let mut random = Random::default();
39 fusion_params.insert(
40 "tri_modal_fusion".to_string(),
41 Array2::from_shape_fn((config.unified_dim, config.unified_dim * 3), |_| {
42 (random.random::<f32>() - 0.5) * 0.1
43 }),
44 );
45
46 let mut random = Random::default();
48 let modality_embeddings = Array2::from_shape_fn(
49 (3, config.unified_dim), |_| (random.random::<f32>() - 0.5) * 0.1,
51 );
52
53 Self {
54 config,
55 cross_attention_params,
56 fusion_params,
57 modality_embeddings,
58 }
59 }
60
61 pub fn fuse_embeddings(
63 &self,
64 vision_emb: &Array1<f32>,
65 language_emb: &Array1<f32>,
66 graph_emb: &Array1<f32>,
67 ) -> Result<Array1<f32>> {
68 match self.config.fusion_strategy {
69 FusionStrategy::EarlyFusion => self.early_fusion(vision_emb, language_emb, graph_emb),
70 FusionStrategy::CrossAttention => {
71 self.cross_attention_fusion(vision_emb, language_emb, graph_emb)
72 }
73 FusionStrategy::TensorFusion => self.tensor_fusion(vision_emb, language_emb, graph_emb),
74 _ => self.early_fusion(vision_emb, language_emb, graph_emb),
75 }
76 }
77
78 fn early_fusion(
80 &self,
81 vision_emb: &Array1<f32>,
82 language_emb: &Array1<f32>,
83 graph_emb: &Array1<f32>,
84 ) -> Result<Array1<f32>> {
85 let mut concatenated = Vec::new();
86 concatenated.extend_from_slice(vision_emb.as_slice().unwrap());
87 concatenated.extend_from_slice(language_emb.as_slice().unwrap());
88 concatenated.extend_from_slice(graph_emb.as_slice().unwrap());
89
90 let concat_array = Array1::from_vec(concatenated);
91
92 if let Some(fusion_matrix) = self.fusion_params.get("tri_modal_fusion") {
93 Ok(fusion_matrix.dot(&concat_array))
94 } else {
95 let avg_len = vision_emb
97 .len()
98 .min(language_emb.len())
99 .min(graph_emb.len());
100 let mut averaged = Array1::zeros(avg_len);
101
102 for i in 0..avg_len {
103 averaged[i] = (vision_emb[i] + language_emb[i] + graph_emb[i]) / 3.0;
104 }
105
106 Ok(averaged)
107 }
108 }
109
110 fn cross_attention_fusion(
112 &self,
113 vision_emb: &Array1<f32>,
114 language_emb: &Array1<f32>,
115 graph_emb: &Array1<f32>,
116 ) -> Result<Array1<f32>> {
117 let mut fused = vision_emb.clone();
119
120 if let Some(vl_attention) = self.cross_attention_params.get("vision_language_0") {
122 let vl_attended = vl_attention.dot(language_emb);
123 fused = &fused + &vl_attended;
124 }
125
126 if let Some(vg_attention) = self.cross_attention_params.get("vision_graph_0") {
128 let vg_attended = vg_attention.dot(graph_emb);
129 fused = &fused + &vg_attended;
130 }
131
132 let norm = fused.dot(&fused).sqrt();
134 if norm > 0.0 {
135 fused /= norm;
136 }
137
138 Ok(fused)
139 }
140
141 fn tensor_fusion(
143 &self,
144 vision_emb: &Array1<f32>,
145 language_emb: &Array1<f32>,
146 graph_emb: &Array1<f32>,
147 ) -> Result<Array1<f32>> {
148 let min_dim = vision_emb
150 .len()
151 .min(language_emb.len())
152 .min(graph_emb.len());
153 let mut fused = Array1::zeros(min_dim);
154
155 for i in 0..min_dim {
156 fused[i] = vision_emb[i] * language_emb[i] * graph_emb[i];
157 }
158
159 Ok(fused)
160 }
161}