use ruvector_cnn::contrastive::{
AugmentationConfig, ContrastiveAugmentation, InfoNCELoss, TripletLoss, TripletDistance,
};
#[test]
fn test_infonce_basic_computation() {
let loss_fn = InfoNCELoss::new(0.07);
let embeddings = vec![
vec![1.0f64, 0.0, 0.0], vec![0.9, 0.1, 0.0], vec![0.0, 1.0, 0.0], vec![0.1, 0.9, 0.0], ];
let loss = loss_fn.forward(&embeddings, 2);
assert!(loss > 0.0, "InfoNCE loss should be positive, got {}", loss);
assert!(loss.is_finite(), "Loss should be finite");
}
#[test]
fn test_infonce_perfect_alignment() {
let loss_fn = InfoNCELoss::new(0.07);
let embeddings = vec![
vec![1.0f64, 0.0, 0.0, 0.0], vec![1.0f64, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0], ];
let loss = loss_fn.forward(&embeddings, 2);
assert!(
loss < 1.0,
"Perfect alignment should have lower loss, got {}",
loss
);
}
#[test]
fn test_infonce_temperature_effect() {
let loss_low_temp = InfoNCELoss::new(0.01);
let loss_high_temp = InfoNCELoss::new(1.0);
let embeddings = vec![
vec![1.0f64, 0.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![-1.0, 0.0, 0.0, 0.0],
];
let low_temp_loss = loss_low_temp.forward(&embeddings, 2);
let high_temp_loss = loss_high_temp.forward(&embeddings, 2);
assert!(low_temp_loss.is_finite(), "Low temp loss should be finite");
assert!(high_temp_loss.is_finite(), "High temp loss should be finite");
}
#[test]
fn test_infonce_many_negatives() {
let loss_fn = InfoNCELoss::new(0.07);
let mut embeddings = Vec::new();
for i in 0..10 {
let angle = (i as f64) * 0.5;
embeddings.push(vec![angle.cos(), angle.sin(), 0.0]);
embeddings.push(vec![
(angle + 0.1).cos(),
(angle + 0.1).sin(),
0.1,
]);
}
let loss = loss_fn.forward(&embeddings, 2);
assert!(loss > 0.0, "Loss should be positive");
assert!(loss.is_finite(), "Loss should be finite");
}
#[test]
fn test_infonce_detailed_results() {
let loss_fn = InfoNCELoss::new(0.07)
.with_per_sample_losses()
.with_similarity_matrix();
let embeddings = vec![
vec![1.0f64, 0.0],
vec![0.9, 0.1],
vec![0.0, 1.0],
vec![0.1, 0.9],
];
let result = loss_fn.forward_detailed(&embeddings, 2).unwrap();
assert!(result.loss > 0.0);
assert!(result.per_sample_losses.is_some());
assert!(result.similarity_matrix.is_some());
let per_sample = result.per_sample_losses.unwrap();
assert_eq!(per_sample.len(), 4);
let sim_matrix = result.similarity_matrix.unwrap();
assert_eq!(sim_matrix.len(), 4);
assert_eq!(sim_matrix[0].len(), 4);
for i in 0..4 {
assert!((sim_matrix[i][i] - 1.0).abs() < 1e-6, "Self-similarity should be 1.0");
}
}
#[test]
fn test_infonce_forward_with_pairs() {
let loss_fn = InfoNCELoss::new(0.07);
let anchors = vec![vec![1.0f64, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
let positives = vec![vec![0.9, 0.1, 0.0], vec![0.1, 0.9, 0.0]];
let loss = loss_fn.forward_with_pairs(&anchors, &positives, None).unwrap();
assert!(loss > 0.0);
assert!(loss.is_finite());
}
#[test]
fn test_triplet_loss_basic() {
let loss_fn = TripletLoss::new(1.0);
let anchor = vec![1.0f64, 0.0, 0.0];
let positive = vec![0.9, 0.1, 0.0]; let negative = vec![-1.0, 0.0, 0.0];
let loss = loss_fn.forward(&anchor, &positive, &negative);
assert!(loss >= 0.0, "Triplet loss should be non-negative");
}
#[test]
fn test_triplet_loss_zero_case() {
let loss_fn = TripletLoss::new(0.5);
let anchor = vec![1.0f64, 0.0, 0.0];
let positive = vec![1.0, 0.0, 0.0]; let negative = vec![-1.0, 0.0, 0.0];
let loss = loss_fn.forward(&anchor, &positive, &negative);
assert_eq!(
loss, 0.0,
"Loss should be zero when margin is satisfied"
);
}
#[test]
fn test_triplet_loss_positive_case() {
let loss_fn = TripletLoss::new(1.0);
let anchor = vec![1.0f64, 0.0, 0.0];
let positive = vec![0.0, 1.0, 0.0]; let negative = vec![0.0, 0.0, 1.0];
let loss = loss_fn.forward(&anchor, &positive, &negative);
assert!(
loss > 0.0,
"Loss should be positive when margin not satisfied, got {}",
loss
);
}
#[test]
fn test_triplet_loss_euclidean() {
let loss_fn = TripletLoss::new(1.0).with_distance(TripletDistance::Euclidean);
let anchor = vec![0.0f64, 0.0, 0.0];
let positive = vec![1.0, 0.0, 0.0]; let negative = vec![3.0, 0.0, 0.0];
let loss = loss_fn.forward(&anchor, &positive, &negative);
assert_eq!(loss, 0.0, "Margin should be satisfied");
}
#[test]
fn test_triplet_loss_batch() {
let loss_fn = TripletLoss::new(0.5);
let anchors = vec![vec![1.0f64, 0.0], vec![0.0, 1.0]];
let positives = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
let negatives = vec![vec![-1.0, 0.0], vec![0.0, -1.0]];
let loss = loss_fn.forward_batch(&anchors, &positives, &negatives).unwrap();
assert!(loss >= 0.0);
assert!(loss.is_finite());
}
#[test]
fn test_triplet_loss_detailed() {
let loss_fn = TripletLoss::new(1.0);
let anchor = vec![0.0f64, 0.0];
let positive = vec![1.0, 0.0];
let negative = vec![0.5, 0.0];
let result = loss_fn.forward_detailed(&anchor, &positive, &negative).unwrap();
assert!(result.loss > 0.0);
assert!(result.is_hard);
assert!(result.violates_margin);
assert!(result.positive_distance > result.negative_distance);
}
#[test]
fn test_augmentation_config_default() {
let config = AugmentationConfig::default();
assert_eq!(config.crop_scale_min, 0.08);
assert_eq!(config.crop_scale_max, 1.0);
assert_eq!(config.horizontal_flip_prob, 0.5);
assert_eq!(config.output_size, (224, 224));
}
#[test]
fn test_augmentation_builder() {
let aug = ContrastiveAugmentation::builder()
.crop_scale(0.2, 0.8)
.horizontal_flip_prob(1.0)
.color_jitter(0.2, 0.2, 0.2, 0.05)
.grayscale_prob(0.1)
.output_size(128, 128)
.build();
let config = aug.config();
assert_eq!(config.crop_scale_min, 0.2);
assert_eq!(config.crop_scale_max, 0.8);
assert_eq!(config.horizontal_flip_prob, 1.0);
assert_eq!(config.brightness, 0.2);
assert_eq!(config.contrast, 0.2);
assert_eq!(config.saturation, 0.2);
assert_eq!(config.hue, 0.05);
assert_eq!(config.grayscale_prob, 0.1);
assert_eq!(config.output_size, (128, 128));
}
#[test]
fn test_augmentation_with_seed() {
let aug1 = ContrastiveAugmentation::builder().seed(42).build();
let aug2 = ContrastiveAugmentation::builder().seed(42).build();
assert_eq!(aug1.config().crop_scale_min, aug2.config().crop_scale_min);
assert_eq!(aug1.config().crop_scale_max, aug2.config().crop_scale_max);
assert_eq!(aug1.config().horizontal_flip_prob, aug2.config().horizontal_flip_prob);
}
#[test]
fn test_augmentation_blur_config() {
let aug = ContrastiveAugmentation::builder()
.gaussian_blur(3, (0.1, 3.0))
.blur_prob(0.5)
.build();
let config = aug.config();
assert_eq!(config.blur_kernel_size, 3);
assert_eq!(config.blur_prob, 0.5);
assert_eq!(config.blur_sigma_range, (0.1, 3.0));
}
#[test]
fn test_augmentation_default() {
let aug = ContrastiveAugmentation::default();
let config = aug.config();
assert!(config.crop_scale_min > 0.0);
assert!(config.crop_scale_max <= 1.0);
assert!(config.horizontal_flip_prob >= 0.0);
assert!(config.horizontal_flip_prob <= 1.0);
}
#[test]
fn test_infonce_with_normalized_embeddings() {
let loss_fn = InfoNCELoss::new(0.1);
let embeddings: Vec<Vec<f64>> = vec![
vec![1.0, 0.0, 0.0],
vec![0.8, 0.6, 0.0], vec![0.0, 1.0, 0.0],
vec![0.0, 0.8, 0.6], ];
let normalized: Vec<Vec<f64>> = embeddings
.into_iter()
.map(|v| {
let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
v.into_iter().map(|x| x / norm).collect()
})
.collect();
let loss = loss_fn.forward(&normalized, 2);
assert!(loss.is_finite(), "Loss with normalized vectors should be finite");
assert!(loss > 0.0, "Loss should be positive");
}
#[test]
fn test_triplet_with_varying_margins() {
let anchor = vec![1.0f64, 0.0, 0.0, 0.0];
let positive = vec![0.7, 0.7, 0.0, 0.0];
let negative = vec![0.0, 1.0, 0.0, 0.0];
let margins = [0.1, 0.5, 1.0, 2.0];
let mut losses = Vec::new();
for margin in margins.iter() {
let loss_fn = TripletLoss::new(*margin);
losses.push(loss_fn.forward(&anchor, &positive, &negative));
}
for i in 1..losses.len() {
assert!(
losses[i] >= losses[i - 1],
"Higher margin should lead to equal or higher loss"
);
}
}
#[test]
fn test_triplet_mine_hard_triplets() {
let triplet = TripletLoss::new(0.01);
let embeddings = vec![
vec![1.0f64, 0.0], vec![0.95, 0.05], vec![0.9, 0.1], vec![0.85, 0.15], ];
let labels = vec![0, 0, 1, 1];
let hard_triplets = triplet.mine_hard_triplets(&embeddings, &labels);
for (a, p, n) in &hard_triplets {
assert_eq!(labels[*a], labels[*p], "anchor and positive should be same class");
assert_ne!(labels[*a], labels[*n], "anchor and negative should be different class");
}
}