oxirs_embed/vision_language_graph/
transformer.rs1use super::*;
4use anyhow::Result;
5use scirs2_core::ndarray_ext::{Array1, Array2};
6use scirs2_core::random::{Random, Rng};
7use std::collections::HashMap;
8#[derive(Debug)]
9pub struct MultiModalTransformer {
10 pub config: MultiModalTransformerConfig,
11 pub cross_attention_params: HashMap<String, Array2<f32>>,
13 pub fusion_params: HashMap<String, Array2<f32>>,
15 pub modality_embeddings: Array2<f32>,
17}
18
19impl MultiModalTransformer {
20 pub fn new(config: MultiModalTransformerConfig) -> Self {
21 let mut cross_attention_params = HashMap::new();
22 let mut fusion_params = HashMap::new();
23
24 for layer in 0..config.num_fusion_layers {
26 for modality_pair in &["vision_language", "language_graph", "vision_graph"] {
27 let mut random = Random::default();
28 cross_attention_params.insert(
29 format!("{modality_pair}_{layer}"),
30 Array2::from_shape_fn((config.unified_dim, config.unified_dim), |_| {
31 (random.random::<f32>() - 0.5) * 0.1
32 }),
33 );
34 }
35 }
36
37 let mut random = Random::default();
39 fusion_params.insert(
40 "tri_modal_fusion".to_string(),
41 Array2::from_shape_fn((config.unified_dim, config.unified_dim * 3), |_| {
42 (random.random::<f32>() - 0.5) * 0.1
43 }),
44 );
45
46 let mut random = Random::default();
48 let modality_embeddings = Array2::from_shape_fn(
49 (3, config.unified_dim), |_| (random.random::<f32>() - 0.5) * 0.1,
51 );
52
53 Self {
54 config,
55 cross_attention_params,
56 fusion_params,
57 modality_embeddings,
58 }
59 }
60
61 pub fn fuse_embeddings(
63 &self,
64 vision_emb: &Array1<f32>,
65 language_emb: &Array1<f32>,
66 graph_emb: &Array1<f32>,
67 ) -> Result<Array1<f32>> {
68 match self.config.fusion_strategy {
69 FusionStrategy::EarlyFusion => self.early_fusion(vision_emb, language_emb, graph_emb),
70 FusionStrategy::CrossAttention => {
71 self.cross_attention_fusion(vision_emb, language_emb, graph_emb)
72 }
73 FusionStrategy::TensorFusion => self.tensor_fusion(vision_emb, language_emb, graph_emb),
74 _ => self.early_fusion(vision_emb, language_emb, graph_emb),
75 }
76 }
77
78 fn early_fusion(
80 &self,
81 vision_emb: &Array1<f32>,
82 language_emb: &Array1<f32>,
83 graph_emb: &Array1<f32>,
84 ) -> Result<Array1<f32>> {
85 let mut concatenated = Vec::new();
86 concatenated.extend_from_slice(vision_emb.as_slice().expect("array should be contiguous"));
87 concatenated
88 .extend_from_slice(language_emb.as_slice().expect("array should be contiguous"));
89 concatenated.extend_from_slice(graph_emb.as_slice().expect("array should be contiguous"));
90
91 let concat_array = Array1::from_vec(concatenated);
92
93 if let Some(fusion_matrix) = self.fusion_params.get("tri_modal_fusion") {
94 Ok(fusion_matrix.dot(&concat_array))
95 } else {
96 let avg_len = vision_emb
98 .len()
99 .min(language_emb.len())
100 .min(graph_emb.len());
101 let mut averaged = Array1::zeros(avg_len);
102
103 for i in 0..avg_len {
104 averaged[i] = (vision_emb[i] + language_emb[i] + graph_emb[i]) / 3.0;
105 }
106
107 Ok(averaged)
108 }
109 }
110
111 fn cross_attention_fusion(
113 &self,
114 vision_emb: &Array1<f32>,
115 language_emb: &Array1<f32>,
116 graph_emb: &Array1<f32>,
117 ) -> Result<Array1<f32>> {
118 let mut fused = vision_emb.clone();
120
121 if let Some(vl_attention) = self.cross_attention_params.get("vision_language_0") {
123 let vl_attended = vl_attention.dot(language_emb);
124 fused = &fused + &vl_attended;
125 }
126
127 if let Some(vg_attention) = self.cross_attention_params.get("vision_graph_0") {
129 let vg_attended = vg_attention.dot(graph_emb);
130 fused = &fused + &vg_attended;
131 }
132
133 let norm = fused.dot(&fused).sqrt();
135 if norm > 0.0 {
136 fused /= norm;
137 }
138
139 Ok(fused)
140 }
141
142 fn tensor_fusion(
144 &self,
145 vision_emb: &Array1<f32>,
146 language_emb: &Array1<f32>,
147 graph_emb: &Array1<f32>,
148 ) -> Result<Array1<f32>> {
149 let min_dim = vision_emb
151 .len()
152 .min(language_emb.len())
153 .min(graph_emb.len());
154 let mut fused = Array1::zeros(min_dim);
155
156 for i in 0..min_dim {
157 fused[i] = vision_emb[i] * language_emb[i] * graph_emb[i];
158 }
159
160 Ok(fused)
161 }
162}