use crate::error::{NeuralError, Result};
use crate::layers::{Dense, Layer};
use scirs2_core::ndarray::prelude::*;
use scirs2_core::ndarray::IxDyn;
use scirs2_core::random::rng;
use std::collections::HashMap;
pub struct SharedBackbone {
layers: Vec<Box<dyn Layer<f32> + Send + Sync>>,
input_dim: usize,
output_dim: usize,
}
impl SharedBackbone {
pub fn new(input_dim: usize, layer_sizes: &[usize]) -> Result<Self> {
let mut layers: Vec<Box<dyn Layer<f32> + Send + Sync>> = Vec::new();
let mut current_dim = input_dim;
let mut rng = rng();
for &layer_size in layer_sizes {
let dense_layer = Dense::<f32>::new(current_dim, layer_size, Some("relu"), &mut rng)?;
layers.push(Box::new(dense_layer));
current_dim = layer_size;
}
let output_dim = layer_sizes.last().copied().unwrap_or(input_dim);
Ok(Self {
layers,
input_dim,
output_dim,
})
}
pub fn forward(&self, input: &ArrayView2<f32>) -> Result<Array2<f32>> {
let mut current_output = input.to_owned().into_dyn();
for layer in &self.layers {
current_output = layer.forward(¤t_output)?;
}
current_output
.into_dimensionality()
.map_err(|e| NeuralError::ShapeMismatch(format!("Shape conversion error: {:?}", e)))
}
pub fn output_dim(&self) -> usize {
self.output_dim
}
pub fn parameters(&self) -> Vec<Array<f32, IxDyn>> {
let mut params = Vec::new();
for layer in &self.layers {
params.extend(layer.params());
}
params
}
pub fn set_parameters(&mut self, parameters: &[Array<f32, IxDyn>]) -> Result<()> {
let mut param_idx = 0;
for layer in &mut self.layers {
let layer_params = layer.params();
let num_layer_params = layer_params.len();
if param_idx + num_layer_params > parameters.len() {
return Err(NeuralError::InvalidArgument(
"Insufficient parameters provided".to_string(),
));
}
let layer_param_slice = ¶meters[param_idx..param_idx + num_layer_params];
layer.set_params(layer_param_slice)?;
param_idx += num_layer_params;
}
Ok(())
}
}
pub struct TaskSpecificHead {
task_name: String,
layers: Vec<Box<dyn Layer<f32> + Send + Sync>>,
input_dim: usize,
output_dim: usize,
task_type: TaskType,
}
#[derive(Debug, Clone, PartialEq)]
pub enum TaskType {
Classification { num_classes: usize },
Regression { output_dim: usize },
MultiLabel { num_labels: usize },
Structured { output_shape: Vec<usize> },
}
impl TaskSpecificHead {
pub fn new(
task_name: String,
input_dim: usize,
layer_sizes: &[usize],
output_dim: usize,
task_type: TaskType,
) -> Result<Self> {
let mut layers: Vec<Box<dyn Layer<f32> + Send + Sync>> = Vec::new();
let mut current_dim = input_dim;
let mut rng = rng();
for &layer_size in layer_sizes {
let dense_layer = Dense::<f32>::new(current_dim, layer_size, Some("relu"), &mut rng)?;
layers.push(Box::new(dense_layer));
current_dim = layer_size;
}
let output_activation = match &task_type {
TaskType::Classification { .. } => Some("softmax"),
TaskType::MultiLabel { .. } => Some("sigmoid"),
TaskType::Regression { .. } => None,
TaskType::Structured { .. } => None,
};
let output_layer = Dense::<f32>::new(current_dim, output_dim, output_activation, &mut rng)?;
layers.push(Box::new(output_layer));
Ok(Self {
task_name,
layers,
input_dim,
output_dim,
task_type,
})
}
pub fn forward(&self, input: &ArrayView2<f32>) -> Result<Array2<f32>> {
let mut current_output = input.to_owned().into_dyn();
for layer in &self.layers {
current_output = layer.forward(¤t_output)?;
}
current_output
.into_dimensionality()
.map_err(|e| NeuralError::ShapeMismatch(format!("Shape error: {:?}", e)))
}
pub fn task_name(&self) -> &str {
&self.task_name
}
pub fn task_type(&self) -> &TaskType {
&self.task_type
}
pub fn output_dim(&self) -> usize {
self.output_dim
}
pub fn parameters(&self) -> Vec<Array<f32, IxDyn>> {
let mut params = Vec::new();
for layer in &self.layers {
params.extend(layer.params());
}
params
}
pub fn compute_loss(
&self,
predictions: &ArrayView2<f32>,
targets: &ArrayView2<f32>,
) -> Result<f32> {
match &self.task_type {
TaskType::Classification { .. } => {
self.compute_cross_entropy_loss(predictions, targets)
}
TaskType::MultiLabel { .. } => {
self.compute_binary_cross_entropy_loss(predictions, targets)
}
TaskType::Regression { .. } => self.compute_mse_loss(predictions, targets),
TaskType::Structured { .. } => self.compute_mse_loss(predictions, targets),
}
}
fn compute_cross_entropy_loss(
&self,
predictions: &ArrayView2<f32>,
targets: &ArrayView2<f32>,
) -> Result<f32> {
let mut total_loss = 0.0;
let batch_size = predictions.shape()[0];
for i in 0..batch_size {
let pred_row = predictions.row(i);
let target_row = targets.row(i);
for (p, t) in pred_row.iter().zip(target_row.iter()) {
if *t > 0.0 {
total_loss -= t * p.max(1e-7).ln();
}
}
}
Ok(total_loss / batch_size as f32)
}
fn compute_binary_cross_entropy_loss(
&self,
predictions: &ArrayView2<f32>,
targets: &ArrayView2<f32>,
) -> Result<f32> {
let batch_size = predictions.shape()[0];
let num_labels = predictions.shape()[1];
let mut total_loss = 0.0;
for i in 0..batch_size {
for j in 0..num_labels {
let p = predictions[[i, j]].clamp(1e-7, 1.0 - 1e-7);
let t = targets[[i, j]];
total_loss -= t * p.ln() + (1.0 - t) * (1.0 - p).ln();
}
}
Ok(total_loss / (batch_size * num_labels) as f32)
}
fn compute_mse_loss(
&self,
predictions: &ArrayView2<f32>,
targets: &ArrayView2<f32>,
) -> Result<f32> {
let diff = predictions.to_owned() - targets;
let squared_diff = diff.mapv(|x| x * x);
squared_diff.mean().ok_or_else(|| {
NeuralError::InferenceError("Failed to compute mean of squared differences".to_string())
})
}
}
pub struct MultiTaskArchitecture {
shared_backbone: SharedBackbone,
task_heads: HashMap<String, TaskSpecificHead>,
task_weights: HashMap<String, f32>,
#[allow(dead_code)]
training: bool,
}
impl MultiTaskArchitecture {
pub fn new(
input_dim: usize,
backbone_layers: &[usize],
task_configs: &[(String, Vec<usize>, usize, TaskType)],
) -> Result<Self> {
let shared_backbone = SharedBackbone::new(input_dim, backbone_layers)?;
let backbone_output_dim = shared_backbone.output_dim();
let mut task_heads = HashMap::new();
let mut task_weights = HashMap::new();
for (task_name, head_layers, output_dim, task_type) in task_configs {
let head = TaskSpecificHead::new(
task_name.clone(),
backbone_output_dim,
head_layers,
*output_dim,
task_type.clone(),
)?;
task_heads.insert(task_name.clone(), head);
task_weights.insert(task_name.clone(), 1.0);
}
Ok(Self {
shared_backbone,
task_heads,
task_weights,
training: true,
})
}
pub fn forward_task(&self, input: &ArrayView2<f32>, task_name: &str) -> Result<Array2<f32>> {
let features = self.shared_backbone.forward(input)?;
if let Some(head) = self.task_heads.get(task_name) {
head.forward(&features.view())
} else {
Err(NeuralError::InvalidArgument(format!(
"Task '{}' not found",
task_name
)))
}
}
pub fn forward_all_tasks(
&self,
input: &ArrayView2<f32>,
) -> Result<HashMap<String, Array2<f32>>> {
let features = self.shared_backbone.forward(input)?;
let mut outputs = HashMap::new();
for (task_name, head) in &self.task_heads {
let task_output = head.forward(&features.view())?;
outputs.insert(task_name.clone(), task_output);
}
Ok(outputs)
}
pub fn compute_multi_task_loss(
&self,
predictions: &HashMap<String, Array2<f32>>,
targets: &HashMap<String, Array2<f32>>,
) -> Result<(f32, HashMap<String, f32>)> {
let mut total_loss = 0.0;
let mut task_losses = HashMap::new();
for (task_name, pred) in predictions {
if let (Some(target), Some(head), Some(&weight)) = (
targets.get(task_name),
self.task_heads.get(task_name),
self.task_weights.get(task_name),
) {
let task_loss = head.compute_loss(&pred.view(), &target.view())?;
let weighted_loss = weight * task_loss;
total_loss += weighted_loss;
task_losses.insert(task_name.clone(), task_loss);
}
}
Ok((total_loss, task_losses))
}
pub fn set_task_weights(&mut self, weights: HashMap<String, f32>) {
for (task_name, weight) in weights {
if self.task_heads.contains_key(&task_name) {
self.task_weights.insert(task_name, weight);
}
}
}
pub fn task_names(&self) -> Vec<String> {
self.task_heads.keys().cloned().collect()
}
pub fn set_training(&mut self, training: bool) {
self.training = training;
}
pub fn backbone_parameters(&self) -> Vec<Array<f32, IxDyn>> {
self.shared_backbone.parameters()
}
pub fn task_parameters(&self, task_name: &str) -> Result<Vec<Array<f32, IxDyn>>> {
if let Some(head) = self.task_heads.get(task_name) {
Ok(head.parameters())
} else {
Err(NeuralError::InvalidArgument(format!(
"Task '{}' not found",
task_name
)))
}
}
pub fn all_parameters(&self) -> HashMap<String, Vec<Array<f32, IxDyn>>> {
let mut all_params = HashMap::new();
all_params.insert("backbone".to_string(), self.backbone_parameters());
for (task_name, head) in &self.task_heads {
all_params.insert(format!("task_{}", task_name), head.parameters());
}
all_params
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shared_backbone_output_dim() {
let backbone = SharedBackbone::new(10, &[64, 32]).expect("SharedBackbone::new failed");
assert_eq!(backbone.output_dim(), 32);
}
#[test]
fn test_shared_backbone_forward() {
let backbone = SharedBackbone::new(10, &[16, 8]).expect("SharedBackbone::new failed");
let input = Array2::from_elem((5, 10), 0.5_f32);
let output = backbone.forward(&input.view()).expect("forward failed");
assert_eq!(output.shape()[0], 5);
assert_eq!(output.shape()[1], 8);
}
#[test]
fn test_task_specific_head_classification() {
let task_type = TaskType::Classification { num_classes: 5 };
let head = TaskSpecificHead::new("cls".to_string(), 16, &[8], 5, task_type)
.expect("TaskSpecificHead::new failed");
assert_eq!(head.task_name(), "cls");
assert_eq!(head.output_dim(), 5);
}
#[test]
fn test_multi_task_architecture() {
let task_configs = vec![
(
"task1".to_string(),
vec![16],
3,
TaskType::Classification { num_classes: 3 },
),
(
"task2".to_string(),
vec![8],
1,
TaskType::Regression { output_dim: 1 },
),
];
let arch = MultiTaskArchitecture::new(10, &[32, 16], &task_configs)
.expect("MultiTaskArchitecture::new failed");
let input = Array2::from_elem((2, 10), 0.5_f32);
let outputs = arch
.forward_all_tasks(&input.view())
.expect("forward_all_tasks failed");
assert_eq!(outputs.len(), 2);
assert!(outputs.contains_key("task1"));
assert!(outputs.contains_key("task2"));
}
#[test]
fn test_task_types() {
let classification = TaskType::Classification { num_classes: 10 };
let regression = TaskType::Regression { output_dim: 5 };
let multi_label = TaskType::MultiLabel { num_labels: 8 };
match classification {
TaskType::Classification { num_classes } => assert_eq!(num_classes, 10),
_ => unreachable!("Expected Classification task type"),
}
match regression {
TaskType::Regression { output_dim } => assert_eq!(output_dim, 5),
_ => unreachable!("Expected Regression task type"),
}
match multi_label {
TaskType::MultiLabel { num_labels } => assert_eq!(num_labels, 8),
_ => unreachable!("Expected MultiLabel task type"),
}
}
}