use aprender::automl::params::ParamKey;
use aprender::automl::{ParamValue, RandomSearch, SearchSpace, SearchStrategy};
use std::collections::HashMap;
use crate::depyler_training::build_combined_corpus;
use crate::ngram::NgramFixPredictor;
use crate::training::TrainingSample;
use crate::tuning::TuningResult;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum OracleParam {
MinSimilarity,
NgramMin,
NgramMax,
ErrorCodeWeight,
}
impl ParamKey for OracleParam {
fn name(&self) -> &'static str {
match self {
Self::MinSimilarity => "min_similarity",
Self::NgramMin => "ngram_min",
Self::NgramMax => "ngram_max",
Self::ErrorCodeWeight => "error_code_weight",
}
}
}
#[must_use]
pub fn build_oracle_search_space() -> SearchSpace<OracleParam> {
SearchSpace::new()
.add_continuous(OracleParam::MinSimilarity, 0.01, 0.3)
.add(OracleParam::NgramMin, 1..4)
.add(OracleParam::NgramMax, 2..6)
.add_continuous(OracleParam::ErrorCodeWeight, 1.0, 5.0)
}
#[derive(Clone, Debug)]
pub struct AutoMLConfig {
pub min_similarity: f32,
pub ngram_range: (usize, usize),
pub error_code_weight: f32,
}
impl AutoMLConfig {
pub fn from_params(params: &HashMap<OracleParam, ParamValue>) -> Self {
let min_similarity = params
.get(&OracleParam::MinSimilarity)
.and_then(ParamValue::as_f64)
.unwrap_or(0.1) as f32;
let ngram_min = params
.get(&OracleParam::NgramMin)
.and_then(ParamValue::as_i64)
.unwrap_or(1) as usize;
let ngram_max = params
.get(&OracleParam::NgramMax)
.and_then(ParamValue::as_i64)
.unwrap_or(3) as usize;
let error_code_weight = params
.get(&OracleParam::ErrorCodeWeight)
.and_then(ParamValue::as_f64)
.unwrap_or(2.0) as f32;
Self {
min_similarity,
ngram_range: (ngram_min, ngram_max.max(ngram_min)),
error_code_weight,
}
}
}
fn evaluate_config(config: &AutoMLConfig, samples: &[TrainingSample]) -> f64 {
let n = samples.len();
if n == 0 {
return 0.0;
}
let mut correct = 0;
for i in 0..n {
let mut predictor = NgramFixPredictor::new()
.with_min_similarity(config.min_similarity)
.with_ngram_range(config.ngram_range.0, config.ngram_range.1);
for (j, sample) in samples.iter().enumerate() {
if i != j {
let fix = sample.fix.as_deref().unwrap_or("Check error");
let weighted_msg = weight_error_codes(&sample.message, config.error_code_weight);
predictor.learn_pattern(&weighted_msg, fix, sample.category);
}
}
if predictor.fit().is_ok() {
let test_sample = &samples[i];
let weighted_test = weight_error_codes(&test_sample.message, config.error_code_weight);
let suggestions = predictor.predict_fixes(&weighted_test, 1);
if let Some(top) = suggestions.first() {
if top.category == test_sample.category {
correct += 1;
}
}
}
}
correct as f64 / n as f64
}
fn weight_error_codes(message: &str, weight: f32) -> String {
if let Some(code_start) = message.find("error[E") {
if let Some(code_end) = message[code_start..].find(']') {
let code = &message[code_start..code_start + code_end + 1];
let repeat_count = weight.round() as usize;
let repeated = std::iter::repeat_n(code, repeat_count)
.collect::<Vec<_>>()
.join(" ");
return format!("{} {}", repeated, message);
}
}
message.to_string()
}
#[derive(Clone, Debug)]
pub struct AutoMLResult {
pub config: AutoMLConfig,
pub accuracy: f64,
pub trials: usize,
pub history: Vec<(AutoMLConfig, f64)>,
}
#[must_use]
pub fn automl_optimize(n_trials: usize) -> AutoMLResult {
let corpus = build_combined_corpus();
let samples: Vec<_> = corpus.samples().to_vec();
let search_space = build_oracle_search_space();
let mut search = RandomSearch::new(n_trials);
let mut best_config = AutoMLConfig {
min_similarity: 0.1,
ngram_range: (1, 3),
error_code_weight: 2.0,
};
let mut best_accuracy = 0.0;
let mut history = Vec::new();
let trials = search.suggest(&search_space, n_trials);
for trial in trials {
let config = AutoMLConfig::from_params(&trial.values);
let accuracy = evaluate_config(&config, &samples);
history.push((config.clone(), accuracy));
if accuracy > best_accuracy {
best_accuracy = accuracy;
best_config = config;
}
}
AutoMLResult {
config: best_config,
accuracy: best_accuracy,
trials: n_trials,
history,
}
}
#[must_use]
pub fn automl_quick() -> AutoMLResult {
automl_optimize(20)
}
#[must_use]
pub fn automl_full() -> AutoMLResult {
automl_optimize(100)
}
#[must_use]
pub fn automl_extended() -> AutoMLResult {
automl_optimize(300)
}
impl From<AutoMLResult> for TuningResult {
fn from(result: AutoMLResult) -> Self {
TuningResult {
config: crate::tuning::TuningConfig {
min_similarity: result.config.min_similarity,
ngram_range: result.config.ngram_range,
error_code_weight: result.config.error_code_weight,
},
accuracy: result.accuracy as f32,
correct: (result.accuracy * 27.0) as usize, total: 27,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_oracle_param_names() {
assert_eq!(OracleParam::MinSimilarity.name(), "min_similarity");
assert_eq!(OracleParam::NgramMin.name(), "ngram_min");
assert_eq!(OracleParam::NgramMax.name(), "ngram_max");
assert_eq!(OracleParam::ErrorCodeWeight.name(), "error_code_weight");
}
#[test]
fn test_oracle_param_eq() {
assert_eq!(OracleParam::MinSimilarity, OracleParam::MinSimilarity);
assert_ne!(OracleParam::MinSimilarity, OracleParam::NgramMin);
}
#[test]
fn test_oracle_param_debug() {
let param = OracleParam::MinSimilarity;
let debug = format!("{:?}", param);
assert!(debug.contains("MinSimilarity"));
}
#[test]
fn test_oracle_param_clone() {
let param = OracleParam::NgramMax;
let cloned = param;
assert_eq!(param, cloned);
}
#[test]
fn test_oracle_param_hash() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(OracleParam::MinSimilarity);
set.insert(OracleParam::NgramMin);
assert!(set.contains(&OracleParam::MinSimilarity));
assert!(set.contains(&OracleParam::NgramMin));
assert!(!set.contains(&OracleParam::NgramMax));
}
#[test]
fn test_build_search_space() {
let space = build_oracle_search_space();
assert_eq!(space.len(), 4);
}
#[test]
fn test_automl_config_from_params() {
let mut params = HashMap::new();
params.insert(OracleParam::MinSimilarity, ParamValue::Float(0.15));
params.insert(OracleParam::NgramMin, ParamValue::Int(2));
params.insert(OracleParam::NgramMax, ParamValue::Int(4));
params.insert(OracleParam::ErrorCodeWeight, ParamValue::Float(3.0));
let config = AutoMLConfig::from_params(¶ms);
assert!((config.min_similarity - 0.15).abs() < 0.01);
assert_eq!(config.ngram_range, (2, 4));
assert!((config.error_code_weight - 3.0).abs() < 0.01);
}
#[test]
fn test_automl_config_from_empty_params() {
let params = HashMap::new();
let config = AutoMLConfig::from_params(¶ms);
assert!((config.min_similarity - 0.1).abs() < 0.01);
assert_eq!(config.ngram_range, (1, 3));
assert!((config.error_code_weight - 2.0).abs() < 0.01);
}
#[test]
fn test_automl_config_ngram_max_min_check() {
let mut params = HashMap::new();
params.insert(OracleParam::NgramMin, ParamValue::Int(5));
params.insert(OracleParam::NgramMax, ParamValue::Int(3));
let config = AutoMLConfig::from_params(¶ms);
assert_eq!(config.ngram_range, (5, 5));
}
#[test]
fn test_automl_config_clone() {
let config = AutoMLConfig {
min_similarity: 0.2,
ngram_range: (1, 4),
error_code_weight: 2.5,
};
let cloned = config.clone();
assert!((config.min_similarity - cloned.min_similarity).abs() < f32::EPSILON);
assert_eq!(config.ngram_range, cloned.ngram_range);
}
#[test]
fn test_automl_config_debug() {
let config = AutoMLConfig {
min_similarity: 0.1,
ngram_range: (1, 3),
error_code_weight: 2.0,
};
let debug = format!("{:?}", config);
assert!(debug.contains("AutoMLConfig"));
}
#[test]
fn test_weight_error_codes_with_code() {
let msg = "error[E0308]: mismatched types";
let weighted = weight_error_codes(msg, 2.0);
assert!(weighted.contains("error[E0308]"));
assert!(weighted.contains("mismatched types"));
}
#[test]
fn test_weight_error_codes_without_code() {
let msg = "some generic error";
let weighted = weight_error_codes(msg, 2.0);
assert_eq!(weighted, msg);
}
#[test]
fn test_weight_error_codes_higher_weight() {
let msg = "error[E0599]: no method named `foo` found";
let weighted = weight_error_codes(msg, 3.0);
let count = weighted.matches("error[E0599]").count();
assert!(count >= 3);
}
#[test]
fn test_automl_result_debug() {
let result = AutoMLResult {
config: AutoMLConfig {
min_similarity: 0.1,
ngram_range: (1, 3),
error_code_weight: 2.0,
},
accuracy: 0.85,
trials: 10,
history: vec![],
};
let debug = format!("{:?}", result);
assert!(debug.contains("AutoMLResult"));
}
#[test]
fn test_automl_result_clone() {
let result = AutoMLResult {
config: AutoMLConfig {
min_similarity: 0.15,
ngram_range: (2, 4),
error_code_weight: 2.5,
},
accuracy: 0.9,
trials: 20,
history: vec![],
};
let cloned = result.clone();
assert!((result.accuracy - cloned.accuracy).abs() < f64::EPSILON);
assert_eq!(result.trials, cloned.trials);
}
#[test]
fn test_automl_result_to_tuning_result() {
let result = AutoMLResult {
config: AutoMLConfig {
min_similarity: 0.1,
ngram_range: (1, 3),
error_code_weight: 2.0,
},
accuracy: 0.8,
trials: 50,
history: vec![],
};
let tuning: TuningResult = result.into();
assert!((tuning.accuracy - 0.8).abs() < 0.01);
assert_eq!(tuning.total, 27);
}
#[test]
#[ignore] fn test_automl_quick() {
let fast_mode = std::env::var("DEPYLER_FAST_TESTS").is_ok();
let trials = if fast_mode { 3 } else { 20 };
let result = automl_optimize(trials);
assert!(result.accuracy > 0.0);
assert_eq!(result.trials, trials);
assert!(!result.history.is_empty());
println!(
"AutoML Quick: {:.2}% accuracy with sim={:.3}, ngram={:?}, weight={:.1}",
result.accuracy * 100.0,
result.config.min_similarity,
result.config.ngram_range,
result.config.error_code_weight
);
}
#[test]
#[ignore] fn test_automl_full() {
let result = automl_full();
assert!(result.accuracy > 0.0);
assert_eq!(result.trials, 100);
println!(
"AutoML Full: {:.2}% accuracy with sim={:.3}, ngram={:?}, weight={:.1}",
result.accuracy * 100.0,
result.config.min_similarity,
result.config.ngram_range,
result.config.error_code_weight
);
}
#[test]
#[ignore] fn test_automl_extended() {
let result = automl_extended();
assert!(result.accuracy > 0.0);
assert_eq!(result.trials, 300);
println!(
"AutoML Extended (300 trials): {:.2}% accuracy with sim={:.3}, ngram={:?}, weight={:.1}",
result.accuracy * 100.0,
result.config.min_similarity,
result.config.ngram_range,
result.config.error_code_weight
);
let mut sorted: Vec<_> = result.history.clone();
sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
println!("\nTop 5 configurations:");
for (i, (cfg, acc)) in sorted.iter().take(5).enumerate() {
println!(
" {}. {:.2}% - sim={:.3}, ngram=({},{}), weight={:.1}",
i + 1,
acc * 100.0,
cfg.min_similarity,
cfg.ngram_range.0,
cfg.ngram_range.1,
cfg.error_code_weight
);
}
}
}