use super::*;
use anyhow::Result;
use scirs2_core::ndarray_ext::{Array1, Array2};
use scirs2_core::random::{Random, RngExt};
use std::collections::HashMap;
#[derive(Debug)]
pub struct MultiModalTransformer {
pub config: MultiModalTransformerConfig,
pub cross_attention_params: HashMap<String, Array2<f32>>,
pub fusion_params: HashMap<String, Array2<f32>>,
pub modality_embeddings: Array2<f32>,
}
impl MultiModalTransformer {
pub fn new(config: MultiModalTransformerConfig) -> Self {
let mut cross_attention_params = HashMap::new();
let mut fusion_params = HashMap::new();
for layer in 0..config.num_fusion_layers {
for modality_pair in &["vision_language", "language_graph", "vision_graph"] {
let mut random = Random::default();
cross_attention_params.insert(
format!("{modality_pair}_{layer}"),
Array2::from_shape_fn((config.unified_dim, config.unified_dim), |_| {
(random.random::<f32>() - 0.5) * 0.1
}),
);
}
}
let mut random = Random::default();
fusion_params.insert(
"tri_modal_fusion".to_string(),
Array2::from_shape_fn((config.unified_dim, config.unified_dim * 3), |_| {
(random.random::<f32>() - 0.5) * 0.1
}),
);
let mut random = Random::default();
let modality_embeddings = Array2::from_shape_fn(
(3, config.unified_dim), |_| (random.random::<f32>() - 0.5) * 0.1,
);
Self {
config,
cross_attention_params,
fusion_params,
modality_embeddings,
}
}
pub fn fuse_embeddings(
&self,
vision_emb: &Array1<f32>,
language_emb: &Array1<f32>,
graph_emb: &Array1<f32>,
) -> Result<Array1<f32>> {
match self.config.fusion_strategy {
FusionStrategy::EarlyFusion => self.early_fusion(vision_emb, language_emb, graph_emb),
FusionStrategy::CrossAttention => {
self.cross_attention_fusion(vision_emb, language_emb, graph_emb)
}
FusionStrategy::TensorFusion => self.tensor_fusion(vision_emb, language_emb, graph_emb),
_ => self.early_fusion(vision_emb, language_emb, graph_emb),
}
}
fn early_fusion(
&self,
vision_emb: &Array1<f32>,
language_emb: &Array1<f32>,
graph_emb: &Array1<f32>,
) -> Result<Array1<f32>> {
let mut concatenated = Vec::new();
concatenated.extend_from_slice(vision_emb.as_slice().expect("array should be contiguous"));
concatenated
.extend_from_slice(language_emb.as_slice().expect("array should be contiguous"));
concatenated.extend_from_slice(graph_emb.as_slice().expect("array should be contiguous"));
let concat_array = Array1::from_vec(concatenated);
if let Some(fusion_matrix) = self.fusion_params.get("tri_modal_fusion") {
Ok(fusion_matrix.dot(&concat_array))
} else {
let avg_len = vision_emb
.len()
.min(language_emb.len())
.min(graph_emb.len());
let mut averaged = Array1::zeros(avg_len);
for i in 0..avg_len {
averaged[i] = (vision_emb[i] + language_emb[i] + graph_emb[i]) / 3.0;
}
Ok(averaged)
}
}
fn cross_attention_fusion(
&self,
vision_emb: &Array1<f32>,
language_emb: &Array1<f32>,
graph_emb: &Array1<f32>,
) -> Result<Array1<f32>> {
let mut fused = vision_emb.clone();
if let Some(vl_attention) = self.cross_attention_params.get("vision_language_0") {
let vl_attended = vl_attention.dot(language_emb);
fused = &fused + &vl_attended;
}
if let Some(vg_attention) = self.cross_attention_params.get("vision_graph_0") {
let vg_attended = vg_attention.dot(graph_emb);
fused = &fused + &vg_attended;
}
let norm = fused.dot(&fused).sqrt();
if norm > 0.0 {
fused /= norm;
}
Ok(fused)
}
fn tensor_fusion(
&self,
vision_emb: &Array1<f32>,
language_emb: &Array1<f32>,
graph_emb: &Array1<f32>,
) -> Result<Array1<f32>> {
let min_dim = vision_emb
.len()
.min(language_emb.len())
.min(graph_emb.len());
let mut fused = Array1::zeros(min_dim);
for i in 0..min_dim {
fused[i] = vision_emb[i] * language_emb[i] * graph_emb[i];
}
Ok(fused)
}
}