use super::AutoTuner;
use crate::automl::budget::BudgetStatus;
use crate::metrics::ewma::EwmaRegressionMetrics;
use tracing::warn;
#[derive(Debug, Clone)]
pub struct AutoTunerSnapshot {
pub champion_factory: String,
pub champion_metric: f64,
pub champion_samples: u64,
pub candidates: Vec<super::racing::CandidateSnapshot>,
pub current_round: usize,
pub samples_in_round: u64,
pub effective_n_initial: usize,
pub tournaments_completed: u64,
pub promotions: u64,
pub total_samples: u64,
pub factory_names: Vec<String>,
}
impl std::fmt::Display for AutoTunerSnapshot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "=== AutoTuner Snapshot ===")?;
writeln!(
f,
"Champion: {} (metric: {:.6}, samples: {})",
self.champion_factory, self.champion_metric, self.champion_samples
)?;
writeln!(
f,
"Tournament: round {}, {}/{} candidates, {} samples in round",
self.current_round,
self.candidates.len(),
self.effective_n_initial,
self.samples_in_round
)?;
writeln!(
f,
"History: {} tournaments, {} promotions, {} total samples",
self.tournaments_completed, self.promotions, self.total_samples
)?;
if !self.candidates.is_empty() {
writeln!(f, "Candidates:")?;
for (i, c) in self.candidates.iter().enumerate() {
let warmup_tag = if c.in_warmup { " [warmup]" } else { "" };
writeln!(
f,
" [{i}] {} metric={:.6} samples={}{warmup_tag}",
c.factory_name, c.metric, c.samples_trained
)?;
}
}
Ok(())
}
}
pub(crate) fn snapshot_impl(tuner: &AutoTuner) -> AutoTunerSnapshot {
let metric_type = tuner.config.metric;
let champion_metric = super::racing::get_metric(&tuner.champion_ewma, metric_type);
let champion_factory = tuner
.factories
.get(tuner.champion_factory_idx)
.map(|f| f.name().to_string())
.unwrap_or_default();
let candidates: Vec<super::racing::CandidateSnapshot> = tuner
.candidates
.iter()
.map(|c| {
let factory_name = tuner
.factories
.get(c.factory_idx)
.map(|f| f.name().to_string())
.unwrap_or_default();
let warmup_hint = tuner
.factories
.get(c.factory_idx)
.map(|f| f.warmup_hint())
.unwrap_or(0);
super::racing::CandidateSnapshot {
factory_name,
metric: super::racing::get_metric(&c.ewma, metric_type),
samples_trained: c.ewma.n_samples(),
in_warmup: (c.ewma.n_samples() as usize) < warmup_hint,
}
})
.collect();
AutoTunerSnapshot {
champion_factory,
champion_metric,
champion_samples: tuner.champion.n_samples_seen(),
candidates,
current_round: tuner.current_round,
samples_in_round: tuner.samples_in_round,
effective_n_initial: tuner.effective_n_initial,
tournaments_completed: tuner.tournaments_completed,
promotions: tuner.promotions,
total_samples: tuner.total_samples,
factory_names: tuner
.factories
.iter()
.map(|f| f.name().to_string())
.collect(),
}
}
pub(crate) fn eliminate_round(tuner: &mut AutoTuner) {
if tuner.candidates.is_empty() {
start_tournament(tuner);
return;
}
tuner.budget_ledger.recompute();
let warmup_hints: Vec<usize> = tuner
.candidates
.iter()
.map(|c| {
tuner
.factories
.get(c.factory_idx)
.map(|f| f.warmup_hint())
.unwrap_or(0)
})
.collect();
let metric_type = tuner.config.metric;
let mut protected_indices: Vec<usize> = Vec::new();
let mut eligible: Vec<(usize, f64)> = Vec::new();
let mut exhausted_indices: Vec<usize> = Vec::new();
for (i, c) in tuner.candidates.iter().enumerate() {
if (c.err_count as usize) < warmup_hints[i] {
protected_indices.push(i);
} else if tuner.budget_ledger.status(c.budget_idx) == BudgetStatus::Exhausted {
exhausted_indices.push(i);
} else {
let raw = super::racing::get_metric(&c.ewma, metric_type);
let adj = tuner
.budget_ledger
.adjusted_metric(c.budget_idx, raw, raw.abs().max(1e-12));
eligible.push((i, adj));
}
}
eligible.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let total_to_keep = tuner.candidates.len().div_ceil(2);
let eligible_to_keep = total_to_keep.saturating_sub(protected_indices.len());
let have_eligible = !eligible.is_empty();
let eligible_to_keep = eligible_to_keep.max(if have_eligible { 1 } else { 0 });
let mut keep_set: Vec<usize> = protected_indices;
keep_set.extend(eligible.iter().take(eligible_to_keep).map(|(i, _)| *i));
if keep_set.is_empty() && !exhausted_indices.is_empty() {
keep_set.push(exhausted_indices[0]);
}
let old = std::mem::take(&mut tuner.candidates);
tuner.candidates = old
.into_iter()
.enumerate()
.filter(|(i, _)| keep_set.contains(i))
.map(|(_, c)| c)
.collect();
tuner.current_round += 1;
tuner.samples_in_round = 0;
if tuner.candidates.len() <= 1 {
finalize_tournament(tuner);
}
}
pub(crate) fn finalize_tournament(tuner: &mut AutoTuner) {
use crate::bandits::Bandit;
let mut promoted = false;
if let Some(finalist) = tuner.candidates.pop() {
let metric_type = tuner.config.metric;
let finalist_metric = super::racing::get_metric(&finalist.ewma, metric_type);
let champion_metric = super::racing::get_metric(&tuner.champion_ewma, metric_type);
if finalist_metric < champion_metric && finalist.ewma.n_samples() >= 10 {
tuner.champion = finalist.model;
tuner.champion_ewma = finalist.ewma;
tuner.champion_params = finalist.params;
tuner.champion_factory_idx = finalist.factory_idx;
tuner.promotions += 1;
promoted = true;
}
let reward = tuner
.normalizer
.normalize(finalist_metric.min(champion_metric));
if finalist.factory_idx < tuner.bandit.n_arms() {
tuner.bandit.update(finalist.factory_idx, reward);
}
}
if promoted {
tuner.effective_n_initial = (tuner.effective_n_initial * 2).min(tuner.config.max_n_initial);
} else {
tuner.effective_n_initial = (tuner.effective_n_initial / 2).max(tuner.config.min_n_initial);
}
tuner.candidates.clear();
tuner.tournaments_completed += 1;
start_tournament(tuner);
}
pub(crate) fn start_tournament(tuner: &mut AutoTuner) {
use crate::bandits::Bandit;
let n = tuner.effective_n_initial;
let n_perturb = n / 2;
let n_factories = tuner.factories.len();
let ewma_span = tuner.config.ewma_span;
let sigma = tuner.config.perturb_sigma;
tuner.budget_ledger.reset();
let mut candidates = Vec::with_capacity(n);
let champ_idx = tuner.champion_factory_idx;
for _ in 0..n_perturb {
let space = tuner.factories[champ_idx].config_space();
let params = match space.perturb(
&tuner.champion_params,
sigma,
&mut tuner.sampler_rngs[champ_idx],
) {
Ok(p) => p,
Err(e) => {
warn!(
factory = tuner.factories[champ_idx].name(),
error = %e,
"search-space perturbation unsatisfiable; skipping this slot"
);
continue;
}
};
let complexity = tuner.factories[champ_idx].complexity_hint();
match tuner.factories[champ_idx].create(¶ms) {
Ok(model) => {
let budget_idx = tuner.budget_ledger.add_arm(complexity);
candidates.push(super::Challenger {
model,
ewma: EwmaRegressionMetrics::new(ewma_span),
params,
factory_idx: champ_idx,
err_mean: 0.0,
err_m2: 0.0,
err_count: 0,
budget_idx,
});
}
Err(e) => {
warn!(
factory = tuner.factories[champ_idx].name(),
error = %e,
"factory rejected perturbed arm config; skipping this slot"
);
}
}
}
for _ in n_perturb..n {
let factory_idx = if n_factories > 1 {
tuner.bandit.select_arm() % n_factories
} else {
0
};
let space = tuner.factories[factory_idx].config_space();
let params = match space.sample(&mut tuner.sampler_rngs[factory_idx]) {
Ok(p) => p,
Err(e) => {
warn!(
factory = tuner.factories[factory_idx].name(),
error = %e,
"search-space sampler unsatisfiable; skipping this slot"
);
continue;
}
};
let complexity = tuner.factories[factory_idx].complexity_hint();
match tuner.factories[factory_idx].create(¶ms) {
Ok(model) => {
let budget_idx = tuner.budget_ledger.add_arm(complexity);
candidates.push(super::Challenger {
model,
ewma: EwmaRegressionMetrics::new(ewma_span),
params,
factory_idx,
err_mean: 0.0,
err_m2: 0.0,
err_count: 0,
budget_idx,
});
}
Err(e) => {
warn!(
factory = tuner.factories[factory_idx].name(),
error = %e,
"factory rejected random arm config; skipping this slot"
);
}
}
}
tuner.candidates = candidates;
tuner.current_round = 0;
tuner.samples_in_round = 0;
}