oxirs_embed/vision_language_graph/
encoders.rs1use super::*;
4use anyhow::Result;
5use scirs2_core::ndarray_ext::{Array1, Array2, Array3, Array4, Axis};
6use scirs2_core::random::{Random, Rng};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
10pub struct VisionEncoder {
11 pub config: VisionEncoderConfig,
12 pub cnn_parameters: HashMap<String, Array4<f32>>,
14 pub vit_parameters: HashMap<String, Array2<f32>>,
16 pub projection: Array2<f32>,
18}
19
20impl VisionEncoder {
21 pub fn new(config: VisionEncoderConfig) -> Self {
22 let mut cnn_parameters = HashMap::new();
23 let mut vit_parameters = HashMap::new();
24
25 for (i, &filter_size) in config.cnn_config.filter_sizes.iter().enumerate() {
27 let layer_name = format!("conv_{i}");
28 let weight_shape = (
29 filter_size,
30 if i == 0 {
31 config.channels
32 } else {
33 config.cnn_config.filter_sizes[i - 1]
34 },
35 3,
36 3,
37 );
38 let mut random = Random::default();
39 cnn_parameters.insert(
40 layer_name,
41 Array4::from_shape_fn(weight_shape, |_| (random.random::<f32>() - 0.5) * 0.1),
42 );
43 }
44
45 let mut random = Random::default();
47 vit_parameters.insert(
48 "patch_embedding".to_string(),
49 Array2::from_shape_fn(
50 (
51 config.channels * config.patch_size.0 * config.patch_size.1,
52 config.vision_dim,
53 ),
54 |_| (random.random::<f32>() - 0.5) * 0.1,
55 ),
56 );
57
58 let mut random = Random::default();
60 let projection = Array2::from_shape_fn((config.vision_dim, config.vision_dim), |_| {
61 (random.random::<f32>() - 0.5) * 0.1
62 });
63
64 Self {
65 config,
66 cnn_parameters,
67 vit_parameters,
68 projection,
69 }
70 }
71
72 pub fn encode_image(&self, image: &Array3<f32>) -> Result<Array1<f32>> {
74 match self.config.architecture {
75 VisionArchitecture::VisionTransformer => self.encode_with_vit(image),
76 VisionArchitecture::ResNet => self.encode_with_cnn(image),
77 _ => self.encode_with_vit(image), }
79 }
80
81 fn encode_with_vit(&self, image: &Array3<f32>) -> Result<Array1<f32>> {
83 let (h, w, c) = image.dim();
85 let (patch_h, patch_w) = self.config.patch_size;
86
87 let num_patches_h = h / patch_h;
88 let num_patches_w = w / patch_w;
89 let num_patches = num_patches_h * num_patches_w;
90
91 let mut patch_embeddings = Array2::zeros((num_patches, self.config.vision_dim));
93
94 for i in 0..num_patches_h {
95 for j in 0..num_patches_w {
96 let patch_idx = i * num_patches_w + j;
97
98 let patch = image.slice(scirs2_core::ndarray_ext::s![
100 i * patch_h..(i + 1) * patch_h,
101 j * patch_w..(j + 1) * patch_w,
102 ..
103 ]);
104
105 let patch_owned = patch.to_owned();
107 let flattened_patch = patch_owned
108 .into_shape_with_order(c * patch_h * patch_w)
109 .expect("reshape should succeed for valid patch dimensions");
110
111 if let Some(patch_embedding_matrix) = self.vit_parameters.get("patch_embedding") {
113 let embedding = flattened_patch.dot(patch_embedding_matrix);
114 patch_embeddings.row_mut(patch_idx).assign(&embedding);
115 }
116 }
117 }
118
119 let global_embedding = patch_embeddings
121 .mean_axis(Axis(0))
122 .expect("mean_axis should succeed for non-empty array");
123
124 Ok(global_embedding)
125 }
126
127 fn encode_with_cnn(&self, image: &Array3<f32>) -> Result<Array1<f32>> {
129 let mut features = image.clone();
131
132 for i in 0..self.config.cnn_config.num_layers.min(2) {
134 let (h, w, c) = features.dim();
137 let new_h = h / 2; let new_w = w / 2;
139 let new_c = self.config.cnn_config.filter_sizes[i];
140
141 let mut new_features = Array3::zeros((new_h, new_w, new_c));
142
143 for new_i in 0..new_h {
145 for new_j in 0..new_w {
146 for new_k in 0..new_c {
147 let old_i = new_i * 2;
148 let old_j = new_j * 2;
149
150 if old_i < h && old_j < w {
151 let mut sum = 0.0;
153 let mut count = 0;
154 for di in 0..2 {
155 for dj in 0..2 {
156 if old_i + di < h && old_j + dj < w {
157 for k in 0..c.min(new_c) {
158 sum += features[[old_i + di, old_j + dj, k]];
159 count += 1;
160 }
161 }
162 }
163 }
164 new_features[[new_i, new_j, new_k]] = sum / count as f32;
165 }
166 }
167 }
168 }
169
170 features = new_features;
171 }
172
173 let features_len = features.len();
175 let flattened = features
176 .into_shape_with_order(features_len)
177 .expect("reshape should succeed for valid features dimensions");
178 let mut global_features = vec![0.0; self.config.vision_dim];
179
180 for i in 0..global_features.len().min(flattened.len()) {
181 global_features[i] = flattened[i];
182 }
183
184 Ok(Array1::from_vec(global_features))
185 }
186}
187
188#[derive(Debug, Clone)]
190pub struct LanguageEncoder {
191 pub config: LanguageEncoderConfig,
192 pub token_embeddings: Array2<f32>,
194 pub position_embeddings: Array2<f32>,
196 pub transformer_parameters: HashMap<String, Array2<f32>>,
198}
199
200impl LanguageEncoder {
201 pub fn new(config: LanguageEncoderConfig) -> Self {
202 let mut random = Random::default();
204 let token_embeddings =
205 Array2::from_shape_fn((config.vocab_size, config.language_dim), |_| {
206 (random.random::<f32>() - 0.5) * 0.1
207 });
208
209 let mut random = Random::default();
210 let position_embeddings =
211 Array2::from_shape_fn((config.max_seq_length, config.language_dim), |_| {
212 (random.random::<f32>() - 0.5) * 0.1
213 });
214
215 let mut transformer_parameters = HashMap::new();
216
217 for layer in 0..config.transformer_config.num_layers {
219 let mut random = Random::default();
220 transformer_parameters.insert(
221 format!("attention_weights_{layer}"),
222 Array2::from_shape_fn((config.language_dim, config.language_dim), |_| {
223 (random.random::<f32>() - 0.5) * 0.1
224 }),
225 );
226
227 let mut random = Random::default();
228 transformer_parameters.insert(
229 format!("feed_forward_{layer}"),
230 Array2::from_shape_fn(
231 (
232 config.transformer_config.intermediate_dim,
233 config.language_dim,
234 ),
235 |_| (random.random::<f32>() - 0.5) * 0.1,
236 ),
237 );
238 }
239
240 Self {
241 config,
242 token_embeddings,
243 position_embeddings,
244 transformer_parameters,
245 }
246 }
247
248 pub fn encode_text(&self, text: &str) -> Result<Array1<f32>> {
250 let tokens = self.tokenize(text);
252
253 let mut sequence_embeddings = Array2::zeros((tokens.len(), self.config.language_dim));
255
256 for (i, &token_id) in tokens.iter().enumerate() {
257 if token_id < self.token_embeddings.nrows() {
258 let token_emb = self.token_embeddings.row(token_id);
259 let pos_emb = self
260 .position_embeddings
261 .row(i.min(self.config.max_seq_length - 1));
262
263 let combined = &token_emb + &pos_emb;
265 sequence_embeddings.row_mut(i).assign(&combined);
266 }
267 }
268
269 let mut hidden_states = sequence_embeddings;
271
272 for layer in 0..self.config.transformer_config.num_layers.min(2) {
273 if let Some(attention_weights) = self
275 .transformer_parameters
276 .get(&format!("attention_weights_{layer}"))
277 {
278 hidden_states = hidden_states.dot(attention_weights);
280
281 for mut row in hidden_states.rows_mut() {
283 let mean = row.mean().unwrap_or(0.0);
284 let var = row.var(0.0);
285 row.mapv_inplace(|x| (x - mean) / (var + 1e-8).sqrt());
286 }
287 }
288 }
289
290 let sentence_embedding = hidden_states
292 .mean_axis(Axis(0))
293 .expect("mean_axis should succeed for non-empty array");
294
295 Ok(sentence_embedding)
296 }
297
298 fn tokenize(&self, text: &str) -> Vec<usize> {
300 text.split_whitespace()
301 .map(|word| {
302 let mut hash = 0usize;
304 for byte in word.bytes() {
305 hash = hash.wrapping_mul(31).wrapping_add(byte as usize);
306 }
307 hash % self.config.vocab_size
308 })
309 .collect()
310 }
311}
312
313#[derive(Debug, Clone)]
315pub struct GraphEncoder {
316 pub config: GraphEncoderConfig,
317 pub node_parameters: HashMap<String, Array2<f32>>,
319 pub edge_parameters: HashMap<String, Array2<f32>>,
321 pub graph_parameters: HashMap<String, Array2<f32>>,
323}
324
325impl GraphEncoder {
326 pub fn new(config: GraphEncoderConfig) -> Self {
327 let mut node_parameters = HashMap::new();
328 let mut edge_parameters = HashMap::new();
329 let mut graph_parameters = HashMap::new();
330
331 for layer in 0..config.num_layers {
333 let mut random = Random::default();
334 node_parameters.insert(
335 format!("node_transform_{layer}"),
336 Array2::from_shape_fn((config.node_dim, config.node_dim), |_| {
337 (random.random::<f32>() - 0.5) * 0.1
338 }),
339 );
340 }
341
342 for layer in 0..config.num_layers {
344 let mut random = Random::default();
345 edge_parameters.insert(
346 format!("edge_transform_{layer}"),
347 Array2::from_shape_fn((config.edge_dim, config.edge_dim), |_| {
348 (random.random::<f32>() - 0.5) * 0.1
349 }),
350 );
351 }
352
353 let mut random = Random::default();
355 graph_parameters.insert(
356 "readout".to_string(),
357 Array2::from_shape_fn(
358 (config.node_dim, 1), |_| (random.random::<f32>() - 0.5) * 0.1,
360 ),
361 );
362
363 let mut random = Random::default();
365 graph_parameters.insert(
366 "graph_projection".to_string(),
367 Array2::from_shape_fn((config.node_dim, config.graph_dim), |_| {
368 (random.random::<f32>() - 0.5) * 0.1
369 }),
370 );
371
372 Self {
373 config,
374 node_parameters,
375 edge_parameters,
376 graph_parameters,
377 }
378 }
379
380 pub fn encode_graph(
382 &self,
383 node_features: &Array2<f32>,
384 edge_features: &Array2<f32>,
385 adjacency_matrix: &Array2<f32>,
386 ) -> Result<Array1<f32>> {
387 let mut node_embeddings = node_features.clone();
388
389 for layer in 0..self.config.num_layers.min(2) {
391 node_embeddings =
393 self.apply_gnn_layer(&node_embeddings, edge_features, adjacency_matrix, layer)?;
394 }
395
396 let graph_embedding = self.graph_readout(&node_embeddings)?;
398
399 Ok(graph_embedding)
400 }
401
402 fn apply_gnn_layer(
404 &self,
405 node_embeddings: &Array2<f32>,
406 _edge_features: &Array2<f32>,
407 adjacency_matrix: &Array2<f32>,
408 layer: usize,
409 ) -> Result<Array2<f32>> {
410 let transform_key = format!("node_transform_{layer}");
411
412 if let Some(transform_matrix) = self.node_parameters.get(&transform_key) {
413 let aggregated = adjacency_matrix.dot(node_embeddings);
415
416 let transformed = aggregated.dot(transform_matrix);
418
419 let activated = transformed.mapv(|x| x.max(0.0));
421
422 Ok(activated)
423 } else {
424 Ok(node_embeddings.clone())
425 }
426 }
427
428 fn graph_readout(&self, node_embeddings: &Array2<f32>) -> Result<Array1<f32>> {
430 let node_level_embedding = match self.config.readout {
431 ReadoutFunction::GlobalMean => node_embeddings
432 .mean_axis(Axis(0))
433 .expect("mean_axis should succeed for non-empty array"),
434 ReadoutFunction::GlobalMax => {
435 node_embeddings.fold_axis(Axis(0), f32::NEG_INFINITY, |&a, &b| a.max(b))
436 }
437 ReadoutFunction::GlobalSum => node_embeddings.sum_axis(Axis(0)),
438 ReadoutFunction::GlobalAttention => {
439 if let Some(readout_matrix) = self.graph_parameters.get("readout") {
440 let attention_scores = node_embeddings.dot(readout_matrix); let attention_scores_1d = attention_scores.column(0).to_owned(); let attention_weights = self.softmax_1d(&attention_scores_1d); let mut weighted_sum = Array1::zeros(node_embeddings.ncols());
447 for (i, &weight) in attention_weights.iter().enumerate() {
448 let node_emb = node_embeddings.row(i);
449 weighted_sum = weighted_sum + weight * &node_emb;
450 }
451 weighted_sum
452 } else {
453 node_embeddings
454 .mean_axis(Axis(0))
455 .expect("mean_axis should succeed for non-empty array")
456 }
457 }
458 _ => node_embeddings
459 .mean_axis(Axis(0))
460 .expect("mean_axis should succeed for non-empty array"),
461 };
462
463 if let Some(projection_matrix) = self.graph_parameters.get("graph_projection") {
465 Ok(projection_matrix.t().dot(&node_level_embedding))
466 } else {
467 Ok(node_level_embedding)
468 }
469 }
470
471 fn softmax_2d(&self, x: &Array2<f32>) -> Array2<f32> {
473 let mut result = x.clone();
474 for mut row in result.rows_mut() {
475 let max_val = row.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
476 row.mapv_inplace(|v| (v - max_val).exp());
477 let sum = row.sum();
478 if sum > 0.0 {
479 row /= sum;
480 }
481 }
482 result
483 }
484
485 fn softmax_1d(&self, x: &Array1<f32>) -> Array1<f32> {
486 let max_val = x.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
487 let mut result = x.mapv(|v| (v - max_val).exp());
488 let sum = result.sum();
489 if sum > 0.0 {
490 result /= sum;
491 }
492 result
493 }
494}