use super::*;
#[test]
fn test_focal_loss_alpha_balancing() {
let predictions = Vector::from_slice(&[0.5]);
let targets_pos = Vector::from_slice(&[1.0]);
let targets_neg = Vector::from_slice(&[0.0]);
let loss_pos = focal_loss(&predictions, &targets_pos, 0.25, 2.0);
let loss_neg = focal_loss(&predictions, &targets_neg, 0.25, 2.0);
assert!(loss_neg > loss_pos);
}
#[test]
#[should_panic(expected = "same length")]
fn test_focal_loss_length_mismatch() {
let predictions = Vector::from_slice(&[0.9, 0.1]);
let targets = Vector::from_slice(&[1.0]);
let _ = focal_loss(&predictions, &targets, 0.25, 2.0);
}
#[test]
fn test_kl_divergence_identical() {
let p = Vector::from_slice(&[0.3, 0.4, 0.3]);
let q = Vector::from_slice(&[0.3, 0.4, 0.3]);
let kl = kl_divergence(&p, &q);
assert!((kl - 0.0).abs() < 1e-6);
}
#[test]
fn test_kl_divergence_different() {
let p = Vector::from_slice(&[0.9, 0.1]);
let q = Vector::from_slice(&[0.1, 0.9]);
let kl = kl_divergence(&p, &q);
assert!(kl > 0.0);
}
#[test]
fn test_kl_divergence_asymmetry() {
let p = Vector::from_slice(&[0.9, 0.1]);
let q = Vector::from_slice(&[0.5, 0.5]);
let kl_pq = kl_divergence(&p, &q);
let kl_qp = kl_divergence(&q, &p);
assert!((kl_pq - kl_qp).abs() > 0.01);
}
#[test]
fn test_kl_divergence_zero_in_p() {
let p = Vector::from_slice(&[1.0, 0.0, 0.0]);
let q = Vector::from_slice(&[0.5, 0.3, 0.2]);
let kl = kl_divergence(&p, &q);
assert!((kl - 2.0_f32.ln()).abs() < 1e-5);
}
#[test]
fn test_kl_divergence_handles_small_q() {
let p = Vector::from_slice(&[0.5, 0.5]);
let q = Vector::from_slice(&[0.999, 0.001]);
let kl = kl_divergence(&p, &q);
assert!(kl.is_finite());
assert!(kl > 0.0);
}
#[test]
#[should_panic(expected = "same length")]
fn test_kl_divergence_length_mismatch() {
let p = Vector::from_slice(&[0.5, 0.5]);
let q = Vector::from_slice(&[0.3, 0.3, 0.4]);
let _ = kl_divergence(&p, &q);
}
#[test]
fn test_triplet_loss_struct() {
let loss_fn = TripletLoss::new(0.3);
assert!((loss_fn.margin() - 0.3).abs() < 1e-6);
let anchor = Vector::from_slice(&[1.0, 0.0]);
let positive = Vector::from_slice(&[0.9, 0.1]);
let negative = Vector::from_slice(&[0.0, 1.0]);
let loss = loss_fn.compute_triplet(&anchor, &positive, &negative);
assert!(loss >= 0.0);
}
#[test]
fn test_focal_loss_struct() {
let loss_fn = FocalLoss::new(0.25, 2.0);
assert!((loss_fn.alpha() - 0.25).abs() < 1e-6);
assert!((loss_fn.gamma() - 2.0).abs() < 1e-6);
let predictions = Vector::from_slice(&[0.9, 0.1]);
let targets = Vector::from_slice(&[1.0, 0.0]);
let loss = loss_fn.compute(&predictions, &targets);
assert!(loss >= 0.0);
assert_eq!(loss_fn.name(), "Focal");
}
#[test]
fn test_focal_loss_trait_polymorphism() {
let loss_fns: Vec<Box<dyn Loss>> = vec![Box::new(MSELoss), Box::new(FocalLoss::new(0.25, 2.0))];
let y_pred = Vector::from_slice(&[0.9, 0.1]);
let y_true = Vector::from_slice(&[1.0, 0.0]);
for loss_fn in loss_fns {
let loss = loss_fn.compute(&y_pred, &y_true);
assert!(loss >= 0.0);
}
}
#[test]
fn test_info_nce_loss_struct() {
let loss_fn = InfoNCELoss::new(0.1);
assert!((loss_fn.temperature() - 0.1).abs() < 1e-6);
let anchor = Vector::from_slice(&[1.0, 0.0, 0.0]);
let positive = Vector::from_slice(&[0.9, 0.1, 0.0]);
let negatives = vec![Vector::from_slice(&[0.0, 1.0, 0.0])];
let loss = loss_fn.compute_contrastive(&anchor, &positive, &negatives);
assert!(loss >= 0.0);
}
#[test]
fn test_euclidean_distance_via_triplet() {
let anchor = Vector::from_slice(&[1.0, 2.0, 3.0]);
let positive = anchor.clone();
let negative = Vector::from_slice(&[4.0, 5.0, 6.0]);
let loss = triplet_loss(&anchor, &positive, &negative, 0.5);
assert!((loss - 0.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_via_info_nce() {
let anchor = Vector::from_slice(&[1.0, 0.0]);
let negatives = vec![Vector::from_slice(&[0.0, 1.0])];
let positive_aligned = Vector::from_slice(&[1.0, 0.0]);
let loss_aligned = info_nce_loss(&anchor, &positive_aligned, &negatives, 0.5);
let positive_ortho = Vector::from_slice(&[0.0, 1.0]);
let loss_ortho = info_nce_loss(&anchor, &positive_ortho, &negatives, 0.5);
assert!(loss_ortho > loss_aligned);
assert!(loss_aligned >= 0.0);
}
#[test]
fn test_contrastive_losses_numerical_stability() {
let anchor = Vector::from_slice(&[1e6, 1e-6]);
let positive = Vector::from_slice(&[1e6, 1e-6]);
let negative = Vector::from_slice(&[-1e6, 1e-6]);
let triplet = triplet_loss(&anchor, &positive, &negative, 0.1);
assert!(triplet.is_finite());
let info_nce = info_nce_loss(&anchor, &positive, &[negative.clone()], 0.1);
assert!(info_nce.is_finite());
}
#[test]
fn test_dice_loss_perfect() {
let y_pred = Vector::from_slice(&[1.0, 0.0, 1.0, 0.0]);
let y_true = Vector::from_slice(&[1.0, 0.0, 1.0, 0.0]);
let loss = dice_loss(&y_pred, &y_true, 1e-6);
assert!(loss < 0.01);
}
#[test]
fn test_dice_loss_no_overlap() {
let y_pred = Vector::from_slice(&[1.0, 1.0, 0.0, 0.0]);
let y_true = Vector::from_slice(&[0.0, 0.0, 1.0, 1.0]);
let loss = dice_loss(&y_pred, &y_true, 1e-6);
assert!(loss > 0.99);
}
#[test]
fn test_dice_loss_partial_overlap() {
let y_pred = Vector::from_slice(&[1.0, 1.0, 0.0, 0.0]);
let y_true = Vector::from_slice(&[1.0, 0.0, 1.0, 0.0]);
let loss = dice_loss(&y_pred, &y_true, 1e-6);
assert!(loss > 0.0 && loss < 1.0);
}
#[test]
fn test_dice_loss_struct() {
let loss_fn = DiceLoss::new(1.0);
assert_eq!(loss_fn.smooth(), 1.0);
assert_eq!(loss_fn.name(), "Dice");
}
#[test]
fn test_hinge_loss_correct() {
let y_pred = Vector::from_slice(&[2.0, -2.0]);
let y_true = Vector::from_slice(&[1.0, -1.0]);
let loss = hinge_loss(&y_pred, &y_true, 1.0);
assert!((loss - 0.0).abs() < 1e-6);
}
#[test]
fn test_hinge_loss_incorrect() {
let y_pred = Vector::from_slice(&[-1.0, 1.0]);
let y_true = Vector::from_slice(&[1.0, -1.0]);
let loss = hinge_loss(&y_pred, &y_true, 1.0);
assert!(loss > 0.0);
}
#[test]
fn test_hinge_loss_margin_violation() {
let y_pred = Vector::from_slice(&[0.5]);
let y_true = Vector::from_slice(&[1.0]);
let loss = hinge_loss(&y_pred, &y_true, 1.0);
assert!((loss - 0.5).abs() < 1e-6);
}
#[test]
fn test_squared_hinge_loss() {
let y_pred = Vector::from_slice(&[0.5]);
let y_true = Vector::from_slice(&[1.0]);
let loss = squared_hinge_loss(&y_pred, &y_true, 1.0);
assert!((loss - 0.25).abs() < 1e-6);
}
#[test]
fn test_hinge_loss_struct() {
let loss_fn = HingeLoss::new(1.0);
assert_eq!(loss_fn.margin(), 1.0);
assert_eq!(loss_fn.name(), "Hinge");
}
#[test]
fn test_ctc_loss_creation() {
let ctc = CTCLoss::new(0);
assert_eq!(ctc.blank_idx(), 0);
}
#[test]
fn test_ctc_loss_simple() {
let ctc = CTCLoss::new(0);
let log_probs = vec![
vec![-1.0, -0.7, -0.7], vec![-0.7, -1.0, -0.7], vec![-0.7, -0.7, -1.0], ];
let targets = vec![1, 2]; let loss = ctc.forward(&log_probs, &targets, 3, 2);
assert!(loss.is_finite());
}
#[test]
fn test_ctc_loss_empty_target() {
let ctc = CTCLoss::new(0);
let log_probs = vec![vec![0.0; 3]; 5];
let loss = ctc.forward(&log_probs, &[], 5, 0);
assert_eq!(loss, 0.0);
}
#[test]
fn test_ctc_loss_single_char() {
let ctc = CTCLoss::new(0);
let log_probs = vec![
vec![-10.0, -0.01, -10.0], vec![-10.0, -0.01, -10.0],
];
let targets = vec![1];
let loss = ctc.forward(&log_probs, &targets, 2, 1);
assert!(loss.is_finite());
}
#[test]
fn test_wasserstein_loss_equal() {
let real = Vector::from_slice(&[1.0, 2.0, 3.0]);
let fake = Vector::from_slice(&[1.0, 2.0, 3.0]);
let loss = wasserstein_loss(&real, &fake);
assert!((loss - 0.0).abs() < 1e-6);
}
#[test]
fn test_wasserstein_loss_diff() {
let real = Vector::from_slice(&[1.0, 1.0, 1.0]);
let fake = Vector::from_slice(&[2.0, 2.0, 2.0]);
let loss = wasserstein_loss(&real, &fake);
assert!((loss - 1.0).abs() < 1e-6);
}
#[test]
fn test_wasserstein_discriminator_loss() {
let real = Vector::from_slice(&[3.0, 3.0]);
let fake = Vector::from_slice(&[1.0, 1.0]);
let loss = wasserstein_discriminator_loss(&real, &fake);
assert!((loss - 2.0).abs() < 1e-6);
}
#[test]
fn test_wasserstein_generator_loss() {
let fake = Vector::from_slice(&[1.0, 2.0, 3.0]);
let loss = wasserstein_generator_loss(&fake);
assert!((loss - (-2.0)).abs() < 1e-6);
}
#[test]
fn test_gradient_penalty() {
let grads = vec![0.6, 0.8]; let penalty = gradient_penalty(&grads, 10.0);
assert!((penalty - 0.0).abs() < 1e-6);
let grads2 = vec![1.2, 1.6]; let penalty2 = gradient_penalty(&grads2, 10.0);
assert!((penalty2 - 10.0).abs() < 1e-6);
}
#[test]
fn test_wasserstein_loss_struct() {
let loss_fn = WassersteinLoss::new(10.0);
assert_eq!(loss_fn.lambda_gp(), 10.0);
assert_eq!(loss_fn.name(), "Wasserstein");
let real = Vector::from_slice(&[2.0, 2.0]);
let fake = Vector::from_slice(&[1.0, 1.0]);
let d_loss = loss_fn.discriminator_loss(&real, &fake);
assert!(d_loss > 0.0);
}