oxirs_embed/vision_language_graph/
encoders.rs1use super::*;
4use anyhow::Result;
5use scirs2_core::ndarray_ext::{Array1, Array2, Array3, Array4, Axis};
6use scirs2_core::random::{Random, Rng};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
10pub struct VisionEncoder {
11 pub config: VisionEncoderConfig,
12 pub cnn_parameters: HashMap<String, Array4<f32>>,
14 pub vit_parameters: HashMap<String, Array2<f32>>,
16 pub projection: Array2<f32>,
18}
19
20impl VisionEncoder {
21 pub fn new(config: VisionEncoderConfig) -> Self {
22 let mut cnn_parameters = HashMap::new();
23 let mut vit_parameters = HashMap::new();
24
25 for (i, &filter_size) in config.cnn_config.filter_sizes.iter().enumerate() {
27 let layer_name = format!("conv_{i}");
28 let weight_shape = (
29 filter_size,
30 if i == 0 {
31 config.channels
32 } else {
33 config.cnn_config.filter_sizes[i - 1]
34 },
35 3,
36 3,
37 );
38 let mut random = Random::default();
39 cnn_parameters.insert(
40 layer_name,
41 Array4::from_shape_fn(weight_shape, |_| (random.random::<f32>() - 0.5) * 0.1),
42 );
43 }
44
45 let mut random = Random::default();
47 vit_parameters.insert(
48 "patch_embedding".to_string(),
49 Array2::from_shape_fn(
50 (
51 config.channels * config.patch_size.0 * config.patch_size.1,
52 config.vision_dim,
53 ),
54 |_| (random.random::<f32>() - 0.5) * 0.1,
55 ),
56 );
57
58 let mut random = Random::default();
60 let projection = Array2::from_shape_fn((config.vision_dim, config.vision_dim), |_| {
61 (random.random::<f32>() - 0.5) * 0.1
62 });
63
64 Self {
65 config,
66 cnn_parameters,
67 vit_parameters,
68 projection,
69 }
70 }
71
72 pub fn encode_image(&self, image: &Array3<f32>) -> Result<Array1<f32>> {
74 match self.config.architecture {
75 VisionArchitecture::VisionTransformer => self.encode_with_vit(image),
76 VisionArchitecture::ResNet => self.encode_with_cnn(image),
77 _ => self.encode_with_vit(image), }
79 }
80
81 fn encode_with_vit(&self, image: &Array3<f32>) -> Result<Array1<f32>> {
83 let (h, w, c) = image.dim();
85 let (patch_h, patch_w) = self.config.patch_size;
86
87 let num_patches_h = h / patch_h;
88 let num_patches_w = w / patch_w;
89 let num_patches = num_patches_h * num_patches_w;
90
91 let mut patch_embeddings = Array2::zeros((num_patches, self.config.vision_dim));
93
94 for i in 0..num_patches_h {
95 for j in 0..num_patches_w {
96 let patch_idx = i * num_patches_w + j;
97
98 let patch = image.slice(scirs2_core::ndarray_ext::s![
100 i * patch_h..(i + 1) * patch_h,
101 j * patch_w..(j + 1) * patch_w,
102 ..
103 ]);
104
105 let patch_owned = patch.to_owned();
107 let flattened_patch = patch_owned
108 .into_shape_with_order(c * patch_h * patch_w)
109 .unwrap();
110
111 if let Some(patch_embedding_matrix) = self.vit_parameters.get("patch_embedding") {
113 let embedding = flattened_patch.dot(patch_embedding_matrix);
114 patch_embeddings.row_mut(patch_idx).assign(&embedding);
115 }
116 }
117 }
118
119 let global_embedding = patch_embeddings.mean_axis(Axis(0)).unwrap();
121
122 Ok(global_embedding)
123 }
124
125 fn encode_with_cnn(&self, image: &Array3<f32>) -> Result<Array1<f32>> {
127 let mut features = image.clone();
129
130 for i in 0..self.config.cnn_config.num_layers.min(2) {
132 let (h, w, c) = features.dim();
135 let new_h = h / 2; let new_w = w / 2;
137 let new_c = self.config.cnn_config.filter_sizes[i];
138
139 let mut new_features = Array3::zeros((new_h, new_w, new_c));
140
141 for new_i in 0..new_h {
143 for new_j in 0..new_w {
144 for new_k in 0..new_c {
145 let old_i = new_i * 2;
146 let old_j = new_j * 2;
147
148 if old_i < h && old_j < w {
149 let mut sum = 0.0;
151 let mut count = 0;
152 for di in 0..2 {
153 for dj in 0..2 {
154 if old_i + di < h && old_j + dj < w {
155 for k in 0..c.min(new_c) {
156 sum += features[[old_i + di, old_j + dj, k]];
157 count += 1;
158 }
159 }
160 }
161 }
162 new_features[[new_i, new_j, new_k]] = sum / count as f32;
163 }
164 }
165 }
166 }
167
168 features = new_features;
169 }
170
171 let features_len = features.len();
173 let flattened = features.into_shape_with_order(features_len).unwrap();
174 let mut global_features = vec![0.0; self.config.vision_dim];
175
176 for i in 0..global_features.len().min(flattened.len()) {
177 global_features[i] = flattened[i];
178 }
179
180 Ok(Array1::from_vec(global_features))
181 }
182}
183
184#[derive(Debug, Clone)]
186pub struct LanguageEncoder {
187 pub config: LanguageEncoderConfig,
188 pub token_embeddings: Array2<f32>,
190 pub position_embeddings: Array2<f32>,
192 pub transformer_parameters: HashMap<String, Array2<f32>>,
194}
195
196impl LanguageEncoder {
197 pub fn new(config: LanguageEncoderConfig) -> Self {
198 let mut random = Random::default();
200 let token_embeddings =
201 Array2::from_shape_fn((config.vocab_size, config.language_dim), |_| {
202 (random.random::<f32>() - 0.5) * 0.1
203 });
204
205 let mut random = Random::default();
206 let position_embeddings =
207 Array2::from_shape_fn((config.max_seq_length, config.language_dim), |_| {
208 (random.random::<f32>() - 0.5) * 0.1
209 });
210
211 let mut transformer_parameters = HashMap::new();
212
213 for layer in 0..config.transformer_config.num_layers {
215 let mut random = Random::default();
216 transformer_parameters.insert(
217 format!("attention_weights_{layer}"),
218 Array2::from_shape_fn((config.language_dim, config.language_dim), |_| {
219 (random.random::<f32>() - 0.5) * 0.1
220 }),
221 );
222
223 let mut random = Random::default();
224 transformer_parameters.insert(
225 format!("feed_forward_{layer}"),
226 Array2::from_shape_fn(
227 (
228 config.transformer_config.intermediate_dim,
229 config.language_dim,
230 ),
231 |_| (random.random::<f32>() - 0.5) * 0.1,
232 ),
233 );
234 }
235
236 Self {
237 config,
238 token_embeddings,
239 position_embeddings,
240 transformer_parameters,
241 }
242 }
243
244 pub fn encode_text(&self, text: &str) -> Result<Array1<f32>> {
246 let tokens = self.tokenize(text);
248
249 let mut sequence_embeddings = Array2::zeros((tokens.len(), self.config.language_dim));
251
252 for (i, &token_id) in tokens.iter().enumerate() {
253 if token_id < self.token_embeddings.nrows() {
254 let token_emb = self.token_embeddings.row(token_id);
255 let pos_emb = self
256 .position_embeddings
257 .row(i.min(self.config.max_seq_length - 1));
258
259 let combined = &token_emb + &pos_emb;
261 sequence_embeddings.row_mut(i).assign(&combined);
262 }
263 }
264
265 let mut hidden_states = sequence_embeddings;
267
268 for layer in 0..self.config.transformer_config.num_layers.min(2) {
269 if let Some(attention_weights) = self
271 .transformer_parameters
272 .get(&format!("attention_weights_{layer}"))
273 {
274 hidden_states = hidden_states.dot(attention_weights);
276
277 for mut row in hidden_states.rows_mut() {
279 let mean = row.mean().unwrap_or(0.0);
280 let var = row.var(0.0);
281 row.mapv_inplace(|x| (x - mean) / (var + 1e-8).sqrt());
282 }
283 }
284 }
285
286 let sentence_embedding = hidden_states.mean_axis(Axis(0)).unwrap();
288
289 Ok(sentence_embedding)
290 }
291
292 fn tokenize(&self, text: &str) -> Vec<usize> {
294 text.split_whitespace()
295 .map(|word| {
296 let mut hash = 0usize;
298 for byte in word.bytes() {
299 hash = hash.wrapping_mul(31).wrapping_add(byte as usize);
300 }
301 hash % self.config.vocab_size
302 })
303 .collect()
304 }
305}
306
307#[derive(Debug, Clone)]
309pub struct GraphEncoder {
310 pub config: GraphEncoderConfig,
311 pub node_parameters: HashMap<String, Array2<f32>>,
313 pub edge_parameters: HashMap<String, Array2<f32>>,
315 pub graph_parameters: HashMap<String, Array2<f32>>,
317}
318
319impl GraphEncoder {
320 pub fn new(config: GraphEncoderConfig) -> Self {
321 let mut node_parameters = HashMap::new();
322 let mut edge_parameters = HashMap::new();
323 let mut graph_parameters = HashMap::new();
324
325 for layer in 0..config.num_layers {
327 let mut random = Random::default();
328 node_parameters.insert(
329 format!("node_transform_{layer}"),
330 Array2::from_shape_fn((config.node_dim, config.node_dim), |_| {
331 (random.random::<f32>() - 0.5) * 0.1
332 }),
333 );
334 }
335
336 for layer in 0..config.num_layers {
338 let mut random = Random::default();
339 edge_parameters.insert(
340 format!("edge_transform_{layer}"),
341 Array2::from_shape_fn((config.edge_dim, config.edge_dim), |_| {
342 (random.random::<f32>() - 0.5) * 0.1
343 }),
344 );
345 }
346
347 let mut random = Random::default();
349 graph_parameters.insert(
350 "readout".to_string(),
351 Array2::from_shape_fn(
352 (config.node_dim, 1), |_| (random.random::<f32>() - 0.5) * 0.1,
354 ),
355 );
356
357 let mut random = Random::default();
359 graph_parameters.insert(
360 "graph_projection".to_string(),
361 Array2::from_shape_fn((config.node_dim, config.graph_dim), |_| {
362 (random.random::<f32>() - 0.5) * 0.1
363 }),
364 );
365
366 Self {
367 config,
368 node_parameters,
369 edge_parameters,
370 graph_parameters,
371 }
372 }
373
374 pub fn encode_graph(
376 &self,
377 node_features: &Array2<f32>,
378 edge_features: &Array2<f32>,
379 adjacency_matrix: &Array2<f32>,
380 ) -> Result<Array1<f32>> {
381 let mut node_embeddings = node_features.clone();
382
383 for layer in 0..self.config.num_layers.min(2) {
385 node_embeddings =
387 self.apply_gnn_layer(&node_embeddings, edge_features, adjacency_matrix, layer)?;
388 }
389
390 let graph_embedding = self.graph_readout(&node_embeddings)?;
392
393 Ok(graph_embedding)
394 }
395
396 fn apply_gnn_layer(
398 &self,
399 node_embeddings: &Array2<f32>,
400 _edge_features: &Array2<f32>,
401 adjacency_matrix: &Array2<f32>,
402 layer: usize,
403 ) -> Result<Array2<f32>> {
404 let transform_key = format!("node_transform_{layer}");
405
406 if let Some(transform_matrix) = self.node_parameters.get(&transform_key) {
407 let aggregated = adjacency_matrix.dot(node_embeddings);
409
410 let transformed = aggregated.dot(transform_matrix);
412
413 let activated = transformed.mapv(|x| x.max(0.0));
415
416 Ok(activated)
417 } else {
418 Ok(node_embeddings.clone())
419 }
420 }
421
422 fn graph_readout(&self, node_embeddings: &Array2<f32>) -> Result<Array1<f32>> {
424 let node_level_embedding = match self.config.readout {
425 ReadoutFunction::GlobalMean => node_embeddings.mean_axis(Axis(0)).unwrap(),
426 ReadoutFunction::GlobalMax => {
427 node_embeddings.fold_axis(Axis(0), f32::NEG_INFINITY, |&a, &b| a.max(b))
428 }
429 ReadoutFunction::GlobalSum => node_embeddings.sum_axis(Axis(0)),
430 ReadoutFunction::GlobalAttention => {
431 if let Some(readout_matrix) = self.graph_parameters.get("readout") {
432 let attention_scores = node_embeddings.dot(readout_matrix); let attention_scores_1d = attention_scores.column(0).to_owned(); let attention_weights = self.softmax_1d(&attention_scores_1d); let mut weighted_sum = Array1::zeros(node_embeddings.ncols());
439 for (i, &weight) in attention_weights.iter().enumerate() {
440 let node_emb = node_embeddings.row(i);
441 weighted_sum = weighted_sum + weight * &node_emb;
442 }
443 weighted_sum
444 } else {
445 node_embeddings.mean_axis(Axis(0)).unwrap()
446 }
447 }
448 _ => node_embeddings.mean_axis(Axis(0)).unwrap(),
449 };
450
451 if let Some(projection_matrix) = self.graph_parameters.get("graph_projection") {
453 Ok(projection_matrix.t().dot(&node_level_embedding))
454 } else {
455 Ok(node_level_embedding)
456 }
457 }
458
459 fn softmax_2d(&self, x: &Array2<f32>) -> Array2<f32> {
461 let mut result = x.clone();
462 for mut row in result.rows_mut() {
463 let max_val = row.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
464 row.mapv_inplace(|v| (v - max_val).exp());
465 let sum = row.sum();
466 if sum > 0.0 {
467 row /= sum;
468 }
469 }
470 result
471 }
472
473 fn softmax_1d(&self, x: &Array1<f32>) -> Array1<f32> {
474 let max_val = x.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
475 let mut result = x.mapv(|v| (v - max_val).exp());
476 let sum = result.sum();
477 if sum > 0.0 {
478 result /= sum;
479 }
480 result
481 }
482}