use super::{AutoTunerConfig, Challenger};
use crate::automl::auto_builder;
use crate::automl::budget::BudgetLedger;
use crate::automl::space::ParamMap;
use crate::automl::{ModelFactory, RewardNormalizer};
use crate::bandits::{Bandit, DiscountedThompsonSampling};
use crate::metrics::ewma::EwmaRegressionMetrics;
use irithyll_core::drift::{DriftDetector, DriftSignal};
use irithyll_core::learner::StreamingLearner;
pub struct AutoTuner {
pub(crate) champion: Box<dyn StreamingLearner>,
pub(crate) champion_ewma: EwmaRegressionMetrics,
pub(crate) champion_params: ParamMap,
pub(crate) champion_factory_idx: usize,
pub(crate) candidates: Vec<Challenger>,
pub(crate) current_round: usize,
pub(crate) samples_in_round: u64,
pub(crate) factories: Vec<Box<dyn ModelFactory>>,
pub(crate) sampler_rngs: Vec<u64>,
pub(crate) bandit: DiscountedThompsonSampling,
pub(crate) normalizer: RewardNormalizer,
pub(crate) config: AutoTunerConfig,
pub(crate) total_samples: u64,
pub(crate) promotions: u64,
pub(crate) tournaments_completed: u64,
pub(crate) effective_n_initial: usize,
pub(crate) drift_detector: Option<Box<dyn DriftDetector>>,
pub(crate) adaptor: Option<auto_builder::DiagnosticLearner>,
pub(crate) last_replacement_count: u64,
pub(crate) budget_ledger: BudgetLedger,
}
impl AutoTuner {
pub fn promotions(&self) -> u64 {
self.promotions
}
pub fn tournaments_completed(&self) -> u64 {
self.tournaments_completed
}
pub fn total_samples(&self) -> u64 {
self.total_samples
}
pub fn factory_names(&self) -> Vec<&str> {
self.factories.iter().map(|f| f.name()).collect()
}
pub fn candidates_remaining(&self) -> usize {
self.candidates.len()
}
pub fn current_round(&self) -> usize {
self.current_round
}
pub fn effective_n_initial(&self) -> usize {
self.effective_n_initial
}
#[allow(deprecated)]
pub fn check_proactive_prune(&mut self) -> bool {
self.champion.check_proactive_prune()
}
#[allow(deprecated)]
pub fn set_prune_half_life(&mut self, hl: usize) {
self.champion.set_prune_half_life(hl);
}
pub fn snapshot(&self) -> super::AutoTunerSnapshot {
super::scheduler::snapshot_impl(self)
}
pub(crate) fn start_tournament(&mut self) {
super::scheduler::start_tournament(self);
}
}
impl StreamingLearner for AutoTuner {
#[allow(deprecated)]
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
let champ_pred = self.champion.predict(features);
self.champion_ewma.update(target, champ_pred);
self.champion.train_one(features, target, weight);
let mut drift_restart = false;
if let Some(ref mut detector) = self.drift_detector {
let error = (target - champ_pred).abs();
let signal = detector.update(error);
if matches!(signal, DriftSignal::Drift) {
drift_restart = true;
}
}
for c in &mut self.candidates {
let pred = c.model.predict(features);
c.ewma.update(target, pred);
c.model.train_one(features, target, weight);
let error = (target - pred).abs();
c.err_count += 1;
let delta = error - c.err_mean;
c.err_mean += delta / c.err_count as f64;
let delta2 = error - c.err_mean;
c.err_m2 += delta * delta2;
self.budget_ledger.record_sample(c.budget_idx);
}
self.samples_in_round += 1;
self.total_samples += 1;
let check_interval = (self.config.round_budget / 4).max(1) as u64;
if self.samples_in_round > 0
&& self.samples_in_round % check_interval == 0
&& self.samples_in_round < self.config.round_budget as u64
{
super::racing::try_early_elimination(self);
}
if self.samples_in_round >= self.config.round_budget as u64 {
super::scheduler::eliminate_round(self);
}
if let Some(ref mut adaptor) = self.adaptor {
let arr = self.champion.diagnostics_array();
let diagnostics = auto_builder::ConfigDiagnostics {
residual_alignment: arr[0],
regularization_sensitivity: arr[1],
depth_sufficiency: arr[2],
effective_dof: arr[3],
uncertainty: if arr[4] > 0.0 {
arr[4]
} else {
super::racing::get_metric(&self.champion_ewma, self.config.metric)
},
};
let adjustments = adaptor.after_train(&diagnostics, champ_pred, target);
if (adjustments.lr_multiplier - 1.0).abs() > 1e-15
|| adjustments.lambda_direction.abs() > 1e-15
{
self.champion
.adjust_config(adjustments.lr_multiplier, adjustments.lambda_direction);
}
let current_rc = self.champion.replacement_count();
if current_rc > self.last_replacement_count {
self.last_replacement_count = current_rc;
if let Some(change) = adaptor.at_replacement(&diagnostics) {
if change.depth_delta != 0 || change.steps_delta != 0 {
self.champion
.apply_structural_change(change.depth_delta, change.steps_delta);
}
}
}
}
if drift_restart {
self.candidates.clear();
self.effective_n_initial =
(self.effective_n_initial * 2).min(self.config.max_n_initial);
super::scheduler::start_tournament(self);
}
}
fn predict(&self, features: &[f64]) -> f64 {
self.champion.predict(features)
}
fn n_samples_seen(&self) -> u64 {
self.total_samples
}
fn reset(&mut self) {
self.champion.reset();
self.champion_ewma.reset();
self.candidates.clear();
self.budget_ledger.reset();
self.total_samples = 0;
self.promotions = 0;
self.tournaments_completed = 0;
self.samples_in_round = 0;
self.current_round = 0;
self.effective_n_initial = self.config.n_initial;
self.bandit.reset(); self.normalizer.reset();
if let Some(ref mut d) = self.drift_detector {
d.reset();
}
if self.config.auto_builder {
let first_supports = self
.factories
.first()
.map(|f| f.supports_auto_builder())
.unwrap_or(false);
if first_supports {
let n_feat = self
.factories
.first()
.map(|f| f.n_features_hint())
.unwrap_or(1)
.max(1);
let region = auto_builder::FeasibleRegion::from_data(100, n_feat, 1.0);
self.adaptor = Some(auto_builder::DiagnosticLearner::with_objective(
region,
self.config.meta_objective,
));
} else {
self.adaptor = None;
}
} else {
self.adaptor = None;
}
self.last_replacement_count = 0;
super::scheduler::start_tournament(self);
}
#[allow(deprecated)]
fn tree_structure(&self) -> Vec<(usize, usize, f64, f64, u64)> {
self.champion.tree_structure()
}
#[allow(deprecated)]
fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
self.champion.adjust_config(lr_multiplier, lambda_delta);
}
#[allow(deprecated)]
fn apply_structural_change(&mut self, depth_delta: i32, steps_delta: i32) {
self.champion
.apply_structural_change(depth_delta, steps_delta);
}
#[allow(deprecated)]
fn replacement_count(&self) -> u64 {
self.champion.replacement_count()
}
#[allow(deprecated)]
fn readout_weights(&self) -> Option<&[f64]> {
self.champion.readout_weights()
}
}
impl crate::automl::DiagnosticSource for AutoTuner {
#[allow(deprecated)]
fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
let arr = self.champion.diagnostics_array();
Some(crate::automl::ConfigDiagnostics {
residual_alignment: arr[0],
regularization_sensitivity: arr[1],
depth_sufficiency: arr[2],
effective_dof: arr[3],
uncertainty: if arr[4] > 0.0 {
arr[4]
} else {
super::racing::get_metric(&self.champion_ewma, self.config.metric)
},
})
}
}