use crate::transpiler::{CodeFeatures, TranspilerVerdict};
use std::path::Path;
#[derive(Debug, Clone)]
pub struct TrainingExample {
pub features: CodeFeatures,
pub is_bug: bool,
}
#[derive(Debug, Clone)]
pub struct TrainingConfig {
pub train_ratio: f64,
pub cv_folds: usize,
pub seed: u64,
pub min_examples: usize,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
train_ratio: 0.8,
cv_folds: 5,
seed: 42,
min_examples: 100,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TrainingMetrics {
pub accuracy: f64,
pub precision: f64,
pub recall: f64,
pub f1_score: f64,
pub auc_roc: f64,
pub train_size: usize,
pub test_size: usize,
}
impl TrainingMetrics {
#[must_use]
pub fn calculate_f1(precision: f64, recall: f64) -> f64 {
if precision + recall == 0.0 {
0.0
} else {
2.0 * precision * recall / (precision + recall)
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CrossValidationResults {
pub fold_metrics: Vec<TrainingMetrics>,
pub mean_accuracy: f64,
pub std_accuracy: f64,
pub mean_f1: f64,
}
impl CrossValidationResults {
#[must_use]
pub fn summarize(fold_metrics: Vec<TrainingMetrics>) -> Self {
if fold_metrics.is_empty() {
return Self::default();
}
let n = fold_metrics.len() as f64;
let mean_accuracy = fold_metrics.iter().map(|m| m.accuracy).sum::<f64>() / n;
let mean_f1 = fold_metrics.iter().map(|m| m.f1_score).sum::<f64>() / n;
let variance = fold_metrics
.iter()
.map(|m| (m.accuracy - mean_accuracy).powi(2))
.sum::<f64>()
/ n;
let std_accuracy = variance.sqrt();
Self {
fold_metrics,
mean_accuracy,
std_accuracy,
mean_f1,
}
}
}
pub trait TrainedModel: Send + Sync {
fn predict(&self, features: &CodeFeatures) -> f64;
fn save(&self, path: &Path) -> std::io::Result<()>;
fn metadata(&self) -> ModelMetadata;
}
#[derive(Debug, Clone)]
pub struct ModelMetadata {
pub model_type: String,
pub trained_at: String,
pub train_examples: usize,
pub metrics: TrainingMetrics,
}
pub trait ModelTrainer {
fn train(
&self,
examples: &[TrainingExample],
config: &TrainingConfig,
) -> Result<Box<dyn TrainedModel>, TrainingError>;
fn cross_validate(
&self,
examples: &[TrainingExample],
config: &TrainingConfig,
) -> Result<CrossValidationResults, TrainingError>;
}
#[derive(Debug, Clone)]
pub enum TrainingError {
InsufficientData {
required: usize,
provided: usize,
},
InvalidConfig(String),
TrainingFailed(String),
IoError(String),
}
impl std::fmt::Display for TrainingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InsufficientData { required, provided } => {
write!(f, "Insufficient data: need {required}, got {provided}")
}
Self::InvalidConfig(msg) => write!(f, "Invalid config: {msg}"),
Self::TrainingFailed(msg) => write!(f, "Training failed: {msg}"),
Self::IoError(msg) => write!(f, "IO error: {msg}"),
}
}
}
impl std::error::Error for TrainingError {}
#[must_use]
pub fn verdict_to_label(verdict: &TranspilerVerdict) -> bool {
!matches!(verdict, TranspilerVerdict::Pass)
}
#[must_use]
pub fn train_test_split(
examples: &[TrainingExample],
train_ratio: f64,
seed: u64,
) -> (Vec<TrainingExample>, Vec<TrainingExample>) {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut train = Vec::new();
let mut test = Vec::new();
for (i, example) in examples.iter().enumerate() {
let mut hasher = DefaultHasher::new();
(seed, i).hash(&mut hasher);
let hash = hasher.finish();
#[allow(clippy::cast_sign_loss)]
let threshold = (train_ratio * u64::MAX as f64) as u64;
if hash < threshold {
train.push(example.clone());
} else {
test.push(example.clone());
}
}
(train, test)
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_examples(n: usize) -> Vec<TrainingExample> {
(0..n)
.map(|i| TrainingExample {
features: CodeFeatures {
ast_depth: i % 5,
cyclomatic_complexity: i % 10,
..Default::default()
},
is_bug: i % 3 == 0,
})
.collect()
}
#[test]
fn test_training_config_default() {
let config = TrainingConfig::default();
assert_eq!(config.train_ratio, 0.8);
assert_eq!(config.cv_folds, 5);
assert_eq!(config.min_examples, 100);
}
#[test]
fn test_training_metrics_f1() {
assert_eq!(TrainingMetrics::calculate_f1(0.8, 0.6), 0.6857142857142857);
assert_eq!(TrainingMetrics::calculate_f1(0.0, 0.0), 0.0);
assert_eq!(TrainingMetrics::calculate_f1(1.0, 1.0), 1.0);
}
#[test]
fn test_cross_validation_summarize() {
let folds = vec![
TrainingMetrics {
accuracy: 0.8,
f1_score: 0.75,
..Default::default()
},
TrainingMetrics {
accuracy: 0.85,
f1_score: 0.80,
..Default::default()
},
TrainingMetrics {
accuracy: 0.9,
f1_score: 0.85,
..Default::default()
},
];
let cv = CrossValidationResults::summarize(folds);
assert!((cv.mean_accuracy - 0.85).abs() < 0.001);
assert!((cv.mean_f1 - 0.8).abs() < 0.001);
assert!(cv.std_accuracy > 0.0);
}
#[test]
fn test_cross_validation_empty() {
let cv = CrossValidationResults::summarize(vec![]);
assert_eq!(cv.mean_accuracy, 0.0);
assert_eq!(cv.fold_metrics.len(), 0);
}
#[test]
fn test_verdict_to_label() {
assert!(!verdict_to_label(&TranspilerVerdict::Pass));
assert!(verdict_to_label(&TranspilerVerdict::OutputMismatch));
assert!(verdict_to_label(&TranspilerVerdict::TranspileError(
"err".into()
)));
assert!(verdict_to_label(&TranspilerVerdict::Timeout));
}
#[test]
fn test_train_test_split_ratio() {
let examples = sample_examples(1000);
let (train, test) = train_test_split(&examples, 0.8, 42);
let train_ratio = train.len() as f64 / examples.len() as f64;
assert!(train_ratio > 0.7 && train_ratio < 0.9);
assert_eq!(train.len() + test.len(), examples.len());
}
#[test]
fn test_train_test_split_deterministic() {
let examples = sample_examples(100);
let (train1, _) = train_test_split(&examples, 0.8, 42);
let (train2, _) = train_test_split(&examples, 0.8, 42);
assert_eq!(train1.len(), train2.len());
}
#[test]
fn test_training_error_display() {
let err = TrainingError::InsufficientData {
required: 100,
provided: 50,
};
assert!(err.to_string().contains("100"));
assert!(err.to_string().contains("50"));
}
#[test]
fn test_model_metadata_clone() {
let meta = ModelMetadata {
model_type: "RandomForest".into(),
trained_at: "2025-01-01".into(),
train_examples: 1000,
metrics: TrainingMetrics::default(),
};
let cloned = meta.clone();
assert_eq!(cloned.model_type, meta.model_type);
}
#[test]
#[ignore = "requires aprender ml feature"]
fn test_random_forest_training() {
unimplemented!("RandomForest training not yet implemented")
}
#[test]
#[ignore = "requires aprender ml feature"]
fn test_cross_validation_with_model() {
unimplemented!("Cross-validation not yet implemented")
}
#[test]
#[ignore = "requires aprender ml feature"]
fn test_model_save_load() {
unimplemented!("Model save/load not yet implemented")
}
#[test]
#[ignore = "requires aprender ml feature"]
fn test_stratified_split() {
unimplemented!("Stratified split not yet implemented")
}
}