use ndarray::{array, Array1, Array2};
use super::utils::{cross_entropy_loss, kl_divergence, l2_normalize, log_softmax, softmax};
use super::{AttentionTransfer, DistillationLoss, ProgressiveDistillation};
#[test]
fn test_softmax_sums_to_one() {
let logits = array![1.0, 2.0, 3.0, 4.0];
let probs = softmax(&logits);
let sum: f32 = probs.sum();
assert!((sum - 1.0).abs() < 1e-5);
}
#[test]
fn test_softmax_all_positive() {
let logits = array![-100.0, 0.0, 100.0];
let probs = softmax(&logits);
for p in &probs {
assert!(*p >= 0.0);
}
}
#[test]
fn test_softmax_numerical_stability() {
let logits = array![1000.0, 1001.0, 1002.0];
let probs = softmax(&logits);
assert!(probs.iter().all(|&p| p.is_finite()));
assert!((probs.sum() - 1.0).abs() < 1e-5);
}
#[test]
fn test_log_softmax_identity() {
let logits = array![1.0, 2.0, 3.0];
let log_probs = log_softmax(&logits);
let probs_from_log: Array1<f32> = log_probs.mapv(f32::exp);
let probs = softmax(&logits);
for (a, b) in probs.iter().zip(probs_from_log.iter()) {
assert!((a - b).abs() < 1e-5);
}
}
#[test]
fn test_kl_divergence_zero_for_same() {
let p = softmax(&array![1.0, 2.0, 3.0]);
let log_p = log_softmax(&array![1.0, 2.0, 3.0]);
let kl = kl_divergence(&log_p, &p);
assert!(kl.abs() < 1e-5);
}
#[test]
fn test_kl_divergence_positive() {
let p = softmax(&array![1.0, 2.0, 3.0]);
let log_q = log_softmax(&array![3.0, 2.0, 1.0]);
let kl = kl_divergence(&log_q, &p);
assert!(kl >= 0.0);
}
#[test]
fn test_distillation_loss_default() {
let loss = DistillationLoss::default();
assert_eq!(loss.temperature, 4.0);
assert_eq!(loss.alpha, 0.7);
}
#[test]
fn test_distillation_loss_positive() {
let loss = DistillationLoss::new(4.0, 0.5);
let student = array![1.0, 2.0, 3.0];
let teacher = array![1.5, 2.5, 2.0];
let l = loss.forward_single(&student, &teacher, 2);
assert!(l >= 0.0);
}
#[test]
fn test_distillation_loss_zero_alpha() {
let loss = DistillationLoss::new(4.0, 0.0);
let student = array![1.0, 2.0, 3.0];
let teacher = array![100.0, 200.0, 300.0]; let l = loss.forward_single(&student, &teacher, 2);
let ce = cross_entropy_loss(&student, 2);
assert!((l - ce).abs() < 0.01);
}
#[test]
fn test_distillation_loss_high_temp() {
let loss_low = DistillationLoss::new(1.0, 1.0);
let loss_high = DistillationLoss::new(10.0, 1.0);
let student = array![1.0, 2.0, 3.0];
let teacher = array![1.0, 2.0, 3.0];
let l_low = loss_low.soft_loss(&student, &teacher);
let l_high = loss_high.soft_loss(&student, &teacher);
assert!(l_low.abs() < 0.1);
assert!(l_high.abs() < 0.1);
}
#[test]
fn test_distillation_loss_batch() {
let loss = DistillationLoss::new(4.0, 0.5);
let student = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 2.0, 1.0, 3.0])
.expect("operation should succeed");
let teacher = Array2::from_shape_vec((2, 3), vec![1.5, 2.5, 2.5, 2.5, 1.5, 2.5])
.expect("operation should succeed");
let targets = vec![2, 0];
let l = loss.forward(&student, &teacher, &targets);
assert!(l >= 0.0);
assert!(l.is_finite());
}
#[test]
fn test_progressive_default() {
let prog = ProgressiveDistillation::default();
assert!(!prog.layer_mapping.is_empty());
assert_eq!(prog.hidden_weight, 1.0);
}
#[test]
fn test_progressive_hidden_loss_zero_for_same() {
let prog = ProgressiveDistillation::new(vec![(0, 0), (1, 1)]);
let hidden = Array2::<f32>::ones((4, 768));
let student = vec![hidden.clone(), hidden.clone()];
let teacher = vec![hidden.clone(), hidden.clone()];
let loss = prog.hidden_state_loss(&student, &teacher);
assert!(loss.abs() < 1e-5);
}
#[test]
fn test_progressive_hidden_loss_positive_for_diff() {
let prog = ProgressiveDistillation::new(vec![(0, 0)]);
let s = Array2::<f32>::zeros((4, 768));
let t = Array2::<f32>::ones((4, 768));
let loss = prog.hidden_state_loss(&[s], &[t]);
assert!(loss > 0.0);
}
#[test]
fn test_progressive_with_weight() {
let prog = ProgressiveDistillation::new(vec![(0, 0)]).with_weight(0.5);
assert_eq!(prog.hidden_weight, 0.5);
}
#[test]
fn test_progressive_projection_layer_creation() {
let prog = ProgressiveDistillation::new(vec![(0, 0)]).with_projection(512, 768);
assert!(prog.projection.is_some());
let proj = prog.projection.as_ref().expect("operation should succeed");
assert_eq!(proj.dim(), (512, 768));
}
#[test]
fn test_progressive_hidden_loss_with_projection() {
let prog = ProgressiveDistillation::new(vec![(0, 0)]).with_projection(512, 768);
let student = vec![Array2::<f32>::ones((4, 512))];
let teacher = vec![Array2::<f32>::ones((4, 768))];
let loss = prog.hidden_state_loss(&student, &teacher);
assert!(loss >= 0.0);
}
#[test]
fn test_progressive_projection_correct_transform() {
let mut prog = ProgressiveDistillation::new(vec![(0, 0)]).with_projection(768, 768);
if let Some(ref mut proj) = prog.projection {
proj.fill(0.0);
for i in 0..768 {
proj[[i, i]] = 1.0;
}
}
let hidden = Array2::<f32>::from_elem((4, 768), 1.0);
let student = vec![hidden.clone()];
let teacher = vec![hidden.clone()];
let loss = prog.hidden_state_loss(&student, &teacher);
assert!(loss.abs() < 1e-4, "Identity projection should give ~0 loss");
}
#[test]
fn test_progressive_no_projection_skips_mismatched() {
let prog = ProgressiveDistillation::new(vec![(0, 0)]);
let student = vec![Array2::<f32>::ones((4, 512))];
let teacher = vec![Array2::<f32>::ones((4, 768))];
let loss = prog.hidden_state_loss(&student, &teacher);
assert_eq!(loss, 0.0, "Should skip mismatched shapes without projection");
}
#[test]
fn test_attention_transfer_default() {
let at = AttentionTransfer::default();
assert_eq!(at.weight, 0.1);
}
#[test]
fn test_attention_transfer_zero_for_same() {
let at = AttentionTransfer::new(1.0);
let attn = Array2::<f32>::ones((8, 8));
let student = vec![attn.clone()];
let teacher = vec![attn.clone()];
let loss = at.loss(&student, &teacher);
assert!(loss.abs() < 1e-5);
}
#[test]
fn test_attention_transfer_positive_for_diff() {
let at = AttentionTransfer::new(1.0);
let s = Array2::<f32>::zeros((8, 8));
let t = Array2::<f32>::ones((8, 8));
let loss = at.loss(&[s], &[t]);
assert!(loss > 0.0);
}
#[test]
fn test_l2_normalize_unit_norm() {
let arr =
Array2::from_shape_vec((2, 2), vec![3.0, 4.0, 0.0, 0.0]).expect("operation should succeed");
let norm = l2_normalize(&arr);
let l2 = norm.mapv(|x| x * x).sum().sqrt();
assert!((l2 - 1.0).abs() < 1e-5);
}
#[test]
fn test_l2_normalize_zero() {
let arr = Array2::<f32>::zeros((2, 2));
let norm = l2_normalize(&arr);
assert!(norm.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_distillation_loss_monotonic_in_alpha() {
let student = array![1.0, 2.0, 3.0];
let teacher = array![3.0, 2.0, 1.0];
let loss_0 = DistillationLoss::new(4.0, 0.0).forward_single(&student, &teacher, 2);
let loss_1 = DistillationLoss::new(4.0, 1.0).forward_single(&student, &teacher, 2);
assert!(loss_0 >= 0.0);
assert!(loss_1 >= 0.0);
}
#[test]
fn test_temperature_scaling_effect() {
let student = array![1.0, 2.0, 3.0];
let teacher = array![0.5, 2.0, 3.5];
let loss_t1 = DistillationLoss::new(1.0, 1.0).soft_loss(&student, &teacher);
let loss_t10 = DistillationLoss::new(10.0, 1.0).soft_loss(&student, &teacher);
assert!(loss_t1.is_finite());
assert!(loss_t10.is_finite());
}