use super::config::OptimizationMetric;
use super::trial::TrialResult;
#[derive(Debug, Clone)]
pub struct SearchHistory {
trials: Vec<TrialResult>,
best_trial_idx: Option<usize>,
optimization_metric: OptimizationMetric,
}
impl Default for SearchHistory {
fn default() -> Self {
Self {
trials: Vec::new(),
best_trial_idx: None,
optimization_metric: OptimizationMetric::ValidationLoss,
}
}
}
impl SearchHistory {
pub fn new() -> Self {
Self::default()
}
pub fn with_metric(metric: OptimizationMetric) -> Self {
Self {
trials: Vec::new(),
best_trial_idx: None,
optimization_metric: metric,
}
}
pub fn add(&mut self, result: TrialResult) {
let new_idx = self.trials.len();
let is_better = self
.best_trial_idx
.and_then(|idx| self.trials.get(idx))
.map(|best| self.compare_trials(&result, best))
.unwrap_or(true);
self.trials.push(result);
if is_better {
self.best_trial_idx = Some(new_idx);
}
}
pub fn compare_trials(&self, new: &TrialResult, best: &TrialResult) -> bool {
match self.optimization_metric {
OptimizationMetric::ValidationLoss => {
new.val_metric < best.val_metric
}
OptimizationMetric::F1Score => {
match (new.f1_score, best.f1_score) {
(Some(new_f1), Some(best_f1)) if !new_f1.is_nan() && !best_f1.is_nan() => {
new_f1 > best_f1
}
(Some(new_f1), Some(_)) if !new_f1.is_nan() => true,
(Some(_), None) => true,
_ => false,
}
}
OptimizationMetric::RocAuc => {
match (new.roc_auc, best.roc_auc) {
(Some(new_auc), Some(best_auc)) => new_auc > best_auc,
(Some(_), None) => true,
_ => false,
}
}
}
}
pub fn optimization_metric(&self) -> OptimizationMetric {
self.optimization_metric
}
pub fn best(&self) -> Option<&TrialResult> {
self.best_trial_idx.and_then(|idx| self.trials.get(idx))
}
pub fn trials(&self) -> &[TrialResult] {
&self.trials
}
pub fn trials_for_iteration(&self, iteration: usize) -> Vec<&TrialResult> {
self.trials
.iter()
.filter(|t| t.iteration == iteration)
.collect()
}
pub fn len(&self) -> usize {
self.trials.len()
}
pub fn is_empty(&self) -> bool {
self.trials.is_empty()
}
pub fn to_json(&self) -> String {
let mut json = String::from("{\n \"trials\": [\n");
for (i, trial) in self.trials.iter().enumerate() {
json.push_str(" {\n");
json.push_str(&format!(" \"trial_id\": {},\n", trial.trial_id));
json.push_str(&format!(" \"iteration\": {},\n", trial.iteration));
json.push_str(&format!(" \"val_metric\": {},\n", trial.val_metric));
json.push_str(&format!(
" \"train_metric\": {},\n",
trial.train_metric
));
json.push_str(&format!(" \"num_trees\": {},\n", trial.num_trees));
json.push_str(&format!(
" \"train_time_ms\": {},\n",
trial.train_time_ms
));
json.push_str(" \"params\": {\n");
for (j, (k, v)) in trial.params.iter().enumerate() {
let comma = if j < trial.params.len() - 1 { "," } else { "" };
json.push_str(&format!(" \"{}\": {}{}\n", k, v, comma));
}
json.push_str(" }\n");
let comma = if i < self.trials.len() - 1 { "," } else { "" };
json.push_str(&format!(" }}{}\n", comma));
}
json.push_str(" ],\n");
if let Some(best) = self.best() {
json.push_str(&format!(" \"best_trial_id\": {}\n", best.trial_id));
} else {
json.push_str(" \"best_trial_id\": null\n");
}
json.push_str("}\n");
json
}
}
pub type ProgressCallback = Box<dyn Fn(&TrialResult, usize, usize) + Send + Sync>;