pub(crate) use super::*;
#[test]
fn test_mse_loss_perfect() {
let y_true = Vector::from_slice(&[1.0, 2.0, 3.0]);
let y_pred = Vector::from_slice(&[1.0, 2.0, 3.0]);
let loss = mse_loss(&y_pred, &y_true);
assert!((loss - 0.0).abs() < 1e-6);
}
#[test]
fn test_mse_loss_basic() {
let y_true = Vector::from_slice(&[1.0, 2.0, 3.0]);
let y_pred = Vector::from_slice(&[2.0, 3.0, 4.0]);
let loss = mse_loss(&y_pred, &y_true);
assert!((loss - 1.0).abs() < 1e-6);
}
#[test]
fn test_mse_loss_different_errors() {
let y_true = Vector::from_slice(&[0.0, 0.0, 0.0]);
let y_pred = Vector::from_slice(&[1.0, 2.0, 3.0]);
let loss = mse_loss(&y_pred, &y_true);
assert!((loss - 14.0 / 3.0).abs() < 1e-5);
}
#[test]
#[should_panic(expected = "same length")]
fn test_mse_loss_mismatched_lengths() {
let y_true = Vector::from_slice(&[1.0, 2.0]);
let y_pred = Vector::from_slice(&[1.0, 2.0, 3.0]);
let _ = mse_loss(&y_pred, &y_true);
}
#[test]
fn test_mae_loss_perfect() {
let y_true = Vector::from_slice(&[1.0, 2.0, 3.0]);
let y_pred = Vector::from_slice(&[1.0, 2.0, 3.0]);
let loss = mae_loss(&y_pred, &y_true);
assert!((loss - 0.0).abs() < 1e-6);
}
#[test]
fn test_mae_loss_basic() {
let y_true = Vector::from_slice(&[1.0, 2.0, 3.0]);
let y_pred = Vector::from_slice(&[1.5, 2.5, 2.5]);
let loss = mae_loss(&y_pred, &y_true);
assert!((loss - 0.5).abs() < 1e-6);
}
#[test]
fn test_mae_loss_outlier_robustness() {
let y_true = Vector::from_slice(&[1.0, 2.0, 3.0]);
let y_pred = Vector::from_slice(&[2.0, 3.0, 100.0]);
let mae = mae_loss(&y_pred, &y_true);
let mse = mse_loss(&y_pred, &y_true);
assert!(mae < mse / 10.0);
}
#[test]
#[should_panic(expected = "same length")]
fn test_mae_loss_mismatched_lengths() {
let y_true = Vector::from_slice(&[1.0, 2.0]);
let y_pred = Vector::from_slice(&[1.0, 2.0, 3.0]);
let _ = mae_loss(&y_pred, &y_true);
}
#[test]
fn test_huber_loss_small_errors() {
let y_true = Vector::from_slice(&[1.0, 2.0, 3.0]);
let y_pred = Vector::from_slice(&[1.5, 2.5, 3.5]);
let loss = huber_loss(&y_pred, &y_true, 1.0);
assert!((loss - 0.125).abs() < 1e-6);
}
#[test]
fn test_huber_loss_large_errors() {
let y_true = Vector::from_slice(&[0.0, 0.0, 0.0]);
let y_pred = Vector::from_slice(&[5.0, 5.0, 5.0]);
let loss = huber_loss(&y_pred, &y_true, 1.0);
assert!((loss - 4.5).abs() < 1e-6);
}
#[test]
fn test_huber_loss_mixed_errors() {
let y_true = Vector::from_slice(&[0.0, 0.0]);
let y_pred = Vector::from_slice(&[0.5, 5.0]);
let loss = huber_loss(&y_pred, &y_true, 1.0);
assert!((loss - 2.3125).abs() < 1e-5);
}
#[test]
#[should_panic(expected = "Delta must be positive")]
fn test_huber_loss_zero_delta() {
let y_true = Vector::from_slice(&[1.0]);
let y_pred = Vector::from_slice(&[2.0]);
let _ = huber_loss(&y_pred, &y_true, 0.0);
}
#[test]
#[should_panic(expected = "Delta must be positive")]
fn test_huber_loss_negative_delta() {
let y_true = Vector::from_slice(&[1.0]);
let y_pred = Vector::from_slice(&[2.0]);
let _ = huber_loss(&y_pred, &y_true, -1.0);
}
#[test]
fn test_huber_vs_mse_small_errors() {
let y_true = Vector::from_slice(&[1.0, 2.0, 3.0]);
let y_pred = Vector::from_slice(&[1.1, 2.1, 3.1]);
let huber = huber_loss(&y_pred, &y_true, 1.0);
let mse = mse_loss(&y_pred, &y_true);
assert!((huber - mse).abs() < 0.01);
}
#[test]
fn test_huber_vs_mae_large_errors() {
let y_true = Vector::from_slice(&[0.0, 0.0, 0.0]);
let y_pred = Vector::from_slice(&[10.0, 10.0, 10.0]);
let huber = huber_loss(&y_pred, &y_true, 1.0);
let mae = mae_loss(&y_pred, &y_true);
assert!(huber < mae);
assert!((huber - (mae - 0.5)).abs() < 0.1);
}
#[test]
fn test_mse_loss_struct() {
let loss_fn = MSELoss;
let y_true = Vector::from_slice(&[1.0, 2.0]);
let y_pred = Vector::from_slice(&[1.0, 2.0]);
let loss = loss_fn.compute(&y_pred, &y_true);
assert!((loss - 0.0).abs() < 1e-6);
assert_eq!(loss_fn.name(), "MSE");
}
#[test]
fn test_mae_loss_struct() {
let loss_fn = MAELoss;
let y_true = Vector::from_slice(&[1.0, 2.0]);
let y_pred = Vector::from_slice(&[1.5, 2.5]);
let loss = loss_fn.compute(&y_pred, &y_true);
assert!((loss - 0.5).abs() < 1e-6);
assert_eq!(loss_fn.name(), "MAE");
}
#[test]
fn test_huber_loss_struct() {
let loss_fn = HuberLoss::new(1.0);
let y_true = Vector::from_slice(&[1.0, 2.0]);
let y_pred = Vector::from_slice(&[1.5, 2.5]);
let loss = loss_fn.compute(&y_pred, &y_true);
assert!(loss > 0.0);
assert_eq!(loss_fn.name(), "Huber");
assert!((loss_fn.delta() - 1.0).abs() < 1e-6);
}
#[test]
fn test_loss_trait_polymorphism() {
let loss_fns: Vec<Box<dyn Loss>> = vec![
Box::new(MSELoss),
Box::new(MAELoss),
Box::new(HuberLoss::new(1.0)),
];
let y_true = Vector::from_slice(&[1.0, 2.0]);
let y_pred = Vector::from_slice(&[1.5, 2.5]);
for loss_fn in loss_fns {
let loss = loss_fn.compute(&y_pred, &y_true);
assert!(loss > 0.0);
assert!(!loss_fn.name().is_empty());
}
}
#[test]
fn test_negative_values() {
let y_true = Vector::from_slice(&[-1.0, -2.0, -3.0]);
let y_pred = Vector::from_slice(&[-1.5, -2.5, -3.5]);
let mse = mse_loss(&y_pred, &y_true);
let mae = mae_loss(&y_pred, &y_true);
let huber = huber_loss(&y_pred, &y_true, 1.0);
assert!(mse > 0.0);
assert!(mae > 0.0);
assert!(huber > 0.0);
}
#[test]
fn test_zero_values() {
let y_true = Vector::from_slice(&[0.0, 0.0, 0.0]);
let y_pred = Vector::from_slice(&[0.0, 0.0, 0.0]);
let mse = mse_loss(&y_pred, &y_true);
let mae = mae_loss(&y_pred, &y_true);
let huber = huber_loss(&y_pred, &y_true, 1.0);
assert!((mse - 0.0).abs() < 1e-6);
assert!((mae - 0.0).abs() < 1e-6);
assert!((huber - 0.0).abs() < 1e-6);
}
#[test]
fn test_single_value() {
let y_true = Vector::from_slice(&[5.0]);
let y_pred = Vector::from_slice(&[3.0]);
let mse = mse_loss(&y_pred, &y_true);
let mae = mae_loss(&y_pred, &y_true);
let huber = huber_loss(&y_pred, &y_true, 1.0);
assert!((mse - 4.0).abs() < 1e-6);
assert!((mae - 2.0).abs() < 1e-6);
assert!(huber > 0.0);
}
#[test]
fn test_triplet_loss_satisfied_margin() {
let anchor = Vector::from_slice(&[1.0, 0.0, 0.0]);
let positive = Vector::from_slice(&[0.9, 0.1, 0.0]); let negative = Vector::from_slice(&[0.0, 1.0, 0.0]);
let loss = triplet_loss(&anchor, &positive, &negative, 0.2);
assert!(
(loss - 0.0).abs() < 1e-6,
"Loss should be 0 when margin satisfied"
);
}
#[test]
fn test_triplet_loss_violated_margin() {
let anchor = Vector::from_slice(&[0.0, 0.0]);
let positive = Vector::from_slice(&[1.0, 0.0]); let negative = Vector::from_slice(&[0.0, 1.0]);
let loss = triplet_loss(&anchor, &positive, &negative, 0.5);
assert!((loss - 0.5).abs() < 1e-5);
}
#[test]
fn test_triplet_loss_hard_negative() {
let anchor = Vector::from_slice(&[0.0, 0.0]);
let positive = Vector::from_slice(&[2.0, 0.0]); let negative = Vector::from_slice(&[0.5, 0.0]);
let loss = triplet_loss(&anchor, &positive, &negative, 0.2);
assert!((loss - 1.7).abs() < 1e-5);
}
#[test]
fn test_triplet_loss_zero_margin() {
let anchor = Vector::from_slice(&[0.0, 0.0]);
let positive = Vector::from_slice(&[1.0, 0.0]);
let negative = Vector::from_slice(&[0.0, 2.0]);
let loss = triplet_loss(&anchor, &positive, &negative, 0.0);
assert!((loss - 0.0).abs() < 1e-6);
}
#[test]
#[should_panic(expected = "same dimension")]
fn test_triplet_loss_dimension_mismatch() {
let anchor = Vector::from_slice(&[1.0, 0.0]);
let positive = Vector::from_slice(&[1.0, 0.0, 0.0]);
let negative = Vector::from_slice(&[0.0, 1.0]);
let _ = triplet_loss(&anchor, &positive, &negative, 0.2);
}
#[test]
fn test_info_nce_loss_basic() {
let anchor = Vector::from_slice(&[1.0, 0.0, 0.0]);
let positive = Vector::from_slice(&[0.95, 0.05, 0.0]); let negatives = vec![
Vector::from_slice(&[0.0, 1.0, 0.0]), Vector::from_slice(&[0.0, 0.0, 1.0]), ];
let loss = info_nce_loss(&anchor, &positive, &negatives, 0.1);
assert!(loss >= 0.0);
assert!(loss < 2.0);
}
#[test]
fn test_info_nce_loss_perfect_alignment() {
let anchor = Vector::from_slice(&[1.0, 0.0]);
let positive = Vector::from_slice(&[1.0, 0.0]); let negatives = vec![
Vector::from_slice(&[-1.0, 0.0]), ];
let loss = info_nce_loss(&anchor, &positive, &negatives, 0.5);
assert!(loss < 0.5);
}
#[test]
fn test_info_nce_loss_temperature_effect() {
let anchor = Vector::from_slice(&[1.0, 0.0, 0.0]);
let positive = Vector::from_slice(&[0.7, 0.7, 0.0]);
let negatives = vec![Vector::from_slice(&[0.0, 1.0, 0.0])];
let loss_low_temp = info_nce_loss(&anchor, &positive, &negatives, 0.1);
let loss_high_temp = info_nce_loss(&anchor, &positive, &negatives, 1.0);
assert!(loss_low_temp != loss_high_temp);
}
#[test]
fn test_info_nce_loss_many_negatives() {
let anchor = Vector::from_slice(&[1.0, 0.0]);
let positive = Vector::from_slice(&[0.9, 0.1]);
let negatives: Vec<Vector<f32>> = (0..10)
.map(|i| {
let angle = (i as f32) * 0.3 + 1.0; Vector::from_slice(&[angle.cos(), angle.sin()])
})
.collect();
let loss = info_nce_loss(&anchor, &positive, &negatives, 0.2);
assert!(loss > 0.0);
}
#[test]
#[should_panic(expected = "same dimension")]
fn test_info_nce_loss_dimension_mismatch() {
let anchor = Vector::from_slice(&[1.0, 0.0]);
let positive = Vector::from_slice(&[0.9, 0.1]);
let negatives = vec![Vector::from_slice(&[0.0, 1.0, 0.0])];
let _ = info_nce_loss(&anchor, &positive, &negatives, 0.1);
}
#[test]
#[should_panic(expected = "Temperature must be positive")]
fn test_info_nce_loss_zero_temperature() {
let anchor = Vector::from_slice(&[1.0, 0.0]);
let positive = Vector::from_slice(&[0.9, 0.1]);
let negatives = vec![Vector::from_slice(&[0.0, 1.0])];
let _ = info_nce_loss(&anchor, &positive, &negatives, 0.0);
}
#[test]
fn test_focal_loss_confident_correct() {
let predictions = Vector::from_slice(&[0.99, 0.01]);
let targets = Vector::from_slice(&[1.0, 0.0]);
let loss = focal_loss(&predictions, &targets, 0.25, 2.0);
assert!(loss < 0.01);
}
#[test]
fn test_focal_loss_confident_wrong() {
let predictions = Vector::from_slice(&[0.01, 0.99]);
let targets = Vector::from_slice(&[1.0, 0.0]);
let loss = focal_loss(&predictions, &targets, 0.25, 2.0);
assert!(loss > 1.0);
}
#[test]
fn test_focal_loss_uncertain() {
let predictions = Vector::from_slice(&[0.5, 0.5]);
let targets = Vector::from_slice(&[1.0, 0.0]);
let loss = focal_loss(&predictions, &targets, 0.25, 2.0);
assert!(loss > 0.0);
assert!(loss < 1.0);
}
#[test]
fn test_focal_loss_gamma_effect() {
let predictions = Vector::from_slice(&[0.9]);
let targets = Vector::from_slice(&[1.0]);
let loss_gamma_0 = focal_loss(&predictions, &targets, 0.25, 0.0);
let loss_gamma_2 = focal_loss(&predictions, &targets, 0.25, 2.0);
assert!(loss_gamma_2 < loss_gamma_0);
}
#[test]
fn test_cross_entropy_loss_one_hot() {
let logits = Vector::from_slice(&[2.0, 1.0, 0.5]);
let targets = Vector::from_slice(&[1.0, 0.0, 0.0]);
let loss = cross_entropy_loss(&logits, &targets);
assert!(loss > 0.0);
assert!(loss.is_finite());
}
include!("tests_include_01.rs");