pub(crate) use super::*;
#[test]
fn test_mixup_new() {
let mixup = Mixup::new(0.4);
assert_eq!(mixup.alpha(), 0.4);
}
#[test]
fn test_mixup_sample_lambda() {
let mixup = Mixup::new(1.0);
for _ in 0..10 {
let lambda = mixup.sample_lambda();
assert!((0.0..=1.0).contains(&lambda));
}
}
#[test]
fn test_mixup_alpha_zero() {
let mixup = Mixup::new(0.0);
assert_eq!(mixup.sample_lambda(), 1.0);
}
#[test]
fn test_mixup_mix_samples() {
let mixup = Mixup::new(1.0);
let x1 = Vector::from_slice(&[1.0, 0.0]);
let x2 = Vector::from_slice(&[0.0, 1.0]);
let mixed = mixup.mix_samples(&x1, &x2, 0.5);
assert!((mixed.as_slice()[0] - 0.5).abs() < 1e-6);
assert!((mixed.as_slice()[1] - 0.5).abs() < 1e-6);
}
#[test]
fn test_mixup_mix_extreme_lambda() {
let mixup = Mixup::new(1.0);
let x1 = Vector::from_slice(&[1.0, 2.0]);
let x2 = Vector::from_slice(&[3.0, 4.0]);
let mixed0 = mixup.mix_samples(&x1, &x2, 0.0);
assert_eq!(mixed0.as_slice(), &[3.0, 4.0]);
let mixed1 = mixup.mix_samples(&x1, &x2, 1.0);
assert_eq!(mixed1.as_slice(), &[1.0, 2.0]);
}
#[test]
fn test_label_smoothing_new() {
let ls = LabelSmoothing::new(0.1);
assert_eq!(ls.epsilon(), 0.1);
}
#[test]
fn test_label_smoothing_smooth() {
let ls = LabelSmoothing::new(0.1);
let label = Vector::from_slice(&[1.0, 0.0, 0.0]);
let smoothed = ls.smooth(&label);
assert!((smoothed.as_slice()[0] - 0.9333).abs() < 0.01);
assert!((smoothed.as_slice()[1] - 0.0333).abs() < 0.01);
}
#[test]
fn test_label_smoothing_smooth_index() {
let ls = LabelSmoothing::new(0.1);
let smoothed = ls.smooth_index(0, 3);
let sum: f32 = smoothed.as_slice().iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
#[test]
fn test_label_smoothing_sums_to_one() {
let ls = LabelSmoothing::new(0.2);
let smoothed = ls.smooth_index(2, 5);
let sum: f32 = smoothed.as_slice().iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
#[test]
fn test_cross_entropy_with_smoothing() {
let logits = Vector::from_slice(&[2.0, 1.0, 0.5]);
let loss = cross_entropy_with_smoothing(&logits, 0, 0.1);
assert!(loss > 0.0);
assert!(loss.is_finite());
}
#[test]
fn test_cross_entropy_no_smoothing() {
let logits = Vector::from_slice(&[10.0, 0.0, 0.0]);
let loss = cross_entropy_with_smoothing(&logits, 0, 0.0);
assert!(loss < 0.1);
}
#[test]
fn test_cutmix_creation() {
let cm = CutMix::new(1.0);
assert_eq!(cm.alpha(), 1.0);
}
#[test]
fn test_cutmix_sample() {
let cm = CutMix::new(1.0);
let params = cm.sample(32, 32);
assert!(params.lambda >= 0.0 && params.lambda <= 1.0);
assert!(params.x1 <= params.x2);
assert!(params.y1 <= params.y2);
assert!(params.x2 <= 32);
assert!(params.y2 <= 32);
}
#[test]
fn test_cutmix_apply() {
let params = CutMixParams {
lambda: 0.5,
x1: 1,
y1: 1,
x2: 2,
y2: 2,
};
let img1 = vec![1.0; 12]; let img2 = vec![2.0; 12];
let result = params.apply(&img1, &img2, 1, 3, 4);
assert_eq!(result.len(), 12);
assert_eq!(result[4 + 1], 2.0);
}
#[test]
fn test_stochastic_depth_creation() {
let sd = StochasticDepth::new(0.2, DropMode::Batch);
assert_eq!(sd.drop_prob(), 0.2);
}
#[test]
fn test_stochastic_depth_eval_always_keeps() {
let sd = StochasticDepth::new(0.9, DropMode::Batch);
for _ in 0..10 {
assert!(sd.should_keep(false));
}
}
#[test]
fn test_stochastic_depth_zero_drop() {
let sd = StochasticDepth::new(0.0, DropMode::Batch);
for _ in 0..10 {
assert!(sd.should_keep(true));
}
}
#[test]
fn test_stochastic_depth_linear_decay() {
let survival = StochasticDepth::linear_decay(5, 10, 0.5);
assert!((survival - 0.75).abs() < 1e-6);
let survival_last = StochasticDepth::linear_decay(10, 10, 0.5);
assert!((survival_last - 0.5).abs() < 1e-6);
}
#[test]
fn test_rdrop_creation() {
let rdrop = RDrop::new(0.5);
assert_eq!(rdrop.alpha(), 0.5);
}
#[test]
fn test_rdrop_kl_divergence_same() {
let rdrop = RDrop::new(1.0);
let p = vec![0.5, 0.3, 0.2];
let kl = rdrop.kl_divergence(&p, &p);
assert!(kl.abs() < 1e-5);
}
#[test]
fn test_rdrop_kl_divergence_different() {
let rdrop = RDrop::new(1.0);
let p = vec![0.9, 0.1];
let q = vec![0.1, 0.9];
let kl = rdrop.kl_divergence(&p, &q);
assert!(kl > 0.0);
}
#[test]
fn test_rdrop_symmetric_kl() {
let rdrop = RDrop::new(1.0);
let p = vec![0.7, 0.3];
let q = vec![0.4, 0.6];
let sym = rdrop.symmetric_kl(&p, &q);
assert!(sym > 0.0);
}
#[test]
fn test_rdrop_compute_loss_same() {
let rdrop = RDrop::new(1.0);
let logits = vec![2.0, 1.0, 0.5];
let loss = rdrop.compute_loss(&logits, &logits);
assert!(loss.abs() < 1e-5);
}
#[test]
fn test_rdrop_compute_loss_different() {
let rdrop = RDrop::new(1.0);
let logits1 = vec![2.0, 0.0, 0.0];
let logits2 = vec![0.0, 2.0, 0.0];
let loss = rdrop.compute_loss(&logits1, &logits2);
assert!(loss > 0.0);
}
#[test]
fn test_rdrop_alpha_zero() {
let rdrop = RDrop::new(0.0);
let logits1 = vec![2.0, 0.0];
let logits2 = vec![0.0, 2.0];
let loss = rdrop.compute_loss(&logits1, &logits2);
assert_eq!(loss, 0.0);
}
#[test]
fn test_specaugment_new() {
let sa = SpecAugment::new();
assert_eq!(sa.num_freq_masks(), 2);
assert_eq!(sa.num_time_masks(), 2);
}
#[test]
fn test_specaugment_custom() {
let sa = SpecAugment::with_params(1, 10, 3, 50);
assert_eq!(sa.num_freq_masks(), 1);
assert_eq!(sa.num_time_masks(), 3);
}
#[test]
fn test_specaugment_apply_shape() {
let sa = SpecAugment::with_params(1, 5, 1, 10);
let spec = vec![1.0; 80 * 100]; let result = sa.apply(&spec, 80, 100);
assert_eq!(result.len(), spec.len());
}
#[test]
fn test_specaugment_masks_applied() {
let sa = SpecAugment::with_params(2, 10, 2, 20).with_mask_value(-999.0);
let spec = vec![1.0; 40 * 50];
let result = sa.apply(&spec, 40, 50);
let masked_count = result.iter().filter(|&&v| v == -999.0).count();
assert!(masked_count > 0, "Should have some masked values");
}
#[test]
fn test_specaugment_freq_mask() {
let sa = SpecAugment::with_params(1, 5, 0, 0).with_mask_value(0.0);
let spec = vec![1.0; 20 * 30]; let result = sa.freq_mask(&spec, 20, 30);
let zero_count = result.iter().filter(|&&v| v == 0.0).count();
let _ = zero_count; }
#[test]
fn test_specaugment_time_mask() {
let sa = SpecAugment::with_params(0, 0, 1, 5).with_mask_value(0.0);
let spec = vec![1.0; 20 * 30];
let result = sa.time_mask(&spec, 20, 30);
assert_eq!(result.len(), spec.len());
}
#[test]
fn test_randaugment_new() {
let ra = RandAugment::new(2, 9);
assert_eq!(ra.n(), 2);
assert_eq!(ra.m(), 9);
}
#[test]
fn test_randaugment_default() {
let ra = RandAugment::default();
assert_eq!(ra.n(), 2);
assert_eq!(ra.m(), 9);
}
#[test]
fn test_randaugment_magnitude_clamp() {
let ra = RandAugment::new(1, 50); assert_eq!(ra.m(), 30);
}
#[test]
fn test_randaugment_normalized_magnitude() {
let ra = RandAugment::new(1, 15);
assert!((ra.normalized_magnitude() - 0.5).abs() < 1e-6);
}
#[test]
fn test_randaugment_sample_augmentations() {
let ra = RandAugment::new(3, 10);
let sampled = ra.sample_augmentations();
assert_eq!(sampled.len(), 3);
}
#[test]
fn test_randaugment_apply_identity() {
let ra = RandAugment::new(1, 15);
let image = vec![0.5; 16]; let result = ra.apply_single(&image, AugmentationType::Identity, 4, 4);
assert_eq!(result, image);
}
#[test]
fn test_randaugment_apply_brightness() {
let ra = RandAugment::new(1, 30); let image = vec![0.5; 16];
let result = ra.apply_single(&image, AugmentationType::Brightness, 4, 4);
let changed = result
.iter()
.zip(image.iter())
.any(|(&r, &o)| (r - o).abs() > 0.01);
assert!(changed, "Brightness should modify values");
}
#[test]
fn test_randaugment_apply_contrast() {
let ra = RandAugment::new(1, 20);
let image: Vec<f32> = (0..16).map(|i| i as f32 / 15.0).collect();
let result = ra.apply_single(&image, AugmentationType::Contrast, 4, 4);
assert_eq!(result.len(), image.len());
}
#[test]
fn test_randaugment_custom_augmentations() {
let ra = RandAugment::new(2, 10).with_augmentations(vec![
AugmentationType::Identity,
AugmentationType::Brightness,
]);
let sampled = ra.sample_augmentations();
for aug in sampled {
assert!(aug == AugmentationType::Identity || aug == AugmentationType::Brightness);
}
}
#[test]
fn test_augmentation_type_equality() {
assert_eq!(AugmentationType::Rotate, AugmentationType::Rotate);
assert_ne!(AugmentationType::Rotate, AugmentationType::Brightness);
}
#[test]
fn test_stochastic_depth_mode() {
let sd_batch = StochasticDepth::new(0.1, DropMode::Batch);
assert_eq!(sd_batch.mode(), DropMode::Batch);
let sd_row = StochasticDepth::new(0.1, DropMode::Row);
assert_eq!(sd_row.mode(), DropMode::Row);
}
#[test]
fn test_drop_mode_eq() {
assert_eq!(DropMode::Batch, DropMode::Batch);
assert_ne!(DropMode::Batch, DropMode::Row);
}
#[test]
fn test_specaugment_default() {
let sa = SpecAugment::default();
assert_eq!(sa.num_freq_masks(), 2);
assert_eq!(sa.num_time_masks(), 2);
}
#[test]
fn test_specaugment_with_mask_value() {
let sa = SpecAugment::new().with_mask_value(-1.0);
let spec = vec![1.0; 100];
let result = sa.apply(&spec, 10, 10);
assert_eq!(result.len(), 100);
}
#[test]
fn test_randaugment_apply_rotate() {
let ra = RandAugment::new(1, 20); let image = vec![1.0, 2.0, 3.0, 4.0];
let result = ra.apply_single(&image, AugmentationType::Rotate, 2, 2);
assert_eq!(result, vec![4.0, 3.0, 2.0, 1.0]);
}
#[test]
fn test_randaugment_apply_rotate_low_mag() {
let ra = RandAugment::new(1, 5); let image = vec![1.0, 2.0, 3.0, 4.0];
let result = ra.apply_single(&image, AugmentationType::Rotate, 2, 2);
assert_eq!(result, image);
}
#[test]
fn test_randaugment_apply_translate_x() {
let ra = RandAugment::new(1, 15);
let image = vec![1.0; 16];
let result = ra.apply_single(&image, AugmentationType::TranslateX, 4, 4);
assert_eq!(result.len(), 16);
}
#[test]
fn test_randaugment_apply_translate_y() {
let ra = RandAugment::new(1, 15);
let image = vec![1.0; 16];
let result = ra.apply_single(&image, AugmentationType::TranslateY, 4, 4);
assert_eq!(result.len(), 16);
}
#[test]
fn test_randaugment_apply_shear_x() {
let ra = RandAugment::new(1, 15);
let image = vec![0.5; 16];
let result = ra.apply_single(&image, AugmentationType::ShearX, 4, 4);
assert_eq!(result.len(), 16);
}
#[path = "tests_augmentation.rs"]
mod tests_augmentation;
#[path = "tests_cutmix_specaugment.rs"]
mod tests_cutmix_specaugment;