use scirs2_core::ndarray::Array2;
use tensorlogic_train::{
CrossEntropyLoss, DistillationLoss, LabelSmoothingLoss, LinearModel, Loss, Model, MseLoss,
MultiTaskLoss, TrainError,
};
#[test]
fn test_distillation_with_label_smoothing() -> Result<(), TrainError> {
let num_classes = 5;
let batch_size = 8;
let mut teacher_logits = Array2::zeros((batch_size, num_classes));
let mut student_logits = Array2::zeros((batch_size, num_classes));
let mut targets = Array2::zeros((batch_size, num_classes));
for i in 0..batch_size {
let target_class = i % num_classes;
targets[[i, target_class]] = 1.0;
for j in 0..num_classes {
teacher_logits[[i, j]] = if j == target_class { 2.0 } else { -0.5 };
student_logits[[i, j]] = if j == target_class { 1.5 } else { -0.3 };
}
}
let ce_loss = CrossEntropyLoss::default();
let distillation = DistillationLoss::new(3.0, 0.7, Box::new(ce_loss))?;
let loss1 = distillation.compute_distillation(
&student_logits.view(),
&teacher_logits.view(),
&targets.view(),
)?;
assert!(loss1 > 0.0 && loss1.is_finite());
let ls_loss = LabelSmoothingLoss::new(0.1, num_classes)?;
let distillation_ls = DistillationLoss::new(3.0, 0.7, Box::new(ls_loss))?;
let loss2 = distillation_ls.compute_distillation(
&student_logits.view(),
&teacher_logits.view(),
&targets.view(),
)?;
assert!(loss2 > 0.0 && loss2.is_finite());
assert!(loss2 < loss1 + 1.0);
Ok(())
}
#[test]
fn test_multitask_with_progressive_difficulty() -> Result<(), TrainError> {
let batch_size = 16;
let task1_outputs = 5;
let task2_outputs = 10;
let total_outputs = task1_outputs + task2_outputs;
let mut predictions = Array2::zeros((batch_size, total_outputs));
let mut targets = Array2::zeros((batch_size, total_outputs));
for i in 0..batch_size {
let task1_class = i % task1_outputs;
targets[[i, task1_class]] = 1.0;
predictions[[i, task1_class]] = 1.5;
for j in task1_outputs..total_outputs {
targets[[i, j]] = (j - task1_outputs) as f64 * 0.1;
predictions[[i, j]] = targets[[i, j]] + 0.05;
}
}
let task_splits = vec![0, task1_outputs, total_outputs];
let losses: Vec<Box<dyn Loss>> = vec![Box::new(CrossEntropyLoss::default()), Box::new(MseLoss)];
let weights = vec![0.6, 0.4];
let mut mt_loss = MultiTaskLoss::new_fixed(losses, weights)?;
let difficulties = vec![0.5, 0.75, 1.0];
let mut losses_at_difficulties = Vec::new();
for &difficulty in &difficulties {
let scaled_preds = &predictions * difficulty;
let loss =
mt_loss.compute_multi_task(&scaled_preds.view(), &targets.view(), &task_splits)?;
losses_at_difficulties.push(loss);
}
assert_eq!(losses_at_difficulties.len(), 3);
assert!(losses_at_difficulties[0] > 0.0);
Ok(())
}
#[test]
fn test_combined_regularization() -> Result<(), TrainError> {
let num_classes = 5;
let batch_size = 16;
let mut predictions = Array2::zeros((batch_size, num_classes));
let mut targets = Array2::zeros((batch_size, num_classes));
for i in 0..batch_size {
let target_class = i % num_classes;
targets[[i, target_class]] = 1.0;
for j in 0..num_classes {
predictions[[i, j]] = if j == target_class {
2.0 + (i as f64 * 0.1)
} else {
-0.5 + (j as f64 * 0.05)
};
}
}
let ce_loss = CrossEntropyLoss::default();
let loss_ce = ce_loss.compute(&predictions.view(), &targets.view())?;
let ls_loss = LabelSmoothingLoss::new(0.1, num_classes)?;
let loss_ls = ls_loss.compute(&predictions.view(), &targets.view())?;
assert!(loss_ce > 0.0 && loss_ce.is_finite());
assert!(loss_ls > 0.0 && loss_ls.is_finite());
assert!(loss_ce > 0.0 && loss_ls > 0.0);
Ok(())
}
#[test]
fn test_integrated_training_workflow() -> Result<(), TrainError> {
let input_size = 20;
let output_size = 10;
let batch_size = 32;
let model = LinearModel::new(input_size, output_size);
let mut inputs = Array2::zeros((batch_size, input_size));
let mut targets = Array2::zeros((batch_size, output_size));
for i in 0..batch_size {
for j in 0..input_size {
inputs[[i, j]] = (i as f64 + j as f64) * 0.01;
}
let target_class = i % output_size;
targets[[i, target_class]] = 1.0;
}
let loss_fn = CrossEntropyLoss::default();
let predictions = model.forward(&inputs.view())?;
assert_eq!(predictions.shape(), &[batch_size, output_size]);
let loss = loss_fn.compute(&predictions.view(), &targets.view())?;
assert!(loss > 0.0 && loss.is_finite());
let grad = loss_fn.gradient(&predictions.view(), &targets.view())?;
assert_eq!(grad.shape(), predictions.shape());
let grad_norm: f64 = grad.iter().map(|&x| x * x).sum::<f64>().sqrt();
assert!(grad_norm > 0.0 && grad_norm.is_finite());
Ok(())
}
#[test]
fn test_ensemble_with_mixed_objectives() -> Result<(), TrainError> {
use tensorlogic_train::{DiceLoss, TverskyLoss};
let num_classes = 3;
let batch_size = 16;
let mut predictions = Array2::zeros((batch_size, num_classes));
let mut targets = Array2::zeros((batch_size, num_classes));
for i in 0..batch_size {
let target_class = i % num_classes;
targets[[i, target_class]] = 1.0;
for j in 0..num_classes {
predictions[[i, j]] = if j == target_class { 2.0 } else { 0.5 };
}
}
let dice_loss = DiceLoss::default();
let loss1 = dice_loss.compute(&predictions.view(), &targets.view())?;
assert!(loss1 >= 0.0 && loss1.is_finite());
let tversky_loss = TverskyLoss::default();
let loss2 = tversky_loss.compute(&predictions.view(), &targets.view())?;
assert!(loss2 >= 0.0 && loss2.is_finite());
assert!(loss1.is_finite() && loss2.is_finite());
Ok(())
}
#[test]
fn test_stateful_training_checkpoint() -> Result<(), TrainError> {
let input_size = 10;
let output_size = 5;
let model = LinearModel::new(input_size, output_size);
let state_dict = model.state_dict();
assert!(!state_dict.is_empty(), "State dict should not be empty");
let total_params: usize = state_dict.values().map(|v| v.len()).sum();
let expected_params = input_size * output_size + output_size; assert_eq!(
total_params, expected_params,
"Total parameters should match model size"
);
Ok(())
}
#[test]
fn test_multiple_loss_combinations() -> Result<(), TrainError> {
use tensorlogic_train::{HuberLoss, MixupLoss};
let num_samples = 10;
let num_features = 5;
let mut predictions = Array2::zeros((num_samples, num_features));
let mut targets = Array2::zeros((num_samples, num_features));
for i in 0..num_samples {
for j in 0..num_features {
predictions[[i, j]] = (j + 1) as f64;
targets[[i, j]] = (j + 1) as f64 + 0.1;
}
}
let mse_loss = MseLoss;
let loss1 = mse_loss.compute(&predictions.view(), &targets.view())?;
assert!(loss1 > 0.0 && loss1.is_finite());
let huber_loss = HuberLoss::default();
let loss2 = huber_loss.compute(&predictions.view(), &targets.view())?;
assert!(loss2 > 0.0 && loss2.is_finite());
let mixup_loss = MixupLoss::new(1.0, Box::new(MseLoss))?;
let loss3 = mixup_loss.compute_mixup(&predictions.view(), &targets.view())?;
assert!(loss3 > 0.0 && loss3.is_finite());
assert!(loss1.is_finite() && loss2.is_finite() && loss3.is_finite());
Ok(())
}