use super::*;
use anyhow::Result;
use scirs2_core::ndarray_ext::{Array1, Array2, Array3, Array4, Axis};
use scirs2_core::random::{Random, RngExt};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct VisionEncoder {
pub config: VisionEncoderConfig,
pub cnn_parameters: HashMap<String, Array4<f32>>,
pub vit_parameters: HashMap<String, Array2<f32>>,
pub projection: Array2<f32>,
}
impl VisionEncoder {
pub fn new(config: VisionEncoderConfig) -> Self {
let mut cnn_parameters = HashMap::new();
let mut vit_parameters = HashMap::new();
for (i, &filter_size) in config.cnn_config.filter_sizes.iter().enumerate() {
let layer_name = format!("conv_{i}");
let weight_shape = (
filter_size,
if i == 0 {
config.channels
} else {
config.cnn_config.filter_sizes[i - 1]
},
3,
3,
);
let mut random = Random::default();
cnn_parameters.insert(
layer_name,
Array4::from_shape_fn(weight_shape, |_| (random.random::<f32>() - 0.5) * 0.1),
);
}
let mut random = Random::default();
vit_parameters.insert(
"patch_embedding".to_string(),
Array2::from_shape_fn(
(
config.channels * config.patch_size.0 * config.patch_size.1,
config.vision_dim,
),
|_| (random.random::<f32>() - 0.5) * 0.1,
),
);
let mut random = Random::default();
let projection = Array2::from_shape_fn((config.vision_dim, config.vision_dim), |_| {
(random.random::<f32>() - 0.5) * 0.1
});
Self {
config,
cnn_parameters,
vit_parameters,
projection,
}
}
pub fn encode_image(&self, image: &Array3<f32>) -> Result<Array1<f32>> {
match self.config.architecture {
VisionArchitecture::VisionTransformer => self.encode_with_vit(image),
VisionArchitecture::ResNet => self.encode_with_cnn(image),
_ => self.encode_with_vit(image), }
}
fn encode_with_vit(&self, image: &Array3<f32>) -> Result<Array1<f32>> {
let (h, w, c) = image.dim();
let (patch_h, patch_w) = self.config.patch_size;
let num_patches_h = h / patch_h;
let num_patches_w = w / patch_w;
let num_patches = num_patches_h * num_patches_w;
let mut patch_embeddings = Array2::zeros((num_patches, self.config.vision_dim));
for i in 0..num_patches_h {
for j in 0..num_patches_w {
let patch_idx = i * num_patches_w + j;
let patch = image.slice(scirs2_core::ndarray_ext::s![
i * patch_h..(i + 1) * patch_h,
j * patch_w..(j + 1) * patch_w,
..
]);
let patch_owned = patch.to_owned();
let flattened_patch = patch_owned
.into_shape_with_order(c * patch_h * patch_w)
.expect("reshape should succeed for valid patch dimensions");
if let Some(patch_embedding_matrix) = self.vit_parameters.get("patch_embedding") {
let embedding = flattened_patch.dot(patch_embedding_matrix);
patch_embeddings.row_mut(patch_idx).assign(&embedding);
}
}
}
let global_embedding = patch_embeddings
.mean_axis(Axis(0))
.expect("mean_axis should succeed for non-empty array");
Ok(global_embedding)
}
fn encode_with_cnn(&self, image: &Array3<f32>) -> Result<Array1<f32>> {
let mut features = image.clone();
for i in 0..self.config.cnn_config.num_layers.min(2) {
let (h, w, c) = features.dim();
let new_h = h / 2; let new_w = w / 2;
let new_c = self.config.cnn_config.filter_sizes[i];
let mut new_features = Array3::zeros((new_h, new_w, new_c));
for new_i in 0..new_h {
for new_j in 0..new_w {
for new_k in 0..new_c {
let old_i = new_i * 2;
let old_j = new_j * 2;
if old_i < h && old_j < w {
let mut sum = 0.0;
let mut count = 0;
for di in 0..2 {
for dj in 0..2 {
if old_i + di < h && old_j + dj < w {
for k in 0..c.min(new_c) {
sum += features[[old_i + di, old_j + dj, k]];
count += 1;
}
}
}
}
new_features[[new_i, new_j, new_k]] = sum / count as f32;
}
}
}
}
features = new_features;
}
let features_len = features.len();
let flattened = features
.into_shape_with_order(features_len)
.expect("reshape should succeed for valid features dimensions");
let mut global_features = vec![0.0; self.config.vision_dim];
for i in 0..global_features.len().min(flattened.len()) {
global_features[i] = flattened[i];
}
Ok(Array1::from_vec(global_features))
}
}
#[derive(Debug, Clone)]
pub struct LanguageEncoder {
pub config: LanguageEncoderConfig,
pub token_embeddings: Array2<f32>,
pub position_embeddings: Array2<f32>,
pub transformer_parameters: HashMap<String, Array2<f32>>,
}
impl LanguageEncoder {
pub fn new(config: LanguageEncoderConfig) -> Self {
let mut random = Random::default();
let token_embeddings =
Array2::from_shape_fn((config.vocab_size, config.language_dim), |_| {
(random.random::<f32>() - 0.5) * 0.1
});
let mut random = Random::default();
let position_embeddings =
Array2::from_shape_fn((config.max_seq_length, config.language_dim), |_| {
(random.random::<f32>() - 0.5) * 0.1
});
let mut transformer_parameters = HashMap::new();
for layer in 0..config.transformer_config.num_layers {
let mut random = Random::default();
transformer_parameters.insert(
format!("attention_weights_{layer}"),
Array2::from_shape_fn((config.language_dim, config.language_dim), |_| {
(random.random::<f32>() - 0.5) * 0.1
}),
);
let mut random = Random::default();
transformer_parameters.insert(
format!("feed_forward_{layer}"),
Array2::from_shape_fn(
(
config.transformer_config.intermediate_dim,
config.language_dim,
),
|_| (random.random::<f32>() - 0.5) * 0.1,
),
);
}
Self {
config,
token_embeddings,
position_embeddings,
transformer_parameters,
}
}
pub fn encode_text(&self, text: &str) -> Result<Array1<f32>> {
let tokens = self.tokenize(text);
let mut sequence_embeddings = Array2::zeros((tokens.len(), self.config.language_dim));
for (i, &token_id) in tokens.iter().enumerate() {
if token_id < self.token_embeddings.nrows() {
let token_emb = self.token_embeddings.row(token_id);
let pos_emb = self
.position_embeddings
.row(i.min(self.config.max_seq_length - 1));
let combined = &token_emb + &pos_emb;
sequence_embeddings.row_mut(i).assign(&combined);
}
}
let mut hidden_states = sequence_embeddings;
for layer in 0..self.config.transformer_config.num_layers.min(2) {
if let Some(attention_weights) = self
.transformer_parameters
.get(&format!("attention_weights_{layer}"))
{
hidden_states = hidden_states.dot(attention_weights);
for mut row in hidden_states.rows_mut() {
let mean = row.mean().unwrap_or(0.0);
let var = row.var(0.0);
row.mapv_inplace(|x| (x - mean) / (var + 1e-8).sqrt());
}
}
}
let sentence_embedding = hidden_states
.mean_axis(Axis(0))
.expect("mean_axis should succeed for non-empty array");
Ok(sentence_embedding)
}
fn tokenize(&self, text: &str) -> Vec<usize> {
text.split_whitespace()
.map(|word| {
let mut hash = 0usize;
for byte in word.bytes() {
hash = hash.wrapping_mul(31).wrapping_add(byte as usize);
}
hash % self.config.vocab_size
})
.collect()
}
}
#[derive(Debug, Clone)]
pub struct GraphEncoder {
pub config: GraphEncoderConfig,
pub node_parameters: HashMap<String, Array2<f32>>,
pub edge_parameters: HashMap<String, Array2<f32>>,
pub graph_parameters: HashMap<String, Array2<f32>>,
}
impl GraphEncoder {
pub fn new(config: GraphEncoderConfig) -> Self {
let mut node_parameters = HashMap::new();
let mut edge_parameters = HashMap::new();
let mut graph_parameters = HashMap::new();
for layer in 0..config.num_layers {
let mut random = Random::default();
node_parameters.insert(
format!("node_transform_{layer}"),
Array2::from_shape_fn((config.node_dim, config.node_dim), |_| {
(random.random::<f32>() - 0.5) * 0.1
}),
);
}
for layer in 0..config.num_layers {
let mut random = Random::default();
edge_parameters.insert(
format!("edge_transform_{layer}"),
Array2::from_shape_fn((config.edge_dim, config.edge_dim), |_| {
(random.random::<f32>() - 0.5) * 0.1
}),
);
}
let mut random = Random::default();
graph_parameters.insert(
"readout".to_string(),
Array2::from_shape_fn(
(config.node_dim, 1), |_| (random.random::<f32>() - 0.5) * 0.1,
),
);
let mut random = Random::default();
graph_parameters.insert(
"graph_projection".to_string(),
Array2::from_shape_fn((config.node_dim, config.graph_dim), |_| {
(random.random::<f32>() - 0.5) * 0.1
}),
);
Self {
config,
node_parameters,
edge_parameters,
graph_parameters,
}
}
pub fn encode_graph(
&self,
node_features: &Array2<f32>,
edge_features: &Array2<f32>,
adjacency_matrix: &Array2<f32>,
) -> Result<Array1<f32>> {
let mut node_embeddings = node_features.clone();
for layer in 0..self.config.num_layers.min(2) {
node_embeddings =
self.apply_gnn_layer(&node_embeddings, edge_features, adjacency_matrix, layer)?;
}
let graph_embedding = self.graph_readout(&node_embeddings)?;
Ok(graph_embedding)
}
fn apply_gnn_layer(
&self,
node_embeddings: &Array2<f32>,
_edge_features: &Array2<f32>,
adjacency_matrix: &Array2<f32>,
layer: usize,
) -> Result<Array2<f32>> {
let transform_key = format!("node_transform_{layer}");
if let Some(transform_matrix) = self.node_parameters.get(&transform_key) {
let aggregated = adjacency_matrix.dot(node_embeddings);
let transformed = aggregated.dot(transform_matrix);
let activated = transformed.mapv(|x| x.max(0.0));
Ok(activated)
} else {
Ok(node_embeddings.clone())
}
}
fn graph_readout(&self, node_embeddings: &Array2<f32>) -> Result<Array1<f32>> {
let node_level_embedding = match self.config.readout {
ReadoutFunction::GlobalMean => node_embeddings
.mean_axis(Axis(0))
.expect("mean_axis should succeed for non-empty array"),
ReadoutFunction::GlobalMax => {
node_embeddings.fold_axis(Axis(0), f32::NEG_INFINITY, |&a, &b| a.max(b))
}
ReadoutFunction::GlobalSum => node_embeddings.sum_axis(Axis(0)),
ReadoutFunction::GlobalAttention => {
if let Some(readout_matrix) = self.graph_parameters.get("readout") {
let attention_scores = node_embeddings.dot(readout_matrix); let attention_scores_1d = attention_scores.column(0).to_owned(); let attention_weights = self.softmax_1d(&attention_scores_1d);
let mut weighted_sum = Array1::zeros(node_embeddings.ncols());
for (i, &weight) in attention_weights.iter().enumerate() {
let node_emb = node_embeddings.row(i);
weighted_sum = weighted_sum + weight * &node_emb;
}
weighted_sum
} else {
node_embeddings
.mean_axis(Axis(0))
.expect("mean_axis should succeed for non-empty array")
}
}
_ => node_embeddings
.mean_axis(Axis(0))
.expect("mean_axis should succeed for non-empty array"),
};
if let Some(projection_matrix) = self.graph_parameters.get("graph_projection") {
Ok(projection_matrix.t().dot(&node_level_embedding))
} else {
Ok(node_level_embedding)
}
}
fn softmax_2d(&self, x: &Array2<f32>) -> Array2<f32> {
let mut result = x.clone();
for mut row in result.rows_mut() {
let max_val = row.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
row.mapv_inplace(|v| (v - max_val).exp());
let sum = row.sum();
if sum > 0.0 {
row /= sum;
}
}
result
}
fn softmax_1d(&self, x: &Array1<f32>) -> Array1<f32> {
let max_val = x.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut result = x.mapv(|v| (v - max_val).exp());
let sum = result.sum();
if sum > 0.0 {
result /= sum;
}
result
}
}