use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::{
errors::invalid_input,
layers::Linear,
tensor::Tensor,
traits::{Layer, Model},
Result,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MTLConfig {
pub architecture: MTLArchitecture,
pub loss_balancing: LossBalancingStrategy,
pub tasks: Vec<TaskConfig>,
pub use_task_embeddings: bool,
pub task_embedding_dim: usize,
pub use_auxiliary_tasks: bool,
pub auxiliary_tasks: Vec<AuxiliaryTaskConfig>,
pub task_clustering: Option<TaskClusteringConfig>,
pub evaluation_frequency: usize,
pub use_task_scheduling: bool,
pub task_scheduling: TaskSchedulingStrategy,
}
impl Default for MTLConfig {
fn default() -> Self {
Self {
architecture: MTLArchitecture::HardParameterSharing {
shared_layers: 8,
task_specific_layers: 2,
},
loss_balancing: LossBalancingStrategy::EqualWeighting,
tasks: Vec::new(),
use_task_embeddings: false,
task_embedding_dim: 64,
use_auxiliary_tasks: false,
auxiliary_tasks: Vec::new(),
task_clustering: None,
evaluation_frequency: 1000,
use_task_scheduling: false,
task_scheduling: TaskSchedulingStrategy::RoundRobin,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MTLArchitecture {
HardParameterSharing {
shared_layers: usize,
task_specific_layers: usize,
},
SoftParameterSharing {
regularization_weight: f32,
regularization_type: RegularizationType,
},
MultiGateMixtureOfExperts {
num_experts: usize,
expert_dim: usize,
num_gates: usize,
},
CrossStitchNetworks {
num_tasks: usize,
cross_stitch_layers: Vec<usize>,
},
TaskRoutingNetworks {
num_routers: usize,
routing_dim: usize,
},
ProgressiveNetworks {
lateral_connections: bool,
adapter_layers: bool,
},
AttentionBasedSharing {
attention_dim: usize,
num_attention_heads: usize,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RegularizationType {
L2Regularization,
TraceNorm,
GroupLasso,
ElasticNet { l1_weight: f32, l2_weight: f32 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LossBalancingStrategy {
EqualWeighting,
ManualWeighting { weights: Vec<f32> },
UncertaintyWeighting,
DynamicWeightAverage,
GradNorm { alpha: f32 },
TaskBalancedSampling,
FocalLoss { gamma: f32 },
MetaLearning { meta_lr: f32 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskConfig {
pub name: String,
pub task_type: TaskType,
pub weight: f32,
pub priority: TaskPriority,
pub is_main_task: bool,
pub learning_rate: Option<f32>,
pub batch_size: Option<usize>,
}
impl TaskConfig {
pub fn new(name: &str, task_type: TaskType) -> Self {
Self {
name: name.to_string(),
task_type,
weight: 1.0,
priority: TaskPriority::Normal,
is_main_task: false,
learning_rate: None,
batch_size: None,
}
}
pub fn with_weight(mut self, weight: f32) -> Self {
self.weight = weight;
self
}
pub fn with_priority(mut self, priority: TaskPriority) -> Self {
self.priority = priority;
self
}
pub fn as_main_task(mut self) -> Self {
self.is_main_task = true;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TaskType {
Classification {
num_classes: usize,
use_class_weights: bool,
},
Regression {
output_dim: usize,
loss_type: RegressionLossType,
},
SequenceLabeling { num_labels: usize, use_crf: bool },
Generation {
vocab_size: usize,
max_length: usize,
},
Ranking { ranking_type: RankingType },
Auxiliary { auxiliary_type: AuxiliaryType },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RegressionLossType {
MSE,
MAE,
Huber { delta: f32 },
LogCosh,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RankingType {
Pairwise,
Listwise,
Pointwise,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AuxiliaryType {
LanguageModeling,
MaskedLanguageModeling,
NextSentencePrediction,
SentenceOrderPrediction,
WordOrderPrediction,
Custom { name: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TaskPriority {
Low,
Normal,
High,
Critical,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuxiliaryTaskConfig {
pub name: String,
pub auxiliary_type: AuxiliaryType,
pub weight: f32,
pub frequency: AuxiliaryTaskFrequency,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AuxiliaryTaskFrequency {
EveryNSteps(usize),
WithProbability(f32),
Continuous,
EpochRange { start: usize, end: usize },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskClusteringConfig {
pub clustering_method: ClusteringMethod,
pub num_clusters: usize,
pub update_frequency: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ClusteringMethod {
GradientSimilarity,
PerformanceCorrelation,
DataSimilarity,
Manual { clusters: Vec<Vec<String>> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TaskSchedulingStrategy {
RoundRobin,
WeightedSampling,
PerformanceBased,
CurriculumBased { difficulty_order: Vec<String> },
Random,
}
pub struct MultiTaskLearningTrainer<M: Model> {
pub base_model: M,
pub task_heads: HashMap<String, TaskHead>,
pub config: MTLConfig,
pub task_weights: HashMap<String, f32>,
pub task_performance: HashMap<String, Vec<f32>>,
pub step_counter: usize,
pub scheduler_state: TaskSchedulerState,
pub gradient_stats: HashMap<String, GradientStats>,
}
impl<M: Model<Input = Tensor, Output = Tensor>> MultiTaskLearningTrainer<M> {
pub fn new(base_model: M, config: MTLConfig) -> Result<Self> {
let mut task_heads = HashMap::new();
let mut task_weights = HashMap::new();
for task_config in &config.tasks {
let task_head = TaskHead::new(&task_config.task_type)?;
task_heads.insert(task_config.name.clone(), task_head);
task_weights.insert(task_config.name.clone(), task_config.weight);
}
let scheduler_state = TaskSchedulerState::new(&config.task_scheduling);
Ok(Self {
base_model,
task_heads,
config,
task_weights,
task_performance: HashMap::new(),
step_counter: 0,
scheduler_state,
gradient_stats: HashMap::new(),
})
}
pub fn train_multi_task_step(
&mut self,
task_data: &HashMap<String, TaskBatch>,
) -> Result<MultiTaskOutput> {
let mut task_losses = HashMap::new();
let mut task_accuracies = HashMap::new();
let mut total_loss = Tensor::zeros(&[1])?;
let active_tasks = self.get_active_tasks(task_data)?;
for task_name in &active_tasks {
if let Some(batch) = task_data.get(task_name) {
let shared_features = self.base_model.forward(batch.inputs.clone())?;
let task_head = self
.task_heads
.get(task_name)
.ok_or_else(|| anyhow::anyhow!("Task head not found: {}", task_name))?;
let task_outputs = task_head.forward(&shared_features)?;
let task_loss = self.compute_task_loss(task_name, &task_outputs, &batch.targets)?;
let task_accuracy =
self.compute_task_accuracy(task_name, &task_outputs, &batch.targets)?;
task_losses.insert(task_name.clone(), task_loss.clone());
task_accuracies.insert(task_name.clone(), task_accuracy);
self.task_performance.entry(task_name.clone()).or_default().push(task_accuracy);
}
}
let balanced_losses = self.balance_losses(&task_losses)?;
for (task_name, loss) in &balanced_losses {
let weight = self.task_weights.get(task_name).copied().unwrap_or(1.0);
total_loss = total_loss.add(&loss.scalar_mul(weight)?)?;
}
self.update_task_weights(&task_losses)?;
if self.config.use_auxiliary_tasks {
let aux_loss = self.compute_auxiliary_losses(task_data)?;
total_loss = total_loss.add(&aux_loss)?;
}
self.step_counter += 1;
Ok(MultiTaskOutput {
total_loss,
task_losses: task_losses
.into_iter()
.map(|(k, v)| (k, v.to_scalar().unwrap_or(0.0)))
.collect(),
task_accuracies,
active_tasks,
task_weights: self.task_weights.clone(),
})
}
fn get_active_tasks(&mut self, task_data: &HashMap<String, TaskBatch>) -> Result<Vec<String>> {
match &self.config.task_scheduling {
TaskSchedulingStrategy::RoundRobin => {
let task_names: Vec<String> = task_data.keys().cloned().collect();
if task_names.is_empty() {
return Ok(Vec::new());
}
let current_task = &task_names[self.step_counter % task_names.len()];
Ok(vec![current_task.clone()])
},
TaskSchedulingStrategy::WeightedSampling => {
let mut weighted_tasks = Vec::new();
for task_config in &self.config.tasks {
if task_data.contains_key(&task_config.name) {
let weight = match task_config.priority {
TaskPriority::Low => 0.5,
TaskPriority::Normal => 1.0,
TaskPriority::High => 2.0,
TaskPriority::Critical => 3.0,
};
for _ in 0..(weight * 10.0) as usize {
weighted_tasks.push(task_config.name.clone());
}
}
}
if weighted_tasks.is_empty() {
return Ok(Vec::new());
}
let selected_task = &weighted_tasks[self.step_counter % weighted_tasks.len()];
Ok(vec![selected_task.clone()])
},
TaskSchedulingStrategy::Random => {
let task_names: Vec<String> = task_data.keys().cloned().collect();
if task_names.is_empty() {
return Ok(Vec::new());
}
let random_idx = fastrand::usize(..task_names.len());
Ok(vec![task_names[random_idx].clone()])
},
_ => {
Ok(task_data.keys().cloned().collect())
},
}
}
fn balance_losses(
&self,
task_losses: &HashMap<String, Tensor>,
) -> Result<HashMap<String, Tensor>> {
match &self.config.loss_balancing {
LossBalancingStrategy::EqualWeighting => Ok(task_losses.clone()),
LossBalancingStrategy::ManualWeighting { weights } => {
let mut balanced = HashMap::new();
for (i, (task_name, loss)) in task_losses.iter().enumerate() {
let weight = weights.get(i).copied().unwrap_or(1.0);
balanced.insert(task_name.clone(), loss.scalar_mul(weight)?);
}
Ok(balanced)
},
LossBalancingStrategy::UncertaintyWeighting => {
Ok(task_losses.clone()) },
LossBalancingStrategy::DynamicWeightAverage => {
self.apply_dynamic_weight_average(task_losses)
},
LossBalancingStrategy::GradNorm { alpha } => {
self.apply_gradnorm(task_losses, *alpha)
},
_ => Ok(task_losses.clone()),
}
}
fn apply_dynamic_weight_average(
&self,
task_losses: &HashMap<String, Tensor>,
) -> Result<HashMap<String, Tensor>> {
let mut balanced = HashMap::new();
if self.step_counter < 2 {
return Ok(task_losses.clone());
}
let temperature = 2.0;
for (task_name, loss) in task_losses {
let prev_loss = self.get_previous_task_loss(task_name);
let current_loss = loss.to_scalar().unwrap_or(0.0);
let weight = if prev_loss > 0.0 {
let relative_decrease = current_loss / prev_loss;
(relative_decrease / temperature).exp()
} else {
1.0
};
balanced.insert(task_name.clone(), loss.clone().mul_scalar(weight)?);
}
Ok(balanced)
}
fn apply_gradnorm(
&self,
task_losses: &HashMap<String, Tensor>,
_alpha: f32,
) -> Result<HashMap<String, Tensor>> {
Ok(task_losses.clone())
}
fn update_task_weights(&mut self, task_losses: &HashMap<String, Tensor>) -> Result<()> {
match &self.config.loss_balancing {
LossBalancingStrategy::DynamicWeightAverage => {
for (task_name, loss) in task_losses {
let current_loss = loss.to_scalar().unwrap_or(0.0);
if let Some(weight) = self.task_weights.get_mut(task_name) {
*weight = (*weight * 0.9 + current_loss * 0.1).clamp(0.1, 10.0);
}
}
},
_ => {
},
}
Ok(())
}
fn get_previous_task_loss(&self, _task_name: &str) -> f32 {
1.0
}
fn compute_auxiliary_losses(&self, task_data: &HashMap<String, TaskBatch>) -> Result<Tensor> {
let mut aux_loss: Tensor = Tensor::zeros(&[1])?;
for aux_config in &self.config.auxiliary_tasks {
if self.should_train_auxiliary_task(aux_config) {
if let Some(aux_data) = task_data.get(&aux_config.name) {
let aux_task_loss: Tensor =
self.compute_auxiliary_task_loss(aux_config, aux_data)?;
let weighted_loss: Tensor = aux_task_loss.mul_scalar(aux_config.weight)?;
aux_loss = aux_loss.add(&weighted_loss)?;
}
}
}
Ok(aux_loss)
}
fn should_train_auxiliary_task(&self, aux_config: &AuxiliaryTaskConfig) -> bool {
match &aux_config.frequency {
AuxiliaryTaskFrequency::EveryNSteps(n) => self.step_counter.is_multiple_of(*n),
AuxiliaryTaskFrequency::WithProbability(p) => fastrand::f32() < *p,
AuxiliaryTaskFrequency::Continuous => true,
AuxiliaryTaskFrequency::EpochRange { start, end } => {
let current_epoch = self.step_counter / 1000; current_epoch >= *start && current_epoch <= *end
},
}
}
fn compute_auxiliary_task_loss(
&self,
aux_config: &AuxiliaryTaskConfig,
data: &TaskBatch,
) -> Result<Tensor> {
let shared_features: Tensor = self.base_model.forward(data.inputs.clone())?;
match &aux_config.auxiliary_type {
AuxiliaryType::LanguageModeling => {
self.compute_lm_loss(&shared_features, &data.targets)
},
AuxiliaryType::MaskedLanguageModeling => {
self.compute_mlm_loss(&shared_features, &data.targets)
},
_ => {
Ok(Tensor::zeros(&[1])?)
},
}
}
fn compute_lm_loss(&self, _features: &Tensor, _targets: &Tensor) -> Result<Tensor> {
Tensor::zeros(&[1])
}
fn compute_mlm_loss(&self, _features: &Tensor, _targets: &Tensor) -> Result<Tensor> {
Tensor::zeros(&[1])
}
fn compute_task_loss(
&self,
task_name: &str,
outputs: &Tensor,
targets: &Tensor,
) -> Result<Tensor> {
let task_config = self
.config
.tasks
.iter()
.find(|t| t.name == task_name)
.ok_or_else(|| invalid_input(format!("Task not found: {}", task_name)))?;
match &task_config.task_type {
TaskType::Classification { .. } => {
let log_probs = outputs.softmax(-1)?;
let nll_loss = targets.mul(&log_probs)?.sum(Some(vec![1]), false)?;
Ok(nll_loss.mean()?.mul_scalar(-1.0)?)
},
TaskType::Regression { loss_type, .. } => {
match loss_type {
RegressionLossType::MSE => {
let diff = outputs.sub(targets)?;
Ok(diff.mul(&diff)?.mean()?)
},
RegressionLossType::MAE => {
let diff = outputs.sub(targets)?;
Ok(diff.abs()?.mean()?)
},
RegressionLossType::Huber { delta } => {
let diff = outputs.sub(targets)?;
let abs_diff = diff.abs()?;
let small_loss = diff.mul(&diff)?.mul_scalar(0.5)?;
let _large_loss =
abs_diff.mul_scalar(*delta)?.sub_scalar(*delta * *delta * 0.5)?;
Ok(small_loss.mean()?)
},
_ => {
let diff = outputs.sub(targets)?;
Ok(diff.mul(&diff)?.mean()?)
},
}
},
_ => {
Ok(Tensor::zeros(&[1])?)
},
}
}
fn compute_task_accuracy(
&self,
task_name: &str,
outputs: &Tensor,
targets: &Tensor,
) -> Result<f32> {
let task_config = self
.config
.tasks
.iter()
.find(|t| t.name == task_name)
.ok_or_else(|| invalid_input(format!("Task not found: {}", task_name)))?;
match &task_config.task_type {
TaskType::Classification { .. } => {
let predicted = outputs.argmax(-1)?;
let target_class = targets.argmax(-1)?;
let correct = (predicted.to_scalar().unwrap_or(-1.0)
== target_class.to_scalar().unwrap_or(-2.0))
as i32 as f32;
Ok(correct)
},
TaskType::Regression { .. } => {
let diff = outputs.sub(targets)?;
let mse = diff.mul(&diff)?.mean()?;
let mean_targets = targets.mean()?;
let diff_from_mean = targets.sub(&mean_targets)?;
let variance = diff_from_mean.pow_scalar(2.0)?.mean()?;
let r_squared =
1.0 - mse.to_scalar().unwrap_or(1.0) / variance.to_scalar().unwrap_or(1.0);
Ok(r_squared.max(0.0))
},
_ => Ok(0.0),
}
}
pub fn evaluate_all_tasks(
&self,
test_data: &HashMap<String, TaskBatch>,
) -> Result<MultiTaskEvaluation> {
let mut task_evaluations = HashMap::new();
for (task_name, batch) in test_data {
if let Some(task_head) = self.task_heads.get(task_name) {
let shared_features = self.base_model.forward(batch.inputs.clone())?;
let task_outputs = task_head.forward(&shared_features)?;
let loss = self.compute_task_loss(task_name, &task_outputs, &batch.targets)?;
let accuracy =
self.compute_task_accuracy(task_name, &task_outputs, &batch.targets)?;
task_evaluations.insert(
task_name.clone(),
TaskEvaluation {
task_name: task_name.clone(),
loss: loss.to_scalar().unwrap_or(0.0),
accuracy,
num_examples: batch.inputs.shape()[0],
},
);
}
}
let overall_accuracy = if !task_evaluations.is_empty() {
task_evaluations.values().map(|e| e.accuracy).sum::<f32>()
/ task_evaluations.len() as f32
} else {
0.0
};
Ok(MultiTaskEvaluation {
task_evaluations,
overall_accuracy,
step: self.step_counter,
})
}
pub fn get_mtl_stats(&self) -> MTLStats {
MTLStats {
num_tasks: self.config.tasks.len(),
task_weights: self.task_weights.clone(),
step_counter: self.step_counter,
architecture: self.config.architecture.clone(),
loss_balancing: self.config.loss_balancing.clone(),
}
}
}
pub struct TaskHead {
layers: Vec<Linear>,
#[allow(dead_code)]
task_type: TaskType,
}
impl TaskHead {
pub fn new(task_type: &TaskType) -> Result<Self> {
let mut layers = Vec::new();
match task_type {
TaskType::Classification { num_classes, .. } => {
layers.push(Linear::new(768, *num_classes, true)); },
TaskType::Regression { output_dim, .. } => {
layers.push(Linear::new(768, *output_dim, true));
},
_ => {
layers.push(Linear::new(768, 768, true));
},
}
Ok(Self {
layers,
task_type: task_type.clone(),
})
}
pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
let mut output = input.clone();
for layer in &self.layers {
output = layer.forward(output)?;
}
Ok(output)
}
}
#[derive(Debug, Clone)]
pub struct TaskBatch {
pub inputs: Tensor,
pub targets: Tensor,
pub task_name: String,
}
pub struct TaskSchedulerState {
pub current_task_index: usize,
pub task_counters: HashMap<String, usize>,
}
impl TaskSchedulerState {
pub fn new(_strategy: &TaskSchedulingStrategy) -> Self {
Self {
current_task_index: 0,
task_counters: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct GradientStats {
pub gradient_norm: f32,
pub gradient_variance: f32,
pub update_count: usize,
}
#[derive(Debug, Clone)]
pub struct MultiTaskOutput {
pub total_loss: Tensor,
pub task_losses: HashMap<String, f32>,
pub task_accuracies: HashMap<String, f32>,
pub active_tasks: Vec<String>,
pub task_weights: HashMap<String, f32>,
}
#[derive(Debug, Clone)]
pub struct TaskEvaluation {
pub task_name: String,
pub loss: f32,
pub accuracy: f32,
pub num_examples: usize,
}
#[derive(Debug, Clone)]
pub struct MultiTaskEvaluation {
pub task_evaluations: HashMap<String, TaskEvaluation>,
pub overall_accuracy: f32,
pub step: usize,
}
#[derive(Debug, Clone)]
pub struct MTLStats {
pub num_tasks: usize,
pub task_weights: HashMap<String, f32>,
pub step_counter: usize,
pub architecture: MTLArchitecture,
pub loss_balancing: LossBalancingStrategy,
}
pub mod utils {
use super::*;
pub fn hard_parameter_sharing_config(
tasks: Vec<TaskConfig>,
shared_layers: usize,
task_specific_layers: usize,
) -> MTLConfig {
MTLConfig {
architecture: MTLArchitecture::HardParameterSharing {
shared_layers,
task_specific_layers,
},
tasks,
..Default::default()
}
}
pub fn soft_parameter_sharing_config(
tasks: Vec<TaskConfig>,
regularization_weight: f32,
) -> MTLConfig {
MTLConfig {
architecture: MTLArchitecture::SoftParameterSharing {
regularization_weight,
regularization_type: RegularizationType::L2Regularization,
},
tasks,
..Default::default()
}
}
pub fn mmoe_config(tasks: Vec<TaskConfig>, num_experts: usize, expert_dim: usize) -> MTLConfig {
MTLConfig {
architecture: MTLArchitecture::MultiGateMixtureOfExperts {
num_experts,
expert_dim,
num_gates: tasks.len(),
},
tasks,
..Default::default()
}
}
pub fn classification_task(name: &str, num_classes: usize) -> TaskConfig {
TaskConfig::new(
name,
TaskType::Classification {
num_classes,
use_class_weights: false,
},
)
}
pub fn regression_task(name: &str, output_dim: usize) -> TaskConfig {
TaskConfig::new(
name,
TaskType::Regression {
output_dim,
loss_type: RegressionLossType::MSE,
},
)
}
pub fn mlm_auxiliary_task(weight: f32) -> AuxiliaryTaskConfig {
AuxiliaryTaskConfig {
name: "mlm".to_string(),
auxiliary_type: AuxiliaryType::MaskedLanguageModeling,
weight,
frequency: AuxiliaryTaskFrequency::EveryNSteps(10),
}
}
pub fn compute_task_similarity(
task_performances: &HashMap<String, Vec<f32>>,
) -> HashMap<(String, String), f32> {
let mut similarities = HashMap::new();
let tasks: Vec<String> = task_performances.keys().cloned().collect();
for i in 0..tasks.len() {
for j in i + 1..tasks.len() {
let task1 = &tasks[i];
let task2 = &tasks[j];
if let (Some(perf1), Some(perf2)) =
(task_performances.get(task1), task_performances.get(task2))
{
let similarity = compute_correlation(perf1, perf2);
similarities.insert((task1.clone(), task2.clone()), similarity);
similarities.insert((task2.clone(), task1.clone()), similarity);
}
}
}
similarities
}
pub fn compute_correlation(seq1: &[f32], seq2: &[f32]) -> f32 {
if seq1.len() != seq2.len() || seq1.is_empty() {
return 0.0;
}
let n = seq1.len() as f32;
let mean1 = seq1.iter().sum::<f32>() / n;
let mean2 = seq2.iter().sum::<f32>() / n;
let mut numerator = 0.0;
let mut denom1 = 0.0;
let mut denom2 = 0.0;
for i in 0..seq1.len() {
let diff1 = seq1[i] - mean1;
let diff2 = seq2[i] - mean2;
numerator += diff1 * diff2;
denom1 += diff1 * diff1;
denom2 += diff2 * diff2;
}
if denom1 * denom2 > 0.0 {
numerator / (denom1 * denom2).sqrt()
} else {
0.0
}
}
pub fn analyze_mtl_effectiveness(
single_task_performances: &HashMap<String, f32>,
multi_task_performances: &HashMap<String, f32>,
) -> MTLAnalysis {
let mut positive_transfer_tasks = Vec::new();
let mut negative_transfer_tasks = Vec::new();
let mut total_improvement = 0.0;
let mut num_tasks = 0;
for (task_name, &mtl_perf) in multi_task_performances {
if let Some(&single_perf) = single_task_performances.get(task_name) {
let improvement = mtl_perf - single_perf;
total_improvement += improvement;
num_tasks += 1;
if improvement > 0.0 {
positive_transfer_tasks.push(task_name.clone());
} else if improvement < 0.0 {
negative_transfer_tasks.push(task_name.clone());
}
}
}
let average_improvement =
if num_tasks > 0 { total_improvement / num_tasks as f32 } else { 0.0 };
MTLAnalysis {
average_improvement,
positive_transfer_tasks,
negative_transfer_tasks,
num_tasks,
}
}
}
#[derive(Debug, Clone)]
pub struct MTLAnalysis {
pub average_improvement: f32,
pub positive_transfer_tasks: Vec<String>,
pub negative_transfer_tasks: Vec<String>,
pub num_tasks: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mtl_config_default() {
let config = MTLConfig::default();
assert_eq!(config.tasks.len(), 0);
assert!(!config.use_task_embeddings);
assert!(!config.use_auxiliary_tasks);
if let MTLArchitecture::HardParameterSharing {
shared_layers,
task_specific_layers,
} = config.architecture
{
assert_eq!(shared_layers, 8);
assert_eq!(task_specific_layers, 2);
} else {
panic!("Expected HardParameterSharing architecture");
}
}
#[test]
fn test_task_config() {
let task = TaskConfig::new(
"test",
TaskType::Classification {
num_classes: 10,
use_class_weights: false,
},
);
assert_eq!(task.name, "test");
assert_eq!(task.weight, 1.0);
assert!(!task.is_main_task);
let weighted_task = task.with_weight(2.0);
assert_eq!(weighted_task.weight, 2.0);
}
#[test]
fn test_classification_task_util() {
let task = utils::classification_task("sentiment", 3);
assert_eq!(task.name, "sentiment");
if let TaskType::Classification { num_classes, .. } = task.task_type {
assert_eq!(num_classes, 3);
} else {
panic!("Expected Classification task type");
}
}
#[test]
fn test_regression_task_util() {
let task = utils::regression_task("score", 1);
assert_eq!(task.name, "score");
if let TaskType::Regression { output_dim, .. } = task.task_type {
assert_eq!(output_dim, 1);
} else {
panic!("Expected Regression task type");
}
}
#[test]
fn test_hard_parameter_sharing_config() {
let tasks = vec![
utils::classification_task("task1", 5),
utils::regression_task("task2", 1),
];
let config = utils::hard_parameter_sharing_config(tasks, 6, 2);
assert_eq!(config.tasks.len(), 2);
if let MTLArchitecture::HardParameterSharing {
shared_layers,
task_specific_layers,
} = config.architecture
{
assert_eq!(shared_layers, 6);
assert_eq!(task_specific_layers, 2);
} else {
panic!("Expected HardParameterSharing architecture");
}
}
#[test]
fn test_soft_parameter_sharing_config() {
let tasks = vec![utils::classification_task("task1", 5)];
let config = utils::soft_parameter_sharing_config(tasks, 0.01);
if let MTLArchitecture::SoftParameterSharing {
regularization_weight,
..
} = config.architecture
{
assert_eq!(regularization_weight, 0.01);
} else {
panic!("Expected SoftParameterSharing architecture");
}
}
#[test]
fn test_mmoe_config() {
let tasks = vec![
utils::classification_task("task1", 5),
utils::classification_task("task2", 3),
];
let config = utils::mmoe_config(tasks, 4, 128);
if let MTLArchitecture::MultiGateMixtureOfExperts {
num_experts,
expert_dim,
num_gates,
} = config.architecture
{
assert_eq!(num_experts, 4);
assert_eq!(expert_dim, 128);
assert_eq!(num_gates, 2);
} else {
panic!("Expected MultiGateMixtureOfExperts architecture");
}
}
#[test]
fn test_mlm_auxiliary_task() {
let aux_task = utils::mlm_auxiliary_task(0.1);
assert_eq!(aux_task.name, "mlm");
assert_eq!(aux_task.weight, 0.1);
if let AuxiliaryType::MaskedLanguageModeling = aux_task.auxiliary_type {
} else {
panic!("Expected MaskedLanguageModeling auxiliary type");
}
}
#[test]
fn test_compute_correlation() {
let seq1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let seq2 = vec![2.0, 4.0, 6.0, 8.0, 10.0];
let correlation = utils::compute_correlation(&seq1, &seq2);
assert!((correlation - 1.0).abs() < 1e-6);
let seq3 = vec![5.0, 4.0, 3.0, 2.0, 1.0]; let correlation_neg = utils::compute_correlation(&seq1, &seq3);
assert!((correlation_neg + 1.0).abs() < 1e-6);
}
#[test]
fn test_mtl_analysis() {
let mut single_task = HashMap::new();
single_task.insert("task1".to_string(), 0.8);
single_task.insert("task2".to_string(), 0.7);
single_task.insert("task3".to_string(), 0.6);
let mut multi_task = HashMap::new();
multi_task.insert("task1".to_string(), 0.85); multi_task.insert("task2".to_string(), 0.65); multi_task.insert("task3".to_string(), 0.65);
let analysis = utils::analyze_mtl_effectiveness(&single_task, &multi_task);
assert_eq!(analysis.num_tasks, 3);
assert_eq!(analysis.positive_transfer_tasks.len(), 2);
assert_eq!(analysis.negative_transfer_tasks.len(), 1);
assert!(analysis.positive_transfer_tasks.contains(&"task1".to_string()));
assert!(analysis.negative_transfer_tasks.contains(&"task2".to_string()));
}
}