use crate::{ErrorCategory, TrainingDataset};
use anyhow::Result;
use aprender::metaheuristics::{
Budget, DifferentialEvolution, OptimizationResult, PerturbativeMetaheuristic, SearchSpace,
};
use aprender::synthetic::{
DiversityMonitor, DiversityScore, QualityDegradationDetector, SyntheticConfig,
SyntheticGenerator,
};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub struct StdlibFunction {
pub module: String,
pub name: String,
pub signature: String,
pub arg_types: Vec<PyType>,
pub return_type: Option<PyType>,
pub docstring_examples: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum PyType {
Int,
Float,
Str,
Bool,
Bytes,
List(Box<PyType>),
Dict(Box<PyType>, Box<PyType>),
Tuple(Vec<PyType>),
Optional(Box<PyType>),
Any,
Path,
FileHandle,
Callable,
Iterator(Box<PyType>),
}
impl PyType {
#[must_use]
pub fn sample_value(&self) -> String {
match self {
PyType::Int => "42".to_string(),
PyType::Float => "3.14".to_string(),
PyType::Str => "\"hello\"".to_string(),
PyType::Bool => "True".to_string(),
PyType::Bytes => "b\"data\"".to_string(),
PyType::List(inner) => format!("[{}]", inner.sample_value()),
PyType::Dict(k, v) => format!("{{{}: {}}}", k.sample_value(), v.sample_value()),
PyType::Tuple(types) => {
let vals: Vec<_> = types.iter().map(PyType::sample_value).collect();
format!("({})", vals.join(", "))
}
PyType::Optional(inner) => inner.sample_value(),
PyType::Any => "None".to_string(),
PyType::Path => "Path(\"/tmp/test\")".to_string(),
PyType::FileHandle => "open(\"/tmp/test.txt\")".to_string(),
PyType::Callable => "lambda x: x".to_string(),
PyType::Iterator(inner) => format!("iter([{}])", inner.sample_value()),
}
}
}
#[derive(Debug, Clone)]
pub struct PythonExample {
pub source: String,
pub target_function: String,
pub strategy: GenerationStrategy,
pub content_hash: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GenerationStrategy {
DocstringMining,
TypeEnumeration,
EdgeCases,
ErrorInduction,
Composition,
}
#[derive(Debug)]
pub struct PythonExampleGenerator {
stdlib_funcs: Vec<StdlibFunction>,
diversity_monitor: DiversityMonitor,
}
impl PythonExampleGenerator {
#[must_use]
pub fn new(stdlib_funcs: Vec<StdlibFunction>) -> Self {
Self {
stdlib_funcs,
diversity_monitor: DiversityMonitor::new(100), }
}
#[must_use]
pub fn function_count(&self) -> usize {
self.stdlib_funcs.len()
}
#[must_use]
pub fn diversity_score(&self) -> DiversityScore {
self.diversity_monitor.latest().unwrap_or_default()
}
}
impl SyntheticGenerator for PythonExampleGenerator {
type Input = StdlibFunction;
type Output = PythonExample;
fn generate(
&self,
seeds: &[Self::Input],
config: &SyntheticConfig,
) -> aprender::error::Result<Vec<Self::Output>> {
let mut examples = Vec::new();
let target_count = (seeds.len() as f32 * config.augmentation_ratio) as usize;
for func in seeds.iter().take(target_count.max(seeds.len())) {
for doc_example in &func.docstring_examples {
let example = PythonExample {
source: doc_example.clone(),
target_function: format!("{}.{}", func.module, func.name),
strategy: GenerationStrategy::DocstringMining,
content_hash: hash_content(doc_example),
};
if self.quality_score(&example, func) >= config.quality_threshold {
examples.push(example);
}
}
let type_example = generate_type_example(func);
let example = PythonExample {
source: type_example.clone(),
target_function: format!("{}.{}", func.module, func.name),
strategy: GenerationStrategy::TypeEnumeration,
content_hash: hash_content(&type_example),
};
if self.quality_score(&example, func) >= config.quality_threshold {
examples.push(example);
}
let error_example = generate_error_example(func);
let example = PythonExample {
source: error_example.clone(),
target_function: format!("{}.{}", func.module, func.name),
strategy: GenerationStrategy::ErrorInduction,
content_hash: hash_content(&error_example),
};
examples.push(example);
}
Ok(examples)
}
fn quality_score(&self, generated: &Self::Output, _seed: &Self::Input) -> f32 {
let mut score: f32 = 0.5;
if !generated.source.trim().is_empty() {
score += 0.2;
}
if generated.source.contains(
generated
.target_function
.split('.')
.next_back()
.unwrap_or(""),
) {
score += 0.2;
}
if generated.source.len() > 20 {
score += 0.1;
}
score.min(1.0)
}
fn diversity_score(&self, batch: &[Self::Output]) -> f32 {
if batch.is_empty() {
return 0.0;
}
use std::collections::HashSet;
let unique_hashes: HashSet<_> = batch.iter().map(|e| e.content_hash).collect();
let unique_strategies: HashSet<_> = batch.iter().map(|e| e.strategy).collect();
let unique_functions: HashSet<_> = batch.iter().map(|e| &e.target_function).collect();
let hash_diversity = unique_hashes.len() as f32 / batch.len() as f32;
let strategy_diversity = unique_strategies.len() as f32 / 5.0; let function_diversity = unique_functions.len() as f32 / batch.len().min(100) as f32;
(hash_diversity + strategy_diversity + function_diversity) / 3.0
}
}
#[derive(Debug, Clone)]
pub struct TranspileResult {
pub python_source: String,
pub rust_output: Option<String>,
pub transpile_error: Option<String>,
pub compile_errors: Vec<RustcError>,
pub content_hash: u64,
}
#[derive(Debug, Clone)]
pub struct RustcError {
pub code: String,
pub message: String,
pub line: usize,
pub column: usize,
pub suggestion: Option<String>,
}
#[must_use]
pub fn auto_label(error: &RustcError) -> ErrorCategory {
match error.code.as_str() {
"E0308" | "E0277" | "E0282" | "E0283" => ErrorCategory::TypeMismatch,
"E0382" | "E0499" | "E0502" | "E0503" | "E0505" | "E0507" | "E0596" | "E0597" => {
ErrorCategory::BorrowChecker
}
"E0432" | "E0433" | "E0412" => ErrorCategory::MissingImport,
"E0425" | "E0423" | "E0424" | "E0609" => ErrorCategory::SyntaxError,
"E0106" | "E0495" | "E0621" => ErrorCategory::LifetimeError,
"E0599" | "E0600" | "E0369" | "E0631" => ErrorCategory::TraitBound,
_ => ErrorCategory::Other,
}
}
#[derive(Debug, Clone)]
pub struct CorpusConfig {
pub target_samples: usize,
pub batch_size: usize,
pub quality_threshold: f32,
pub max_duplicate_rate: f32,
}
impl Default for CorpusConfig {
fn default() -> Self {
Self {
target_samples: 50_000,
batch_size: 100,
quality_threshold: 0.7,
max_duplicate_rate: 0.05,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CorpusMetrics {
pub total_generated: usize,
pub accepted: usize,
pub rejected_quality: usize,
pub rejected_duplicate: usize,
pub category_distribution: HashMap<ErrorCategory, usize>,
pub unique_error_codes: usize,
pub diversity_score: f32,
}
impl CorpusMetrics {
#[must_use]
pub fn acceptance_rate(&self) -> f32 {
if self.total_generated == 0 {
0.0
} else {
self.accepted as f32 / self.total_generated as f32
}
}
#[must_use]
pub fn duplicate_rate(&self) -> f32 {
if self.total_generated == 0 {
0.0
} else {
self.rejected_duplicate as f32 / self.total_generated as f32
}
}
#[must_use]
pub fn imbalance_ratio(&self) -> f32 {
if self.category_distribution.is_empty() {
return 0.0;
}
let max = *self.category_distribution.values().max().unwrap_or(&0) as f32;
let min = *self
.category_distribution
.values()
.min()
.unwrap_or(&1)
.max(&1) as f32;
max / min
}
}
#[allow(dead_code)] pub struct SelfSupervisedCorpusGenerator {
generator: PythonExampleGenerator,
config: CorpusConfig,
quality_detector: QualityDegradationDetector,
seen_hashes: std::collections::HashSet<u64>,
metrics: CorpusMetrics,
}
impl SelfSupervisedCorpusGenerator {
#[must_use]
pub fn new(stdlib_funcs: Vec<StdlibFunction>, config: CorpusConfig) -> Self {
Self {
generator: PythonExampleGenerator::new(stdlib_funcs),
config: config.clone(),
quality_detector: QualityDegradationDetector::new(config.quality_threshold, 100),
seen_hashes: std::collections::HashSet::new(),
metrics: CorpusMetrics::default(),
}
}
#[must_use]
pub fn metrics(&self) -> &CorpusMetrics {
&self.metrics
}
pub fn generate(&mut self) -> Result<TrainingDataset> {
let dataset = TrainingDataset::new();
Ok(dataset)
}
pub fn add_result(&mut self, result: &TranspileResult) -> bool {
self.metrics.total_generated += 1;
if self.seen_hashes.contains(&result.content_hash) {
self.metrics.rejected_duplicate += 1;
return false;
}
self.seen_hashes.insert(result.content_hash);
for error in &result.compile_errors {
let category = auto_label(error);
*self
.metrics
.category_distribution
.entry(category)
.or_insert(0) += 1;
}
self.metrics.accepted += 1;
true
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct GenerationParams {
pub weight_docstring: f64,
pub weight_type_enum: f64,
pub weight_edge_cases: f64,
pub weight_error_induction: f64,
pub weight_composition: f64,
pub quality_threshold: f64,
pub min_diversity: f64,
pub augmentation_ratio: f64,
}
impl Default for GenerationParams {
fn default() -> Self {
Self {
weight_docstring: 0.3,
weight_type_enum: 0.3,
weight_edge_cases: 0.15,
weight_error_induction: 0.15,
weight_composition: 0.1,
quality_threshold: 0.7,
min_diversity: 0.5,
augmentation_ratio: 2.0,
}
}
}
impl GenerationParams {
pub const DIM: usize = 8;
#[must_use]
pub fn from_vec(params: &[f64]) -> Self {
assert!(
params.len() >= Self::DIM,
"Need {} params, got {}",
Self::DIM,
params.len()
);
let weight_sum = params[0] + params[1] + params[2] + params[3] + params[4];
let norm = if weight_sum > 0.0 { weight_sum } else { 1.0 };
Self {
weight_docstring: params[0] / norm,
weight_type_enum: params[1] / norm,
weight_edge_cases: params[2] / norm,
weight_error_induction: params[3] / norm,
weight_composition: params[4] / norm,
quality_threshold: params[5].clamp(0.1, 0.99),
min_diversity: params[6].clamp(0.1, 0.99),
augmentation_ratio: params[7].clamp(1.0, 10.0),
}
}
#[must_use]
pub fn to_vec(&self) -> Vec<f64> {
vec![
self.weight_docstring,
self.weight_type_enum,
self.weight_edge_cases,
self.weight_error_induction,
self.weight_composition,
self.quality_threshold,
self.min_diversity,
self.augmentation_ratio,
]
}
#[must_use]
pub fn search_space() -> SearchSpace {
SearchSpace::Continuous {
dim: Self::DIM,
lower: vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.1, 1.0],
upper: vec![1.0, 1.0, 1.0, 1.0, 1.0, 0.99, 0.99, 10.0],
}
}
#[must_use]
pub fn strategy_weights(&self) -> HashMap<GenerationStrategy, f64> {
let mut weights = HashMap::new();
weights.insert(GenerationStrategy::DocstringMining, self.weight_docstring);
weights.insert(GenerationStrategy::TypeEnumeration, self.weight_type_enum);
weights.insert(GenerationStrategy::EdgeCases, self.weight_edge_cases);
weights.insert(
GenerationStrategy::ErrorInduction,
self.weight_error_induction,
);
weights.insert(GenerationStrategy::Composition, self.weight_composition);
weights
}
}
#[derive(Debug, Clone)]
pub struct OptimizerConfig {
pub max_evaluations: usize,
pub population_size: usize,
pub mutation_factor: f64,
pub crossover_rate: f64,
pub seed: Option<u64>,
}
impl Default for OptimizerConfig {
fn default() -> Self {
Self {
max_evaluations: 1000,
population_size: 20,
mutation_factor: 0.8,
crossover_rate: 0.9,
seed: Some(42),
}
}
}
#[derive(Debug, Clone)]
pub struct OptimizedResult {
pub params: GenerationParams,
pub fitness: f64,
pub evaluations: usize,
pub history: Vec<f64>,
pub converged: bool,
}
pub struct MetaheuristicOptimizer {
config: OptimizerConfig,
de: DifferentialEvolution,
}
impl MetaheuristicOptimizer {
#[must_use]
pub fn new(config: OptimizerConfig) -> Self {
let mut de = DifferentialEvolution::default();
de.population_size = config.population_size;
de.mutation_factor = config.mutation_factor;
de.crossover_rate = config.crossover_rate;
Self { config, de }
}
pub fn optimize<F>(&mut self, fitness_fn: F) -> OptimizedResult
where
F: Fn(&GenerationParams) -> f64,
{
let space = GenerationParams::search_space();
let budget = Budget::Evaluations(self.config.max_evaluations);
let wrapped_fitness = |raw_params: &[f64]| {
let params = GenerationParams::from_vec(raw_params);
let fitness = fitness_fn(¶ms);
-fitness };
let result: OptimizationResult<Vec<f64>> =
self.de.optimize(&wrapped_fitness, &space, budget);
OptimizedResult {
params: GenerationParams::from_vec(&result.solution),
fitness: -result.objective_value, evaluations: result.evaluations,
history: result.history.iter().map(|v| -v).collect(),
converged: result.converged(),
}
}
#[must_use]
pub fn best(&self) -> Option<GenerationParams> {
self.de.best().map(|v| GenerationParams::from_vec(v))
}
pub fn reset(&mut self) {
self.de.reset();
}
}
#[allow(dead_code)] pub fn evaluate_fitness(
params: &GenerationParams,
stdlib_funcs: &[StdlibFunction],
eval_samples: usize,
) -> f64 {
let config = CorpusConfig {
target_samples: eval_samples,
batch_size: 50,
quality_threshold: params.quality_threshold as f32,
max_duplicate_rate: 0.05,
};
let mut generator = SelfSupervisedCorpusGenerator::new(stdlib_funcs.to_vec(), config);
let _dataset = match generator.generate() {
Ok(ds) => ds,
Err(_) => return 0.0, };
let metrics = generator.metrics();
let acceptance_score = metrics.acceptance_rate() as f64;
let balance_score = 1.0 / (1.0 + metrics.imbalance_ratio() as f64 / 10.0);
let diversity_score = metrics.diversity_score as f64;
let coverage_score = (metrics.unique_error_codes as f64 / 50.0).min(1.0);
0.3 * acceptance_score + 0.3 * balance_score + 0.2 * diversity_score + 0.2 * coverage_score
}
#[derive(Debug, Clone, Default)]
pub struct EvaluationMetrics {
pub corpus_size: usize,
pub uniqueness_rate: f64,
pub class_balance: f64,
pub category_coverage: f64,
pub diversity_score: f64,
pub estimated_accuracy: f64,
pub macro_f1: f64,
}
impl EvaluationMetrics {
#[must_use]
pub fn from_corpus(metrics: &CorpusMetrics, k_fold_accuracy: f64, macro_f1: f64) -> Self {
let total_categories = 7; let covered_categories = metrics.category_distribution.len();
Self {
corpus_size: metrics.accepted,
uniqueness_rate: 1.0 - metrics.duplicate_rate() as f64,
class_balance: 1.0 / (1.0 + metrics.imbalance_ratio() as f64 / 10.0),
category_coverage: covered_categories as f64 / total_categories as f64,
diversity_score: metrics.diversity_score as f64,
estimated_accuracy: k_fold_accuracy,
macro_f1,
}
}
#[must_use]
pub fn meets_thresholds(&self, min_accuracy: f64, min_diversity: f64) -> bool {
self.estimated_accuracy >= min_accuracy && self.diversity_score >= min_diversity
}
#[must_use]
pub fn overall_score(&self) -> f64 {
let weights = [
(self.estimated_accuracy, 0.35),
(self.macro_f1, 0.25),
(self.diversity_score, 0.15),
(self.class_balance, 0.15),
(self.category_coverage, 0.10),
];
weights.iter().map(|(v, w)| v * w).sum()
}
}
#[derive(Debug, Clone)]
pub struct EvaluationConfig {
pub k_folds: usize,
pub min_accuracy: f64,
pub min_diversity: f64,
pub seed: u64,
}
impl Default for EvaluationConfig {
fn default() -> Self {
Self {
k_folds: 5,
min_accuracy: 0.85,
min_diversity: 0.5,
seed: 42,
}
}
}
#[derive(Debug, Clone)]
pub struct BenchmarkResult {
pub name: String,
pub params: GenerationParams,
pub metrics: EvaluationMetrics,
pub generation_time_secs: f64,
pub training_time_secs: f64,
}
impl BenchmarkResult {
#[must_use]
pub fn new(
name: impl Into<String>,
params: GenerationParams,
metrics: EvaluationMetrics,
generation_time_secs: f64,
training_time_secs: f64,
) -> Self {
Self {
name: name.into(),
params,
metrics,
generation_time_secs,
training_time_secs,
}
}
#[must_use]
pub fn is_better_than(&self, other: &Self) -> bool {
self.metrics.overall_score() > other.metrics.overall_score()
}
}
pub struct Evaluator {
config: EvaluationConfig,
results: Vec<BenchmarkResult>,
}
impl Evaluator {
#[must_use]
pub fn new(config: EvaluationConfig) -> Self {
Self {
config,
results: Vec::new(),
}
}
#[must_use]
pub fn config(&self) -> &EvaluationConfig {
&self.config
}
#[must_use]
pub fn results(&self) -> &[BenchmarkResult] {
&self.results
}
pub fn add_result(&mut self, result: BenchmarkResult) {
self.results.push(result);
}
#[must_use]
pub fn best_result(&self) -> Option<&BenchmarkResult> {
self.results.iter().max_by(|a, b| {
a.metrics
.overall_score()
.partial_cmp(&b.metrics.overall_score())
.unwrap_or(std::cmp::Ordering::Equal)
})
}
#[must_use]
pub fn baseline_metrics(&self) -> EvaluationMetrics {
EvaluationMetrics {
corpus_size: 99, uniqueness_rate: 0.95,
class_balance: 0.6,
category_coverage: 0.71, diversity_score: 0.7,
estimated_accuracy: 0.84, macro_f1: 0.80,
}
}
#[must_use]
pub fn improves_over_baseline(&self, metrics: &EvaluationMetrics) -> bool {
let baseline = self.baseline_metrics();
metrics.overall_score() > baseline.overall_score()
}
#[must_use]
pub fn summary_report(&self) -> String {
let mut report = String::new();
report.push_str("=== Self-Supervised Corpus Evaluation Report ===\n\n");
let baseline = self.baseline_metrics();
report.push_str(&format!(
"Baseline: accuracy={:.2}%, F1={:.2}, score={:.3}\n\n",
baseline.estimated_accuracy * 100.0,
baseline.macro_f1,
baseline.overall_score()
));
for (i, result) in self.results.iter().enumerate() {
let improvement = result.metrics.overall_score() - baseline.overall_score();
let status = if improvement > 0.0 { "✓" } else { "✗" };
report.push_str(&format!(
"{}. {} {}\n Accuracy: {:.2}% | F1: {:.2} | Diversity: {:.2}\n Score: {:.3} ({:+.3})\n Time: {:.1}s gen + {:.1}s train\n\n",
i + 1,
result.name,
status,
result.metrics.estimated_accuracy * 100.0,
result.metrics.macro_f1,
result.metrics.diversity_score,
result.metrics.overall_score(),
improvement,
result.generation_time_secs,
result.training_time_secs,
));
}
if let Some(best) = self.best_result() {
report.push_str(&format!("Best configuration: {}\n", best.name));
}
report
}
}
fn hash_content(content: &str) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
content.hash(&mut hasher);
hasher.finish()
}
fn generate_type_example(func: &StdlibFunction) -> String {
let args: Vec<_> = func.arg_types.iter().map(PyType::sample_value).collect();
format!(
"from {} import {}\nresult = {}({})",
func.module,
func.name,
func.name,
args.join(", ")
)
}
fn generate_error_example(func: &StdlibFunction) -> String {
format!(
"from {} import {}\nresult = {}(None) # Wrong type",
func.module, func.name, func.name
)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum DifficultyLevel {
Basic,
Intermediate,
Advanced,
Expert,
}
impl DifficultyLevel {
#[must_use]
pub fn strategies(&self) -> Vec<GenerationStrategy> {
match self {
DifficultyLevel::Basic => vec![GenerationStrategy::DocstringMining],
DifficultyLevel::Intermediate => vec![
GenerationStrategy::DocstringMining,
GenerationStrategy::TypeEnumeration,
],
DifficultyLevel::Advanced => vec![
GenerationStrategy::TypeEnumeration,
GenerationStrategy::EdgeCases,
GenerationStrategy::ErrorInduction,
],
DifficultyLevel::Expert => vec![
GenerationStrategy::EdgeCases,
GenerationStrategy::ErrorInduction,
GenerationStrategy::Composition,
],
}
}
#[must_use]
pub fn weight(&self) -> f64 {
match self {
DifficultyLevel::Basic => 0.3,
DifficultyLevel::Intermediate => 0.3,
DifficultyLevel::Advanced => 0.25,
DifficultyLevel::Expert => 0.15,
}
}
}
#[derive(Debug, Clone)]
pub struct CurriculumScheduler {
current_level: DifficultyLevel,
samples_per_level: usize,
samples_generated: usize,
total_generated: usize,
}
impl CurriculumScheduler {
#[must_use]
pub fn new(samples_per_level: usize) -> Self {
Self {
current_level: DifficultyLevel::Basic,
samples_per_level,
samples_generated: 0,
total_generated: 0,
}
}
#[must_use]
pub fn current_level(&self) -> DifficultyLevel {
self.current_level
}
#[must_use]
pub fn total_generated(&self) -> usize {
self.total_generated
}
pub fn record_sample(&mut self) {
self.samples_generated += 1;
self.total_generated += 1;
}
pub fn try_advance(&mut self) -> bool {
if self.samples_generated >= self.samples_per_level {
match self.current_level {
DifficultyLevel::Basic => {
self.current_level = DifficultyLevel::Intermediate;
self.samples_generated = 0;
true
}
DifficultyLevel::Intermediate => {
self.current_level = DifficultyLevel::Advanced;
self.samples_generated = 0;
true
}
DifficultyLevel::Advanced => {
self.current_level = DifficultyLevel::Expert;
self.samples_generated = 0;
true
}
DifficultyLevel::Expert => false,
}
} else {
false
}
}
#[must_use]
pub fn is_complete(&self) -> bool {
self.current_level == DifficultyLevel::Expert
&& self.samples_generated >= self.samples_per_level
}
pub fn reset(&mut self) {
self.current_level = DifficultyLevel::Basic;
self.samples_generated = 0;
self.total_generated = 0;
}
}
#[derive(Debug, Clone)]
pub struct OptimizationRunConfig {
pub eval_stdlib_count: usize,
pub eval_samples: usize,
pub max_evaluations: usize,
pub use_curriculum: bool,
}
impl Default for OptimizationRunConfig {
fn default() -> Self {
Self {
eval_stdlib_count: 20,
eval_samples: 100,
max_evaluations: 500,
use_curriculum: true,
}
}
}
pub fn run_optimization(
stdlib_funcs: &[StdlibFunction],
config: &OptimizationRunConfig,
) -> OptimizedResult {
let optimizer_config = OptimizerConfig {
max_evaluations: config.max_evaluations,
population_size: 15,
mutation_factor: 0.7,
crossover_rate: 0.9,
seed: Some(42),
};
let mut optimizer = MetaheuristicOptimizer::new(optimizer_config);
let eval_funcs: Vec<_> = stdlib_funcs
.iter()
.take(config.eval_stdlib_count)
.cloned()
.collect();
optimizer.optimize(|params| {
evaluate_fitness_with_curriculum(
params,
&eval_funcs,
config.eval_samples,
config.use_curriculum,
)
})
}
fn evaluate_fitness_with_curriculum(
params: &GenerationParams,
stdlib_funcs: &[StdlibFunction],
eval_samples: usize,
use_curriculum: bool,
) -> f64 {
if !use_curriculum {
return evaluate_fitness(params, stdlib_funcs, eval_samples);
}
let samples_per_level = eval_samples / 4;
let mut scheduler = CurriculumScheduler::new(samples_per_level);
let mut total_fitness = 0.0;
let mut level_count = 0;
while !scheduler.is_complete() {
let level = scheduler.current_level();
let strategies = level.strategies();
let level_fitness =
evaluate_level_fitness(params, stdlib_funcs, &strategies, samples_per_level);
total_fitness += level_fitness * level.weight();
level_count += 1;
for _ in 0..samples_per_level {
scheduler.record_sample();
}
scheduler.try_advance();
}
if level_count > 0 {
total_fitness / level_count as f64 * 4.0 } else {
0.0
}
}
fn evaluate_level_fitness(
_params: &GenerationParams,
stdlib_funcs: &[StdlibFunction],
strategies: &[GenerationStrategy],
_samples: usize,
) -> f64 {
let strategy_diversity = strategies.len() as f64 / 5.0;
let stdlib_coverage = (stdlib_funcs.len() as f64 / 50.0).min(1.0);
(strategy_diversity + stdlib_coverage) / 2.0
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum FixPattern {
TypeConversion,
AddClone,
AddImport,
AddLifetime,
ImplementTrait,
Other(String),
}
impl FixPattern {
#[must_use]
pub fn from_category(category: ErrorCategory) -> Self {
match category {
ErrorCategory::TypeMismatch => FixPattern::TypeConversion,
ErrorCategory::BorrowChecker => FixPattern::AddClone,
ErrorCategory::MissingImport => FixPattern::AddImport,
ErrorCategory::LifetimeError => FixPattern::AddLifetime,
ErrorCategory::TraitBound => FixPattern::ImplementTrait,
ErrorCategory::SyntaxError | ErrorCategory::Other => {
FixPattern::Other("manual_review".to_string())
}
}
}
}
#[derive(Debug, Clone)]
pub struct ExtractedFix {
pub category: ErrorCategory,
pub pattern: FixPattern,
pub error_pattern: String,
pub fix_template: String,
pub confidence: f64,
}
#[must_use]
pub fn extract_fix_pattern(sample: &crate::TrainingSample) -> Option<ExtractedFix> {
let category = sample.category;
let pattern = FixPattern::from_category(category);
let fix_template = match category {
ErrorCategory::TypeMismatch => "value.into() or value as Type".to_string(),
ErrorCategory::BorrowChecker => "value.clone()".to_string(),
ErrorCategory::MissingImport => "use crate::module::Type;".to_string(),
ErrorCategory::LifetimeError => "'a annotation".to_string(),
ErrorCategory::TraitBound => "impl Trait for Type".to_string(),
_ => return None,
};
Some(ExtractedFix {
category,
pattern,
error_pattern: sample.message.clone(),
fix_template,
confidence: 0.7, })
}
#[derive(Debug, Default)]
pub struct CorpusFixPredictor {
patterns: HashMap<ErrorCategory, Vec<ExtractedFix>>,
pattern_count: usize,
}
impl CorpusFixPredictor {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add_fix(&mut self, fix: ExtractedFix) {
self.patterns.entry(fix.category).or_default().push(fix);
self.pattern_count += 1;
}
#[must_use]
pub fn pattern_count(&self) -> usize {
self.pattern_count
}
#[must_use]
pub fn categories(&self) -> Vec<ErrorCategory> {
self.patterns.keys().copied().collect()
}
#[must_use]
pub fn predict(&self, category: ErrorCategory) -> Option<&ExtractedFix> {
self.patterns.get(&category).and_then(|fixes| {
fixes
.iter()
.max_by(|a, b| a.confidence.partial_cmp(&b.confidence).unwrap())
})
}
pub fn train_from_corpus(&mut self, corpus: &crate::TrainingDataset) {
for sample in corpus.samples() {
if let Some(fix) = extract_fix_pattern(sample) {
self.add_fix(fix);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::training::TrainingSample;
#[test]
fn test_pytype_sample_value_int() {
assert_eq!(PyType::Int.sample_value(), "42");
}
#[test]
fn test_pytype_sample_value_str() {
assert_eq!(PyType::Str.sample_value(), "\"hello\"");
}
#[test]
fn test_pytype_sample_value_list() {
let list_type = PyType::List(Box::new(PyType::Int));
assert_eq!(list_type.sample_value(), "[42]");
}
#[test]
fn test_pytype_sample_value_dict() {
let dict_type = PyType::Dict(Box::new(PyType::Str), Box::new(PyType::Int));
assert_eq!(dict_type.sample_value(), "{\"hello\": 42}");
}
#[test]
fn test_stdlib_function_creation() {
let func = StdlibFunction {
module: "os.path".to_string(),
name: "join".to_string(),
signature: "(path, *paths) -> str".to_string(),
arg_types: vec![PyType::Str, PyType::Str],
return_type: Some(PyType::Str),
docstring_examples: vec!["os.path.join('/home', 'user')".to_string()],
};
assert_eq!(func.module, "os.path");
assert_eq!(func.name, "join");
assert_eq!(func.arg_types.len(), 2);
}
fn sample_stdlib_function() -> StdlibFunction {
StdlibFunction {
module: "os.path".to_string(),
name: "join".to_string(),
signature: "(path, *paths) -> str".to_string(),
arg_types: vec![PyType::Str, PyType::Str],
return_type: Some(PyType::Str),
docstring_examples: vec!["os.path.join('/home', 'user')".to_string()],
}
}
#[test]
fn test_python_example_generator_creation() {
let funcs = vec![sample_stdlib_function()];
let gen = PythonExampleGenerator::new(funcs);
assert_eq!(gen.function_count(), 1);
}
#[test]
fn test_python_example_generator_generates_examples() {
let funcs = vec![sample_stdlib_function()];
let gen = PythonExampleGenerator::new(funcs.clone());
let config = SyntheticConfig::default();
let examples = gen
.generate(&funcs, &config)
.expect("generation should succeed");
assert!(
examples.len() >= 2,
"Expected at least 2 examples, got {}",
examples.len()
);
}
#[test]
fn test_python_example_generator_quality_score() {
let func = sample_stdlib_function();
let gen = PythonExampleGenerator::new(vec![func.clone()]);
let good_example = PythonExample {
source: "os.path.join('/home', 'user')".to_string(),
target_function: "os.path.join".to_string(),
strategy: GenerationStrategy::DocstringMining,
content_hash: 12345,
};
let score = gen.quality_score(&good_example, &func);
assert!(
score >= 0.7,
"Good example should have high quality score: {}",
score
);
}
#[test]
fn test_python_example_generator_diversity_score() {
let func = sample_stdlib_function();
let gen = PythonExampleGenerator::new(vec![func]);
let examples = vec![
PythonExample {
source: "example1".to_string(),
target_function: "os.path.join".to_string(),
strategy: GenerationStrategy::DocstringMining,
content_hash: 1,
},
PythonExample {
source: "example2".to_string(),
target_function: "os.path.join".to_string(),
strategy: GenerationStrategy::TypeEnumeration,
content_hash: 2,
},
PythonExample {
source: "example3".to_string(),
target_function: "os.path.exists".to_string(),
strategy: GenerationStrategy::ErrorInduction,
content_hash: 3,
},
];
let score = SyntheticGenerator::diversity_score(&gen, &examples);
assert!(
score > 0.5,
"Diverse examples should have high diversity: {:.2}",
score
);
}
#[test]
fn test_auto_label_type_mismatch() {
let error = RustcError {
code: "E0308".to_string(),
message: "mismatched types".to_string(),
line: 10,
column: 5,
suggestion: None,
};
assert_eq!(auto_label(&error), ErrorCategory::TypeMismatch);
}
#[test]
fn test_auto_label_borrow_checker() {
let error = RustcError {
code: "E0382".to_string(),
message: "use of moved value".to_string(),
line: 10,
column: 5,
suggestion: None,
};
assert_eq!(auto_label(&error), ErrorCategory::BorrowChecker);
}
#[test]
fn test_auto_label_missing_import() {
let error = RustcError {
code: "E0433".to_string(),
message: "failed to resolve".to_string(),
line: 10,
column: 5,
suggestion: None,
};
assert_eq!(auto_label(&error), ErrorCategory::MissingImport);
}
#[test]
fn test_auto_label_lifetime() {
let error = RustcError {
code: "E0106".to_string(),
message: "missing lifetime specifier".to_string(),
line: 10,
column: 5,
suggestion: None,
};
assert_eq!(auto_label(&error), ErrorCategory::LifetimeError);
}
#[test]
fn test_auto_label_trait_bound() {
let error = RustcError {
code: "E0599".to_string(),
message: "no method named".to_string(),
line: 10,
column: 5,
suggestion: None,
};
assert_eq!(auto_label(&error), ErrorCategory::TraitBound);
}
#[test]
fn test_auto_label_unknown() {
let error = RustcError {
code: "E9999".to_string(),
message: "unknown error".to_string(),
line: 10,
column: 5,
suggestion: None,
};
assert_eq!(auto_label(&error), ErrorCategory::Other);
}
#[test]
fn test_corpus_config_defaults() {
let config = CorpusConfig::default();
assert_eq!(config.target_samples, 50_000);
assert_eq!(config.batch_size, 100);
assert!((config.quality_threshold - 0.7).abs() < f32::EPSILON);
}
#[test]
fn test_corpus_metrics_acceptance_rate() {
let metrics = CorpusMetrics {
total_generated: 100,
accepted: 80,
..Default::default()
};
assert!((metrics.acceptance_rate() - 0.8).abs() < f32::EPSILON);
}
#[test]
fn test_corpus_metrics_duplicate_rate() {
let metrics = CorpusMetrics {
total_generated: 100,
rejected_duplicate: 5,
..Default::default()
};
assert!((metrics.duplicate_rate() - 0.05).abs() < f32::EPSILON);
}
#[test]
fn test_corpus_metrics_imbalance_ratio() {
let mut metrics = CorpusMetrics::default();
metrics
.category_distribution
.insert(ErrorCategory::TypeMismatch, 100);
metrics
.category_distribution
.insert(ErrorCategory::BorrowChecker, 50);
metrics
.category_distribution
.insert(ErrorCategory::Other, 10);
assert!((metrics.imbalance_ratio() - 10.0).abs() < f32::EPSILON);
}
#[test]
fn test_self_supervised_generator_creation() {
let funcs = vec![sample_stdlib_function()];
let config = CorpusConfig::default();
let gen = SelfSupervisedCorpusGenerator::new(funcs, config);
assert_eq!(gen.metrics().total_generated, 0);
assert_eq!(gen.metrics().accepted, 0);
}
#[test]
fn test_self_supervised_generator_add_result() {
let funcs = vec![sample_stdlib_function()];
let config = CorpusConfig::default();
let mut gen = SelfSupervisedCorpusGenerator::new(funcs, config);
let result = TranspileResult {
python_source: "test code".to_string(),
rust_output: Some("fn main() {}".to_string()),
transpile_error: None,
compile_errors: vec![RustcError {
code: "E0308".to_string(),
message: "mismatched types".to_string(),
line: 1,
column: 1,
suggestion: None,
}],
content_hash: 12345,
};
assert!(gen.add_result(&result));
assert_eq!(gen.metrics().accepted, 1);
assert_eq!(
gen.metrics()
.category_distribution
.get(&ErrorCategory::TypeMismatch),
Some(&1)
);
}
#[test]
fn test_self_supervised_generator_deduplication() {
let funcs = vec![sample_stdlib_function()];
let config = CorpusConfig::default();
let mut gen = SelfSupervisedCorpusGenerator::new(funcs, config);
let result = TranspileResult {
python_source: "test code".to_string(),
rust_output: None,
transpile_error: None,
compile_errors: vec![],
content_hash: 12345, };
assert!(gen.add_result(&result));
assert!(!gen.add_result(&result)); assert_eq!(gen.metrics().rejected_duplicate, 1);
}
#[test]
fn test_generate_type_example() {
let func = sample_stdlib_function();
let example = generate_type_example(&func);
assert!(example.contains("from os.path import join"));
assert!(example.contains("join("));
}
#[test]
fn test_generate_error_example() {
let func = sample_stdlib_function();
let example = generate_error_example(&func);
assert!(example.contains("None"));
assert!(example.contains("join"));
}
#[test]
fn test_hash_content_deterministic() {
let content = "test content";
let hash1 = hash_content(content);
let hash2 = hash_content(content);
assert_eq!(hash1, hash2);
}
#[test]
fn test_hash_content_different_for_different_content() {
let hash1 = hash_content("content A");
let hash2 = hash_content("content B");
assert_ne!(hash1, hash2);
}
#[test]
fn test_generation_params_default() {
let params = GenerationParams::default();
let weight_sum = params.weight_docstring
+ params.weight_type_enum
+ params.weight_edge_cases
+ params.weight_error_induction
+ params.weight_composition;
assert!((weight_sum - 1.0).abs() < 0.01, "Weights should sum to 1.0");
assert!(params.quality_threshold >= 0.0 && params.quality_threshold <= 1.0);
}
#[test]
fn test_generation_params_from_vec() {
let raw = vec![0.2, 0.3, 0.1, 0.2, 0.2, 0.75, 0.6, 3.0];
let params = GenerationParams::from_vec(&raw);
let weight_sum = params.weight_docstring
+ params.weight_type_enum
+ params.weight_edge_cases
+ params.weight_error_induction
+ params.weight_composition;
assert!(
(weight_sum - 1.0).abs() < 0.001,
"Weights should be normalized"
);
assert!((params.quality_threshold - 0.75).abs() < 0.001);
assert!((params.augmentation_ratio - 3.0).abs() < 0.001);
}
#[test]
fn test_generation_params_to_vec() {
let params = GenerationParams::default();
let vec = params.to_vec();
assert_eq!(vec.len(), GenerationParams::DIM);
assert!((vec[5] - 0.7).abs() < 0.001); assert!((vec[7] - 2.0).abs() < 0.001); }
#[test]
fn test_generation_params_roundtrip() {
let original = GenerationParams::default();
let vec = original.to_vec();
let restored = GenerationParams::from_vec(&vec);
assert!((original.quality_threshold - restored.quality_threshold).abs() < 0.001);
assert!((original.min_diversity - restored.min_diversity).abs() < 0.001);
}
#[test]
fn test_generation_params_search_space() {
let space = GenerationParams::search_space();
match space {
SearchSpace::Continuous { dim, lower, upper } => {
assert_eq!(dim, GenerationParams::DIM);
assert_eq!(lower.len(), dim);
assert_eq!(upper.len(), dim);
for i in 0..dim {
assert!(lower[i] <= upper[i], "Invalid bounds at dim {}", i);
}
}
_ => panic!("Expected Continuous search space"),
}
}
#[test]
fn test_generation_params_strategy_weights() {
let params = GenerationParams::default();
let weights = params.strategy_weights();
assert_eq!(weights.len(), 5);
assert!(weights.contains_key(&GenerationStrategy::DocstringMining));
assert!(weights.contains_key(&GenerationStrategy::TypeEnumeration));
assert!(weights.contains_key(&GenerationStrategy::EdgeCases));
assert!(weights.contains_key(&GenerationStrategy::ErrorInduction));
assert!(weights.contains_key(&GenerationStrategy::Composition));
}
#[test]
fn test_generation_params_clamp_bounds() {
let raw = vec![0.5, 0.5, 0.0, 0.0, 0.0, -0.5, 2.0, 0.1];
let params = GenerationParams::from_vec(&raw);
assert!(params.quality_threshold >= 0.1 && params.quality_threshold <= 0.99);
assert!(params.min_diversity >= 0.1 && params.min_diversity <= 0.99);
assert!(params.augmentation_ratio >= 1.0 && params.augmentation_ratio <= 10.0);
}
#[test]
fn test_optimizer_config_default() {
let config = OptimizerConfig::default();
assert!(config.max_evaluations > 0);
assert!(config.population_size > 0);
assert!(config.mutation_factor >= 0.0 && config.mutation_factor <= 2.0);
assert!(config.crossover_rate >= 0.0 && config.crossover_rate <= 1.0);
}
#[test]
fn test_metaheuristic_optimizer_creation() {
let config = OptimizerConfig {
max_evaluations: 100,
population_size: 10,
mutation_factor: 0.5,
crossover_rate: 0.7,
seed: Some(42),
};
let optimizer = MetaheuristicOptimizer::new(config);
assert!(optimizer.best().is_none()); }
#[test]
fn test_metaheuristic_optimizer_simple_fitness() {
let config = OptimizerConfig {
max_evaluations: 50, population_size: 10,
mutation_factor: 0.8,
crossover_rate: 0.9,
seed: Some(42),
};
let mut optimizer = MetaheuristicOptimizer::new(config);
let result = optimizer.optimize(|params| params.quality_threshold);
assert!(result.fitness > 0.5, "Should improve from initial");
assert!(result.evaluations > 0);
assert!(!result.history.is_empty());
}
#[test]
fn test_metaheuristic_optimizer_reset() {
let config = OptimizerConfig {
max_evaluations: 20,
population_size: 5,
..Default::default()
};
let mut optimizer = MetaheuristicOptimizer::new(config.clone());
let _ = optimizer.optimize(|_| 0.5);
assert!(optimizer.best().is_some());
optimizer.reset();
assert!(optimizer.best().is_none());
}
#[test]
fn test_optimized_result_fields() {
let config = OptimizerConfig {
max_evaluations: 30,
population_size: 5,
..Default::default()
};
let mut optimizer = MetaheuristicOptimizer::new(config);
let result = optimizer.optimize(|_| 0.75);
assert!(result.fitness >= 0.0);
assert!(result.evaluations > 0);
assert!(!result.history.is_empty());
assert!(result.params.quality_threshold >= 0.1);
}
#[test]
fn test_evaluate_fitness_empty_stdlib() {
let params = GenerationParams::default();
let fitness = evaluate_fitness(¶ms, &[], 10);
assert!((0.0..=1.0).contains(&fitness));
}
#[test]
fn test_evaluate_fitness_with_sample_stdlib() {
let stdlib_funcs = vec![sample_stdlib_function()];
let params = GenerationParams::default();
let fitness = evaluate_fitness(¶ms, &stdlib_funcs, 10);
assert!((0.0..=1.0).contains(&fitness));
}
#[test]
fn test_evaluation_metrics_default() {
let metrics = EvaluationMetrics::default();
assert_eq!(metrics.corpus_size, 0);
assert!((metrics.estimated_accuracy - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_evaluation_metrics_from_corpus() {
let mut corpus_metrics = CorpusMetrics {
accepted: 1000,
total_generated: 1100,
rejected_duplicate: 50,
diversity_score: 0.8,
..Default::default()
};
corpus_metrics
.category_distribution
.insert(ErrorCategory::TypeMismatch, 300);
corpus_metrics
.category_distribution
.insert(ErrorCategory::BorrowChecker, 200);
corpus_metrics
.category_distribution
.insert(ErrorCategory::Other, 500);
let eval_metrics = EvaluationMetrics::from_corpus(&corpus_metrics, 0.92, 0.88);
assert_eq!(eval_metrics.corpus_size, 1000);
assert!(eval_metrics.uniqueness_rate > 0.9);
assert!((eval_metrics.estimated_accuracy - 0.92).abs() < 0.001);
assert!((eval_metrics.macro_f1 - 0.88).abs() < 0.001);
}
#[test]
fn test_evaluation_metrics_meets_thresholds() {
let metrics = EvaluationMetrics {
estimated_accuracy: 0.90,
diversity_score: 0.7,
..Default::default()
};
assert!(metrics.meets_thresholds(0.85, 0.5));
assert!(!metrics.meets_thresholds(0.95, 0.5));
assert!(!metrics.meets_thresholds(0.85, 0.8));
}
#[test]
fn test_evaluation_metrics_overall_score() {
let metrics = EvaluationMetrics {
estimated_accuracy: 0.95,
macro_f1: 0.93,
diversity_score: 0.8,
class_balance: 0.9,
category_coverage: 1.0,
..Default::default()
};
let score = metrics.overall_score();
let expected = 0.95 * 0.35 + 0.93 * 0.25 + 0.8 * 0.15 + 0.9 * 0.15 + 1.0 * 0.10;
assert!((score - expected).abs() < 0.001);
}
#[test]
fn test_evaluation_config_default() {
let config = EvaluationConfig::default();
assert_eq!(config.k_folds, 5);
assert!((config.min_accuracy - 0.85).abs() < 0.001);
assert!((config.min_diversity - 0.5).abs() < 0.001);
}
#[test]
fn test_benchmark_result_creation() {
let params = GenerationParams::default();
let metrics = EvaluationMetrics {
estimated_accuracy: 0.92,
..Default::default()
};
let result = BenchmarkResult::new("Test Config", params, metrics, 10.5, 5.2);
assert_eq!(result.name, "Test Config");
assert!((result.generation_time_secs - 10.5).abs() < 0.001);
assert!((result.training_time_secs - 5.2).abs() < 0.001);
}
#[test]
fn test_benchmark_result_comparison() {
let params = GenerationParams::default();
let better = BenchmarkResult::new(
"Better",
params.clone(),
EvaluationMetrics {
estimated_accuracy: 0.95,
macro_f1: 0.93,
..Default::default()
},
1.0,
1.0,
);
let worse = BenchmarkResult::new(
"Worse",
params,
EvaluationMetrics {
estimated_accuracy: 0.80,
macro_f1: 0.75,
..Default::default()
},
1.0,
1.0,
);
assert!(better.is_better_than(&worse));
assert!(!worse.is_better_than(&better));
}
#[test]
fn test_evaluator_creation() {
let config = EvaluationConfig::default();
let evaluator = Evaluator::new(config);
assert!(evaluator.results().is_empty());
assert_eq!(evaluator.config().k_folds, 5);
}
#[test]
fn test_evaluator_add_results() {
let config = EvaluationConfig::default();
let mut evaluator = Evaluator::new(config);
let result1 = BenchmarkResult::new(
"Config A",
GenerationParams::default(),
EvaluationMetrics {
estimated_accuracy: 0.90,
..Default::default()
},
1.0,
1.0,
);
let result2 = BenchmarkResult::new(
"Config B",
GenerationParams::default(),
EvaluationMetrics {
estimated_accuracy: 0.95,
..Default::default()
},
1.0,
1.0,
);
evaluator.add_result(result1);
evaluator.add_result(result2);
assert_eq!(evaluator.results().len(), 2);
}
#[test]
fn test_evaluator_best_result() {
let config = EvaluationConfig::default();
let mut evaluator = Evaluator::new(config);
evaluator.add_result(BenchmarkResult::new(
"Low",
GenerationParams::default(),
EvaluationMetrics {
estimated_accuracy: 0.80,
macro_f1: 0.75,
..Default::default()
},
1.0,
1.0,
));
evaluator.add_result(BenchmarkResult::new(
"High",
GenerationParams::default(),
EvaluationMetrics {
estimated_accuracy: 0.95,
macro_f1: 0.93,
..Default::default()
},
1.0,
1.0,
));
evaluator.add_result(BenchmarkResult::new(
"Medium",
GenerationParams::default(),
EvaluationMetrics {
estimated_accuracy: 0.88,
macro_f1: 0.85,
..Default::default()
},
1.0,
1.0,
));
let best = evaluator.best_result().expect("Should have best result");
assert_eq!(best.name, "High");
}
#[test]
fn test_evaluator_baseline_metrics() {
let evaluator = Evaluator::new(EvaluationConfig::default());
let baseline = evaluator.baseline_metrics();
assert_eq!(baseline.corpus_size, 99);
assert!((baseline.estimated_accuracy - 0.84).abs() < 0.01);
}
#[test]
fn test_evaluator_improves_over_baseline() {
let evaluator = Evaluator::new(EvaluationConfig::default());
let improved = EvaluationMetrics {
corpus_size: 5000,
uniqueness_rate: 0.98,
class_balance: 0.9,
category_coverage: 1.0,
diversity_score: 0.85,
estimated_accuracy: 0.95,
macro_f1: 0.93,
};
let worse = EvaluationMetrics {
estimated_accuracy: 0.70,
macro_f1: 0.65,
..Default::default()
};
assert!(evaluator.improves_over_baseline(&improved));
assert!(!evaluator.improves_over_baseline(&worse));
}
#[test]
fn test_evaluator_summary_report() {
let config = EvaluationConfig::default();
let mut evaluator = Evaluator::new(config);
evaluator.add_result(BenchmarkResult::new(
"Test Config",
GenerationParams::default(),
EvaluationMetrics {
estimated_accuracy: 0.92,
macro_f1: 0.90,
diversity_score: 0.8,
..Default::default()
},
15.5,
3.2,
));
let report = evaluator.summary_report();
assert!(report.contains("Self-Supervised Corpus Evaluation Report"));
assert!(report.contains("Baseline"));
assert!(report.contains("Test Config"));
assert!(report.contains("92.00%")); }
#[test]
fn test_evaluator_empty_best_result() {
let evaluator = Evaluator::new(EvaluationConfig::default());
assert!(evaluator.best_result().is_none());
}
#[test]
fn test_difficulty_level_ordering() {
assert!(DifficultyLevel::Basic < DifficultyLevel::Intermediate);
assert!(DifficultyLevel::Intermediate < DifficultyLevel::Advanced);
assert!(DifficultyLevel::Advanced < DifficultyLevel::Expert);
}
#[test]
fn test_difficulty_level_strategies() {
let basic_strategies = DifficultyLevel::Basic.strategies();
assert!(basic_strategies.contains(&GenerationStrategy::DocstringMining));
let advanced_strategies = DifficultyLevel::Advanced.strategies();
assert!(advanced_strategies.contains(&GenerationStrategy::ErrorInduction));
let expert_strategies = DifficultyLevel::Expert.strategies();
assert!(expert_strategies.contains(&GenerationStrategy::Composition));
}
#[test]
fn test_difficulty_level_weight() {
let total_weight: f64 = DifficultyLevel::Basic.weight()
+ DifficultyLevel::Intermediate.weight()
+ DifficultyLevel::Advanced.weight()
+ DifficultyLevel::Expert.weight();
assert!(
(total_weight - 1.0).abs() < 0.001,
"Weights should sum to 1.0"
);
}
#[test]
fn test_curriculum_scheduler_creation() {
let scheduler = CurriculumScheduler::new(100);
assert_eq!(scheduler.current_level(), DifficultyLevel::Basic);
assert_eq!(scheduler.total_generated(), 0);
assert!(!scheduler.is_complete());
}
#[test]
fn test_curriculum_scheduler_record_sample() {
let mut scheduler = CurriculumScheduler::new(10);
for _ in 0..5 {
scheduler.record_sample();
}
assert_eq!(scheduler.total_generated(), 5);
assert_eq!(scheduler.current_level(), DifficultyLevel::Basic);
}
#[test]
fn test_curriculum_scheduler_advance() {
let mut scheduler = CurriculumScheduler::new(2);
scheduler.record_sample();
scheduler.record_sample();
let advanced = scheduler.try_advance();
assert!(advanced);
assert_eq!(scheduler.current_level(), DifficultyLevel::Intermediate);
}
#[test]
fn test_curriculum_scheduler_full_progression() {
let mut scheduler = CurriculumScheduler::new(1);
scheduler.record_sample();
scheduler.try_advance();
scheduler.record_sample();
scheduler.try_advance();
scheduler.record_sample();
scheduler.try_advance();
scheduler.record_sample();
let at_end = !scheduler.try_advance();
assert!(at_end);
assert!(scheduler.is_complete());
assert_eq!(scheduler.current_level(), DifficultyLevel::Expert);
}
#[test]
fn test_curriculum_scheduler_reset() {
let mut scheduler = CurriculumScheduler::new(1);
scheduler.record_sample();
scheduler.try_advance();
scheduler.reset();
assert_eq!(scheduler.current_level(), DifficultyLevel::Basic);
assert_eq!(scheduler.total_generated(), 0);
}
#[test]
fn test_optimization_run_config_default() {
let config = OptimizationRunConfig::default();
assert!(config.eval_stdlib_count > 0);
assert!(config.eval_samples > 0);
assert!(config.max_evaluations > 0);
assert!(config.use_curriculum); }
#[test]
fn test_run_optimization_basic() {
let stdlib_funcs = vec![sample_stdlib_function()];
let config = OptimizationRunConfig {
eval_stdlib_count: 1,
eval_samples: 5,
max_evaluations: 10,
use_curriculum: false,
};
let result = run_optimization(&stdlib_funcs, &config);
assert!(result.fitness >= 0.0);
assert!(result.evaluations > 0);
}
#[test]
fn test_run_optimization_with_curriculum() {
let stdlib_funcs = vec![sample_stdlib_function()];
let config = OptimizationRunConfig {
eval_stdlib_count: 1,
eval_samples: 5,
max_evaluations: 10,
use_curriculum: true,
};
let result = run_optimization(&stdlib_funcs, &config);
assert!(result.fitness >= 0.0);
assert!(result.evaluations > 0);
}
#[test]
fn test_fix_pattern_from_category() {
assert!(matches!(
FixPattern::from_category(ErrorCategory::TypeMismatch),
FixPattern::TypeConversion
));
assert!(matches!(
FixPattern::from_category(ErrorCategory::BorrowChecker),
FixPattern::AddClone
));
assert!(matches!(
FixPattern::from_category(ErrorCategory::MissingImport),
FixPattern::AddImport
));
assert!(matches!(
FixPattern::from_category(ErrorCategory::LifetimeError),
FixPattern::AddLifetime
));
assert!(matches!(
FixPattern::from_category(ErrorCategory::TraitBound),
FixPattern::ImplementTrait
));
}
#[test]
fn test_fix_pattern_other() {
let pattern = FixPattern::from_category(ErrorCategory::Other);
assert!(matches!(pattern, FixPattern::Other(_)));
}
#[test]
fn test_corpus_fix_predictor_creation() {
let predictor = CorpusFixPredictor::new();
assert_eq!(predictor.pattern_count(), 0);
assert!(predictor.predict(ErrorCategory::TypeMismatch).is_none());
}
#[test]
fn test_corpus_fix_predictor_add_fix() {
let mut predictor = CorpusFixPredictor::new();
let fix = ExtractedFix {
category: ErrorCategory::TypeMismatch,
pattern: FixPattern::TypeConversion,
error_pattern: "expected i32, found String".to_string(),
fix_template: "use .parse::<i32>()".to_string(),
confidence: 0.9,
};
predictor.add_fix(fix);
assert_eq!(predictor.pattern_count(), 1);
}
#[test]
fn test_corpus_fix_predictor_predict() {
let mut predictor = CorpusFixPredictor::new();
predictor.add_fix(ExtractedFix {
category: ErrorCategory::TypeMismatch,
pattern: FixPattern::TypeConversion,
error_pattern: "type mismatch".to_string(),
fix_template: "use .into()".to_string(),
confidence: 0.8,
});
predictor.add_fix(ExtractedFix {
category: ErrorCategory::TypeMismatch,
pattern: FixPattern::TypeConversion,
error_pattern: "expected struct".to_string(),
fix_template: "use From::from()".to_string(),
confidence: 0.9,
});
let prediction = predictor.predict(ErrorCategory::TypeMismatch);
assert!(prediction.is_some());
let fix = prediction.unwrap();
assert!((fix.confidence - 0.9).abs() < 0.001);
}
#[test]
fn test_corpus_fix_predictor_train_from_corpus() {
use crate::TrainingDataset;
let mut predictor = CorpusFixPredictor::new();
let mut dataset = TrainingDataset::new();
dataset.add(TrainingSample {
message: "mismatched types expected i32 found String".to_string(),
category: ErrorCategory::TypeMismatch,
fix: Some("use .parse::<i32>()".to_string()),
});
dataset.add(TrainingSample {
message: "cannot borrow as mutable".to_string(),
category: ErrorCategory::BorrowChecker,
fix: Some("use .clone()".to_string()),
});
predictor.train_from_corpus(&dataset);
assert!(predictor.pattern_count() >= 2);
}
#[test]
fn test_extract_fix_pattern_type_mismatch() {
let sample = TrainingSample {
message: "type mismatch error".to_string(),
category: ErrorCategory::TypeMismatch,
fix: Some("convert type".to_string()),
};
let extracted = extract_fix_pattern(&sample);
assert!(extracted.is_some());
let fix = extracted.unwrap();
assert_eq!(fix.category, ErrorCategory::TypeMismatch);
assert!(matches!(fix.pattern, FixPattern::TypeConversion));
}
#[test]
fn test_extract_fix_pattern_syntax_error() {
let sample = TrainingSample {
message: "syntax error".to_string(),
category: ErrorCategory::SyntaxError,
fix: None,
};
let extracted = extract_fix_pattern(&sample);
assert!(extracted.is_none());
}
}