#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod cross_validation_tests {
use crate::services::mutation::{
Mutant, MutantStatus, MutationOperatorType, SourceLocation, SurvivabilityPredictor,
TrainingData,
};
fn create_test_mutant() -> Mutant {
Mutant {
id: "test_1".to_string(),
original_file: std::path::PathBuf::from("test.rs"),
mutated_source: "fn test(a: i32) -> i32 { a - 1 }".to_string(),
location: SourceLocation {
line: 10,
column: 5,
end_line: 10,
end_column: 15,
},
operator: MutationOperatorType::ArithmeticReplacement,
hash: "test_hash".to_string(),
status: MutantStatus::Pending,
}
}
fn create_diverse_training_data() -> Vec<TrainingData> {
let mut data = Vec::new();
for i in 0..20 {
let operator = if i % 2 == 0 {
MutationOperatorType::ArithmeticReplacement
} else {
MutationOperatorType::RelationalReplacement
};
let was_killed = i < 15;
data.push(TrainingData {
mutant: Mutant {
operator,
..create_test_mutant()
},
was_killed,
test_failures: vec![],
execution_time_ms: 100,
});
}
data
}
#[test]
fn test_cross_validate_basic() {
let predictor = SurvivabilityPredictor::new();
let training_data = create_diverse_training_data();
let accuracy = predictor.cross_validate(&training_data, 5);
assert!(accuracy.is_ok());
let acc = accuracy.unwrap();
println!("Cross-validation accuracy: {:.2}%", acc * 100.0);
assert!(
acc >= 0.5,
"Expected accuracy >= 50%, got {:.2}%",
acc * 100.0
);
assert!(acc <= 1.0, "Accuracy should not exceed 100%");
}
#[test]
fn test_cross_validate_with_perfect_data() {
let predictor = SurvivabilityPredictor::new();
let mut data = Vec::new();
for _i in 0..20 {
data.push(TrainingData {
mutant: Mutant {
operator: MutationOperatorType::ArithmeticReplacement,
..create_test_mutant()
},
was_killed: true,
test_failures: vec![],
execution_time_ms: 100,
});
}
let accuracy = predictor.cross_validate(&data, 4).unwrap();
println!("Perfect data accuracy: {:.2}%", accuracy * 100.0);
assert!(accuracy >= 0.8, "Expected high accuracy on perfect data");
}
#[test]
fn test_cross_validate_insufficient_folds() {
let predictor = SurvivabilityPredictor::new();
let training_data = create_diverse_training_data();
let result = predictor.cross_validate(&training_data, 1);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("at least 2"));
}
#[test]
fn test_cross_validate_empty_data() {
let predictor = SurvivabilityPredictor::new();
let result = predictor.cross_validate(&[], 5);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("cannot be empty"));
}
#[test]
fn test_cross_validate_too_many_folds() {
let predictor = SurvivabilityPredictor::new();
let small_data: Vec<_> = create_diverse_training_data().into_iter().take(5).collect();
let result = predictor.cross_validate(&small_data, 10);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Not enough samples"));
}
}