pub(crate) use super::*;
use tests_training_sample::cosine_sim;
#[test]
fn test_config_default() {
let config = NeuralEncoderConfig::default();
assert_eq!(config.vocab_size, 8192);
assert_eq!(config.embed_dim, 256);
assert_eq!(config.output_dim, 256);
}
#[test]
fn test_config_minimal() {
let config = NeuralEncoderConfig::minimal();
assert_eq!(config.vocab_size, 1000);
assert_eq!(config.embed_dim, 64);
assert!(config.num_layers < NeuralEncoderConfig::default().num_layers);
}
#[test]
fn test_config_small() {
let config = NeuralEncoderConfig::small();
assert!(config.embed_dim > NeuralEncoderConfig::minimal().embed_dim);
assert!(config.embed_dim < NeuralEncoderConfig::default().embed_dim);
}
#[test]
fn test_vocabulary_creation() {
let vocab = Vocabulary::for_rust_errors();
assert!(vocab.vocab_size() > 0);
}
#[test]
fn test_vocabulary_special_tokens() {
let vocab = Vocabulary::for_rust_errors();
assert!(vocab.cls_token() < vocab.vocab_size());
assert!(vocab.sep_token() < vocab.vocab_size());
assert!(vocab.eos_token() < vocab.vocab_size());
}
#[test]
fn test_vocabulary_lang_tokens() {
let vocab = Vocabulary::for_rust_errors();
let python_token = vocab.lang_token("python");
let rust_token = vocab.lang_token("rust");
assert_ne!(python_token, rust_token);
}
#[test]
fn test_vocabulary_tokenize_simple() {
let vocab = Vocabulary::for_rust_errors();
let tokens = vocab.tokenize("error expected type");
assert!(!tokens.is_empty());
}
#[test]
fn test_vocabulary_tokenize_with_punctuation() {
let vocab = Vocabulary::for_rust_errors();
let tokens = vocab.tokenize("E0308: mismatched types");
assert!(tokens.len() >= 3);
}
#[test]
fn test_vocabulary_tokenize_error_code() {
let vocab = Vocabulary::for_rust_errors();
let tokens = vocab.tokenize("E0308");
assert_eq!(tokens.len(), 1);
assert_ne!(tokens[0], vocab.unk_id);
}
#[test]
fn test_vocabulary_unknown_token() {
let vocab = Vocabulary::for_rust_errors();
let tokens = vocab.tokenize("xyzzy12345");
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0], vocab.unk_id);
}
#[test]
fn test_encoder_creation() {
let encoder = NeuralErrorEncoder::with_config(NeuralEncoderConfig::minimal());
assert!(!encoder.is_training());
}
#[test]
fn test_encoder_train_eval_mode() {
let mut encoder = NeuralErrorEncoder::with_config(NeuralEncoderConfig::minimal());
encoder.train();
assert!(encoder.is_training());
encoder.eval();
assert!(!encoder.is_training());
}
#[test]
fn test_encoder_num_parameters() {
let encoder = NeuralErrorEncoder::with_config(NeuralEncoderConfig::minimal());
let num_params = encoder.num_parameters();
assert!(num_params > 0);
}
#[test]
fn test_encoder_encode_returns_correct_dim() {
let config = NeuralEncoderConfig::minimal();
let output_dim = config.output_dim;
let encoder = NeuralErrorEncoder::with_config(config);
let embedding = encoder.encode("E0308: mismatched types", "let x: i32 = \"hello\";", "rust");
assert_eq!(embedding.len(), output_dim);
}
#[test]
fn test_encoder_embedding_is_normalized() {
let encoder = NeuralErrorEncoder::with_config(NeuralEncoderConfig::minimal());
let embedding = encoder.encode("E0308: mismatched types", "let x: i32 = \"hello\";", "rust");
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 0.01,
"Embedding norm should be ~1.0, got {norm}"
);
}
#[test]
fn test_encoder_similar_errors_similar_embeddings() {
let encoder = NeuralErrorEncoder::with_config(NeuralEncoderConfig::minimal());
let emb1 = encoder.encode(
"E0308: mismatched types, expected i32 found &str",
"let x: i32 = \"hello\";",
"rust",
);
let emb2 = encoder.encode(
"E0308: mismatched types, expected i32 found String",
"let y: i32 = String::new();",
"rust",
);
let emb3 = encoder.encode(
"E0382: borrow of moved value",
"let x = vec![1]; let y = x; let z = x;",
"rust",
);
let sim_12 = cosine_sim(&emb1, &emb2);
let sim_13 = cosine_sim(&emb1, &emb3);
let tolerance = 0.01;
assert!(
sim_12 > sim_13 - tolerance,
"Similar errors should have higher similarity (or near-tie): sim_12={sim_12}, sim_13={sim_13}"
);
}
#[test]
fn test_encoder_different_languages() {
let encoder = NeuralErrorEncoder::with_config(NeuralEncoderConfig::minimal());
let emb_rust = encoder.encode("E0308: mismatched types", "let x: i32 = \"hello\";", "rust");
let emb_python = encoder.encode(
"TypeError: expected int, got str",
"x: int = \"hello\"",
"python",
);
let sim = cosine_sim(&emb_rust, &emb_python);
assert!((-1.0..=1.0).contains(&sim));
}
#[test]
fn test_contrastive_loss_creation() {
let loss = ContrastiveLoss::new();
assert!((loss.temperature - 0.07).abs() < 0.001);
}
#[test]
fn test_contrastive_loss_custom_temperature() {
let loss = ContrastiveLoss::with_temperature(0.1);
assert!((loss.temperature - 0.1).abs() < 0.001);
}
#[test]
fn test_triplet_loss_creation() {
let loss = TripletLoss::new();
assert!((loss.margin() - 1.0).abs() < 0.001);
assert_eq!(loss.distance_metric(), TripletDistance::Euclidean);
}
#[test]
fn test_triplet_loss_default() {
let loss = TripletLoss::default();
assert!((loss.margin() - 1.0).abs() < 0.001);
}
#[test]
fn test_triplet_loss_custom_margin() {
let loss = TripletLoss::with_margin(0.5);
assert!((loss.margin() - 0.5).abs() < 0.001);
}
#[test]
fn test_triplet_loss_with_distance() {
let loss = TripletLoss::new().with_distance(TripletDistance::Cosine);
assert_eq!(loss.distance_metric(), TripletDistance::Cosine);
}
#[test]
fn test_triplet_loss_zero_when_satisfied() {
let loss = TripletLoss::with_margin(0.1);
let anchor = Tensor::new(&[0.0, 0.0], &[1, 2]);
let positive = Tensor::new(&[0.1, 0.0], &[1, 2]); let negative = Tensor::new(&[5.0, 0.0], &[1, 2]);
let loss_val = loss.forward(&anchor, &positive, &negative);
assert!(
loss_val.data()[0] < 0.01,
"Loss should be ~0 when triplet is satisfied"
);
}
#[test]
fn test_triplet_loss_positive_when_violated() {
let loss = TripletLoss::with_margin(0.5);
let anchor = Tensor::new(&[0.0, 0.0], &[1, 2]);
let positive = Tensor::new(&[3.0, 0.0], &[1, 2]); let negative = Tensor::new(&[1.0, 0.0], &[1, 2]);
let loss_val = loss.forward(&anchor, &positive, &negative);
assert!(
loss_val.data()[0] > 2.0,
"Loss should be positive when triplet is violated"
);
}
#[test]
fn test_triplet_loss_batch() {
let loss = TripletLoss::with_margin(1.0);
let anchor = Tensor::new(&[0.0, 0.0, 0.0, 0.0], &[2, 2]);
let positive = Tensor::new(&[0.1, 0.0, 0.1, 0.0], &[2, 2]);
let negative = Tensor::new(&[5.0, 0.0, 5.0, 0.0], &[2, 2]);
let loss_val = loss.forward(&anchor, &positive, &negative);
assert_eq!(loss_val.shape(), &[1]);
}
#[test]
fn test_triplet_loss_squared_euclidean() {
let loss = TripletLoss::with_margin(1.0).with_distance(TripletDistance::SquaredEuclidean);
let anchor = Tensor::new(&[0.0, 0.0], &[1, 2]);
let positive = Tensor::new(&[1.0, 0.0], &[1, 2]); let negative = Tensor::new(&[2.0, 0.0], &[1, 2]);
let loss_val = loss.forward(&anchor, &positive, &negative);
assert!(loss_val.data()[0] < 0.01);
}
#[test]
fn test_triplet_loss_cosine() {
let loss = TripletLoss::with_margin(0.1).with_distance(TripletDistance::Cosine);
let anchor = Tensor::new(&[1.0, 0.0], &[1, 2]);
let positive = Tensor::new(&[0.9, 0.1], &[1, 2]);
let negative = Tensor::new(&[0.0, 1.0], &[1, 2]);
let loss_val = loss.forward(&anchor, &positive, &negative);
assert!(loss_val.data()[0] < 0.5);
}
#[test]
fn test_pairwise_distances() {
let loss = TripletLoss::new();
let embeddings = Tensor::new(
&[
0.0, 0.0, 1.0, 0.0, 0.0, 1.0, ],
&[3, 2],
);
let distances = loss.pairwise_distances(&embeddings);
let data = distances.data();
assert_eq!(distances.shape(), &[3, 3]);
assert!(data[0] < 0.01); assert!(data[4] < 0.01); assert!(data[8] < 0.01);
assert!((data[1] - 1.0).abs() < 0.01);
assert!((data[2] - 1.0).abs() < 0.01);
assert!((data[5] - std::f32::consts::SQRT_2).abs() < 0.01);
}
#[test]
fn test_mine_hard_triplets() {
let loss = TripletLoss::new();
let embeddings = Tensor::new(
&[
0.0, 0.0, 0.1, 0.0, 5.0, 0.0, 5.1, 0.0, ],
&[4, 2],
);
let labels = vec![0, 0, 1, 1];
let triplets = loss.mine_hard_triplets(&embeddings, &labels);
assert!(!triplets.is_empty());
for (a, p, n) in &triplets {
assert_eq!(labels[*a], labels[*p], "Positive should be same class");
assert_ne!(labels[*a], labels[*n], "Negative should be different class");
}
}
#[test]
fn test_mine_hard_triplets_single_class() {
let loss = TripletLoss::new();
let embeddings = Tensor::new(&[0.0, 0.0, 1.0, 0.0, 2.0, 0.0], &[3, 2]);
let labels = vec![0, 0, 0];
let triplets = loss.mine_hard_triplets(&embeddings, &labels);
assert!(triplets.is_empty(), "No triplets when all same class");
}
#[test]
fn test_batch_hard_loss() {
let loss = TripletLoss::with_margin(0.5);
let embeddings = Tensor::new(
&[
0.0, 0.0, 0.1, 0.1, 10.0, 0.0, 10.1, 0.1, ],
&[4, 2],
);
let labels = vec![0, 0, 1, 1];
let loss_val = loss.batch_hard_loss(&embeddings, &labels);
assert!(loss_val.data()[0] < 1.0);
}
#[test]
fn test_batch_hard_loss_overlapping() {
let loss = TripletLoss::with_margin(1.0);
let embeddings = Tensor::new(
&[
0.0, 0.0, 0.5, 0.0, 0.3, 0.0, 0.8, 0.0, ],
&[4, 2],
);
let labels = vec![0, 0, 1, 1];
let loss_val = loss.batch_hard_loss(&embeddings, &labels);
assert!(loss_val.data()[0] > 0.0);
}
#[test]
fn test_batch_hard_loss_empty() {
let loss = TripletLoss::new();
let embeddings = Tensor::new(&[0.0, 0.0], &[1, 2]);
let labels = vec![0];
let loss_val = loss.batch_hard_loss(&embeddings, &labels);
assert!(
loss_val.data()[0].abs() < 0.001,
"Loss should be 0 for single sample"
);
}
#[test]
fn test_training_sample_creation() {
let sample = TrainingSample::new("E0308: mismatched types", "let x: i32 = \"hello\";", "rust");
assert_eq!(sample.source_lang, "rust");
assert!(sample.positive.is_none());
}
#[path = "tests_training_sample.rs"]
mod tests_training_sample;