use crate::error::{NeuralError, Result};
use crate::layers::{Dense, Dropout, Layer, LayerNorm, Sequential};
use scirs2_core::ndarray::{Array, Axis, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::SeedableRng;
use scirs2_core::simd_ops::SimdUnifiedOps;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FusionMethod {
Concatenation,
Sum,
Product,
Attention,
Bilinear,
FiLM,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeatureFusionConfig {
pub input_dims: Vec<usize>,
pub hidden_dim: usize,
pub fusion_method: FusionMethod,
pub dropout_rate: f64,
pub num_classes: usize,
pub include_head: bool,
}
#[derive(Debug, Clone)]
pub struct FeatureAlignment<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign>
where
F: SimdUnifiedOps,
{
pub input_dim: usize,
pub output_dim: usize,
pub projection: Dense<F>,
pub norm: LayerNorm<F>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> FeatureAlignment<F>
where
F: SimdUnifiedOps,
{
pub fn new(input_dim: usize, output_dim: usize, _name: Option<&str>) -> Result<Self> {
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
let projection = Dense::<F>::new(input_dim, output_dim, None, &mut rng)?;
let norm = LayerNorm::<F>::new(output_dim, 1e-6, &mut rng)?;
Ok(Self {
input_dim,
output_dim,
projection,
norm,
})
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for FeatureAlignment<F>
where
F: SimdUnifiedOps,
{
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let x = self.projection.forward(input)?;
let x = self.norm.forward(&x)?;
Ok(x)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn backward(
&self,
input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
let proj_output = self.projection.forward(input)?;
let grad_proj = self.norm.backward(&proj_output, grad_output)?;
let grad_input = self.projection.backward(input, &grad_proj)?;
Ok(grad_input)
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.projection.update(learning_rate)?;
self.norm.update(learning_rate)?;
Ok(())
}
fn params(&self) -> Vec<Array<F, IxDyn>> {
let mut params = Vec::new();
params.extend(self.projection.params());
params.extend(self.norm.params());
params
}
fn set_training(&mut self, training: bool) {
self.projection.set_training(training);
self.norm.set_training(training);
}
fn is_training(&self) -> bool {
self.projection.is_training()
}
}
#[derive(Debug, Clone)]
pub struct CrossModalAttention<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
pub query_proj: Dense<F>,
pub key_proj: Dense<F>,
pub value_proj: Dense<F>,
pub output_proj: Dense<F>,
pub hidden_dim: usize,
pub scale: F,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> CrossModalAttention<F> {
pub fn new(query_dim: usize, key_dim: usize, hidden_dim: usize) -> Result<Self> {
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
let query_proj = Dense::<F>::new(query_dim, hidden_dim, None, &mut rng)?;
let key_proj = Dense::<F>::new(key_dim, hidden_dim, None, &mut rng)?;
let value_proj = Dense::<F>::new(key_dim, hidden_dim, None, &mut rng)?;
let output_proj = Dense::<F>::new(hidden_dim, query_dim, None, &mut rng)?;
let scale = F::from(1.0 / (hidden_dim as f64).sqrt()).expect("Operation failed");
Ok(Self {
query_proj,
key_proj,
value_proj,
output_proj,
hidden_dim,
scale,
})
}
pub fn forward(
&self,
query: &Array<F, IxDyn>,
context: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
let q = self.query_proj.forward(query)?;
let k = self.key_proj.forward(context)?;
let v = self.value_proj.forward(context)?;
let batch_size = q.shape()[0];
let query_len = q.shape()[1];
let context_len = k.shape()[1];
let q_2d = q
.clone()
.into_shape_with_order((batch_size * query_len, self.hidden_dim))?;
let k_2d = k.into_shape_with_order((batch_size * context_len, self.hidden_dim))?;
let v_2d = v.into_shape_with_order((batch_size * context_len, self.hidden_dim))?;
let scores = q_2d.dot(&k_2d.t()) * self.scale;
let scores_3d = scores.into_shape_with_order((batch_size, query_len, context_len))?;
let mut attention_weights = scores_3d.to_owned().into_dyn();
attention_weights.fill(F::zero());
for b in 0..batch_size {
for q in 0..query_len {
let mut row = scores_3d
.slice(scirs2_core::ndarray::s![b, q, ..])
.to_owned();
let max_val = row.fold(F::neg_infinity(), |m: F, &v: &F| m.max(v));
let mut exp_sum = F::zero();
for i in 0..context_len {
let exp_val = (row[i] - max_val).exp();
row[i] = exp_val;
exp_sum += exp_val;
}
if exp_sum > F::zero() {
for i in 0..context_len {
row[i] /= exp_sum;
}
}
for i in 0..context_len {
attention_weights[[b, q, i]] = row[i];
}
}
}
let attn_weights_2d = attention_weights
.into_shape_with_order((batch_size * query_len, batch_size * context_len))?;
let context_vec = attn_weights_2d.dot(&v_2d);
let context_vec_reshaped =
context_vec.into_shape_with_order((batch_size, query_len, self.hidden_dim))?;
let output = self.output_proj.forward(&context_vec_reshaped.into_dyn())?;
Ok(output)
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F>
for CrossModalAttention<F>
{
fn forward(&self, _input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
Err(NeuralError::ValidationError(
"CrossModalAttention requires separate query and context inputs. Use the dedicated forward method."
.to_string(),
))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.query_proj.update(learning_rate)?;
self.key_proj.update(learning_rate)?;
self.value_proj.update(learning_rate)?;
self.output_proj.update(learning_rate)?;
Ok(())
}
fn params(&self) -> Vec<Array<F, IxDyn>> {
let mut params = Vec::new();
params.extend(self.query_proj.params());
params.extend(self.key_proj.params());
params.extend(self.value_proj.params());
params.extend(self.output_proj.params());
params
}
fn set_training(&mut self, training: bool) {
self.query_proj.set_training(training);
self.key_proj.set_training(training);
self.value_proj.set_training(training);
self.output_proj.set_training(training);
}
fn is_training(&self) -> bool {
self.query_proj.is_training()
}
}
#[derive(Debug, Clone)]
pub struct FiLMModule<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
pub feature_dim: usize,
pub cond_dim: usize,
pub gamma_proj: Dense<F>,
pub beta_proj: Dense<F>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> FiLMModule<F> {
pub fn new(feature_dim: usize, cond_dim: usize) -> Result<Self> {
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
let gamma_proj = Dense::<F>::new(cond_dim, feature_dim, None, &mut rng)?;
let beta_proj = Dense::<F>::new(cond_dim, feature_dim, None, &mut rng)?;
Ok(Self {
feature_dim,
cond_dim,
gamma_proj,
beta_proj,
})
}
pub fn forward(
&self,
features: &Array<F, IxDyn>,
conditioning: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
let gamma = self.gamma_proj.forward(conditioning)?;
let beta = self.beta_proj.forward(conditioning)?;
let modulated = &gamma * features + β
Ok(modulated)
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for FiLMModule<F> {
fn forward(&self, _input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
Err(NeuralError::ValidationError(
"FiLMModule requires separate feature and conditioning inputs. Use the dedicated forward method."
.to_string(),
))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.gamma_proj.update(learning_rate)?;
self.beta_proj.update(learning_rate)?;
Ok(())
}
fn params(&self) -> Vec<Array<F, IxDyn>> {
let mut params = Vec::new();
params.extend(self.gamma_proj.params());
params.extend(self.beta_proj.params());
params
}
fn set_training(&mut self, training: bool) {
self.gamma_proj.set_training(training);
self.beta_proj.set_training(training);
}
fn is_training(&self) -> bool {
self.gamma_proj.is_training()
}
}
#[derive(Debug, Clone)]
pub struct BilinearFusion<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
pub dim_a: usize,
pub dim_b: usize,
pub output_dim: usize,
pub proj_a: Dense<F>,
pub proj_b: Dense<F>,
pub low_rank_proj: Dense<F>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> BilinearFusion<F> {
pub fn new(dim_a: usize, dim_b: usize, output_dim: usize, rank: usize) -> Result<Self> {
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
let proj_a = Dense::<F>::new(dim_a, rank, None, &mut rng)?;
let proj_b = Dense::<F>::new(dim_b, rank, None, &mut rng)?;
let low_rank_proj = Dense::<F>::new(rank, output_dim, None, &mut rng)?;
Ok(Self {
dim_a,
dim_b,
output_dim,
proj_a,
proj_b,
low_rank_proj,
})
}
pub fn forward(
&self,
features_a: &Array<F, IxDyn>,
features_b: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
let a_proj = self.proj_a.forward(features_a)?;
let b_proj = self.proj_b.forward(features_b)?;
let bilinear = &a_proj * &b_proj;
let output = self.low_rank_proj.forward(&bilinear)?;
Ok(output)
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for BilinearFusion<F> {
fn forward(&self, _input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
Err(NeuralError::ValidationError(
"BilinearFusion requires separate feature inputs. Use the dedicated forward method."
.to_string(),
))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.proj_a.update(learning_rate)?;
self.proj_b.update(learning_rate)?;
self.low_rank_proj.update(learning_rate)?;
Ok(())
}
fn params(&self) -> Vec<Array<F, IxDyn>> {
let mut params = Vec::new();
params.extend(self.proj_a.params());
params.extend(self.proj_b.params());
params.extend(self.low_rank_proj.params());
params
}
fn set_training(&mut self, training: bool) {
self.proj_a.set_training(training);
self.proj_b.set_training(training);
self.low_rank_proj.set_training(training);
}
fn is_training(&self) -> bool {
self.proj_a.is_training()
}
}
pub struct FeatureFusion<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign>
where
F: SimdUnifiedOps,
{
pub aligners: Vec<FeatureAlignment<F>>,
pub fusion_module: Option<Box<dyn Layer<F> + Send + Sync>>,
pub post_fusion: Sequential<F>,
pub classifier: Option<Dense<F>>,
pub config: FeatureFusionConfig,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Debug for FeatureFusion<F>
where
F: SimdUnifiedOps,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FeatureFusion")
.field("aligners", &self.aligners)
.field(
"fusion_module",
&"<Box<dyn Layer<F> + Send + Sync>>".to_string(),
)
.field("post_fusion", &self.post_fusion)
.field("classifier", &self.classifier)
.field("config", &self.config)
.finish()
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Clone for FeatureFusion<F>
where
F: SimdUnifiedOps,
{
fn clone(&self) -> Self {
Self {
aligners: self.aligners.clone(),
fusion_module: None, post_fusion: self.post_fusion.clone(),
classifier: self.classifier.clone(),
config: self.config.clone(),
}
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> FeatureFusion<F>
where
F: SimdUnifiedOps,
{
pub fn new(config: FeatureFusionConfig) -> Result<Self> {
let mut aligners = Vec::with_capacity(config.input_dims.len());
for (i, &dim) in config.input_dims.iter().enumerate() {
aligners.push(FeatureAlignment::<F>::new(
dim,
config.hidden_dim,
Some(&format!("aligner_{}", i)),
)?);
}
let fusion_module: Option<Box<dyn Layer<F> + Send + Sync>> = match config.fusion_method {
FusionMethod::Attention => {
if config.input_dims.len() < 2 {
return Err(NeuralError::ValidationError(
"Attention fusion requires at least two modalities".to_string(),
));
}
let attn = CrossModalAttention::<F>::new(
config.hidden_dim,
config.hidden_dim,
config.hidden_dim,
)?;
Some(Box::new(attn))
}
FusionMethod::Bilinear => {
if config.input_dims.len() != 2 {
return Err(NeuralError::ValidationError(
"Bilinear fusion requires exactly two modalities".to_string(),
));
}
let bilinear = BilinearFusion::<F>::new(
config.hidden_dim,
config.hidden_dim,
config.hidden_dim,
config.hidden_dim / 4, )?;
Some(Box::new(bilinear))
}
FusionMethod::FiLM => {
if config.input_dims.len() != 2 {
return Err(NeuralError::ValidationError(
"FiLM fusion requires exactly two modalities".to_string(),
));
}
let film = FiLMModule::<F>::new(config.hidden_dim, config.hidden_dim)?;
Some(Box::new(film))
}
_ => None,
};
let mut post_fusion = Sequential::new();
let post_fusion_input_dim = match config.fusion_method {
FusionMethod::Concatenation => config.hidden_dim * config.input_dims.len(),
_ => config.hidden_dim,
};
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
post_fusion.add(Dense::<F>::new(
post_fusion_input_dim,
config.hidden_dim * 2,
Some("gelu"),
&mut rng,
)?);
if config.dropout_rate > 0.0 {
post_fusion.add(Dropout::<F>::new(config.dropout_rate, &mut rng)?);
}
post_fusion.add(Dense::<F>::new(
config.hidden_dim * 2,
config.hidden_dim,
Some("gelu"),
&mut rng,
)?);
let classifier = if config.include_head {
Some(Dense::<F>::new(
config.hidden_dim,
config.num_classes,
None,
&mut rng,
)?)
} else {
None
};
Ok(Self {
aligners,
fusion_module,
post_fusion,
classifier,
config,
})
}
pub fn forward_multi(&self, inputs: &[Array<F, IxDyn>]) -> Result<Array<F, IxDyn>> {
if inputs.len() != self.config.input_dims.len() {
return Err(NeuralError::ValidationError(format!(
"Expected {} inputs, got {}",
self.config.input_dims.len(),
inputs.len()
)));
}
let mut aligned_features = Vec::with_capacity(inputs.len());
for (i, input) in inputs.iter().enumerate() {
aligned_features.push(self.aligners[i].forward(input)?);
}
let fused = match self.config.fusion_method {
FusionMethod::Concatenation => {
let batch_size = aligned_features[0].shape()[0];
let mut concatenated = Vec::new();
for batch_idx in 0..batch_size {
for features in &aligned_features {
let batch_features = features.slice_axis(
Axis(0),
scirs2_core::ndarray::Slice::from(batch_idx..batch_idx + 1),
);
concatenated.extend(batch_features.iter().cloned());
}
}
Array::from_shape_vec(
[batch_size, self.config.hidden_dim * aligned_features.len()],
concatenated,
)?
.into_dyn()
}
FusionMethod::Sum => {
let mut result = aligned_features[0].clone();
for features in &aligned_features[1..] {
result += features;
}
result
}
FusionMethod::Product => {
let mut result = aligned_features[0].clone();
for features in &aligned_features[1..] {
result *= features;
}
result
}
FusionMethod::Attention => {
if let Some(ref module) = self.fusion_module {
if let Some(attn) = module.as_any().downcast_ref::<CrossModalAttention<F>>() {
attn.forward(&aligned_features[0], &aligned_features[1])?
} else {
return Err(NeuralError::InferenceError(
"Failed to cast fusion module to CrossModalAttention".to_string(),
));
}
} else {
return Err(NeuralError::InferenceError(
"Attention fusion module not initialized".to_string(),
));
}
}
FusionMethod::Bilinear => {
if let Some(ref module) = self.fusion_module {
if let Some(bilinear) = module.as_any().downcast_ref::<BilinearFusion<F>>() {
bilinear.forward(&aligned_features[0], &aligned_features[1])?
} else {
return Err(NeuralError::InferenceError(
"Failed to cast fusion module to BilinearFusion".to_string(),
));
}
} else {
return Err(NeuralError::InferenceError(
"Bilinear fusion module not initialized".to_string(),
));
}
}
FusionMethod::FiLM => {
if let Some(ref module) = self.fusion_module {
if let Some(film) = module.as_any().downcast_ref::<FiLMModule<F>>() {
film.forward(&aligned_features[0], &aligned_features[1])?
} else {
return Err(NeuralError::InferenceError(
"Failed to cast fusion module to FiLMModule".to_string(),
));
}
} else {
return Err(NeuralError::InferenceError(
"FiLM fusion module not initialized".to_string(),
));
}
}
};
let features = self.post_fusion.forward(&fused)?;
if let Some(ref classifier) = self.classifier {
classifier.forward(&features)
} else {
Ok(features)
}
}
pub fn create_early_fusion(
dim_a: usize,
dim_b: usize,
hidden_dim: usize,
num_classes: usize,
include_head: bool,
) -> Result<Self> {
let config = FeatureFusionConfig {
input_dims: vec![dim_a, dim_b],
hidden_dim,
fusion_method: FusionMethod::Concatenation,
dropout_rate: 0.1,
num_classes,
include_head,
};
Self::new(config)
}
pub fn create_attention_fusion(
dim_a: usize,
dim_b: usize,
hidden_dim: usize,
num_classes: usize,
include_head: bool,
) -> Result<Self> {
let config = FeatureFusionConfig {
input_dims: vec![dim_a, dim_b],
hidden_dim,
fusion_method: FusionMethod::Attention,
dropout_rate: 0.1,
num_classes,
include_head,
};
Self::new(config)
}
pub fn create_film_fusion(
dim_a: usize,
dim_b: usize,
hidden_dim: usize,
num_classes: usize,
include_head: bool,
) -> Result<Self> {
let config = FeatureFusionConfig {
input_dims: vec![dim_a, dim_b],
hidden_dim,
fusion_method: FusionMethod::FiLM,
dropout_rate: 0.1,
num_classes,
include_head,
};
Self::new(config)
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for FeatureFusion<F>
where
F: SimdUnifiedOps,
{
fn forward(&self, _input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
Err(NeuralError::ValidationError(
"FeatureFusion requires multiple inputs. Use forward_multi method instead.".to_string(),
))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
for aligner in &mut self.aligners {
aligner.update(learning_rate)?;
}
if let Some(ref mut module) = self.fusion_module {
module.update(learning_rate)?;
}
self.post_fusion.update(learning_rate)?;
if let Some(ref mut classifier) = self.classifier {
classifier.update(learning_rate)?;
}
Ok(())
}
fn params(&self) -> Vec<Array<F, IxDyn>> {
let mut params = Vec::new();
for aligner in &self.aligners {
params.extend(aligner.params());
}
if let Some(ref module) = self.fusion_module {
params.extend(module.params());
}
params.extend(self.post_fusion.params());
if let Some(ref classifier) = self.classifier {
params.extend(classifier.params());
}
params
}
fn set_training(&mut self, training: bool) {
for aligner in &mut self.aligners {
aligner.set_training(training);
}
if let Some(ref mut module) = self.fusion_module {
module.set_training(training);
}
self.post_fusion.set_training(training);
if let Some(ref mut classifier) = self.classifier {
classifier.set_training(training);
}
}
fn is_training(&self) -> bool {
self.aligners[0].is_training()
}
}