use crate::automl::auto_builder;
use crate::automl::budget::BudgetLedger;
use crate::automl::space::ParamMap;
use crate::automl::{AutoMetric, ModelFactory, RewardNormalizer};
use crate::bandits::DiscountedThompsonSampling;
use crate::drift::adwin::Adwin;
use crate::metrics::ewma::EwmaRegressionMetrics;
use irithyll_core::drift::DriftDetector;
use irithyll_core::error::ConfigError;
use irithyll_core::learner::StreamingLearner;
use tracing::warn;
mod core;
mod racing;
mod scheduler;
pub use self::core::AutoTuner;
pub use self::racing::CandidateSnapshot;
pub use self::scheduler::AutoTunerSnapshot;
#[doc(inline)]
pub use self::scheduler::AutoTunerSnapshot as Snapshot;
#[doc(inline)]
pub use self::racing::CandidateSnapshot as CandidateInfo;
#[derive(Debug, Clone)]
pub struct AutoTunerConfig {
pub n_initial: usize,
pub round_budget: usize,
pub metric: AutoMetric,
pub ewma_span: usize,
pub discount: f64,
pub perturb_sigma: f64,
pub seed: u64,
pub min_n_initial: usize,
pub max_n_initial: usize,
pub use_drift_rerace: bool,
pub auto_builder: bool,
pub meta_objective: auto_builder::MetaObjective,
}
impl Default for AutoTunerConfig {
fn default() -> Self {
Self {
n_initial: 8,
round_budget: 100,
metric: AutoMetric::MAE,
ewma_span: 50,
discount: 0.99,
perturb_sigma: 0.2,
seed: 42,
min_n_initial: 4,
max_n_initial: 32,
use_drift_rerace: false,
auto_builder: false,
meta_objective: auto_builder::MetaObjective::default(),
}
}
}
pub(crate) struct Challenger {
pub model: Box<dyn StreamingLearner>,
pub ewma: EwmaRegressionMetrics,
pub params: ParamMap,
pub factory_idx: usize,
pub err_mean: f64,
pub err_m2: f64,
pub err_count: u64,
pub budget_idx: usize,
}
pub struct AutoTunerBuilder {
pub(crate) factories: Vec<Box<dyn ModelFactory>>,
pub(crate) config: AutoTunerConfig,
}
impl AutoTuner {
pub fn builder() -> AutoTunerBuilder {
AutoTunerBuilder {
factories: Vec::new(),
config: AutoTunerConfig::default(),
}
}
}
impl AutoTunerBuilder {
pub fn factory(mut self, f: impl ModelFactory + 'static) -> Self {
self.factories.clear();
self.factories.push(Box::new(f));
self
}
pub fn add_factory(mut self, f: impl ModelFactory + 'static) -> Self {
self.factories.push(Box::new(f));
self
}
pub fn n_initial(mut self, n: usize) -> Self {
self.config.n_initial = n;
self
}
pub fn round_budget(mut self, b: usize) -> Self {
self.config.round_budget = b;
self
}
pub fn metric(mut self, m: AutoMetric) -> Self {
self.config.metric = m;
self
}
pub fn ewma_span(mut self, s: usize) -> Self {
self.config.ewma_span = s;
self
}
pub fn discount(mut self, d: f64) -> Self {
self.config.discount = d;
self
}
pub fn perturb_sigma(mut self, s: f64) -> Self {
self.config.perturb_sigma = s;
self
}
pub fn seed(mut self, s: u64) -> Self {
self.config.seed = s;
self
}
pub fn min_n_initial(mut self, n: usize) -> Self {
self.config.min_n_initial = n;
self
}
pub fn max_n_initial(mut self, n: usize) -> Self {
self.config.max_n_initial = n;
self
}
pub fn use_drift_rerace(mut self, enabled: bool) -> Self {
self.config.use_drift_rerace = enabled;
self
}
pub fn auto_builder(mut self, enabled: bool) -> Self {
self.config.auto_builder = enabled;
self
}
pub fn meta_objective(mut self, obj: auto_builder::MetaObjective) -> Self {
self.config.meta_objective = obj;
self
}
pub fn build(self) -> Result<AutoTuner, ConfigError> {
if self.factories.is_empty() {
return Err(ConfigError::invalid(
"factories",
"at least one ModelFactory is required",
));
}
let c = &self.config;
if c.n_initial == 0 {
return Err(ConfigError::out_of_range(
"n_initial",
"must be >= 1",
c.n_initial,
));
}
if c.round_budget == 0 {
return Err(ConfigError::out_of_range(
"round_budget",
"must be >= 1",
c.round_budget,
));
}
if c.ewma_span == 0 {
return Err(ConfigError::out_of_range(
"ewma_span",
"must be >= 1",
c.ewma_span,
));
}
if c.discount <= 0.0 || c.discount > 1.0 {
return Err(ConfigError::out_of_range(
"discount",
"must be in (0, 1]",
c.discount,
));
}
if c.perturb_sigma < 0.0 {
return Err(ConfigError::out_of_range(
"perturb_sigma",
"must be >= 0",
c.perturb_sigma,
));
}
if c.min_n_initial > c.max_n_initial {
return Err(ConfigError::invalid(
"min_n_initial",
format!(
"must be <= max_n_initial ({}), got {}",
c.max_n_initial, c.min_n_initial
),
));
}
if c.n_initial < c.min_n_initial || c.n_initial > c.max_n_initial {
return Err(ConfigError::invalid(
"n_initial",
format!(
"must be in [min_n_initial ({}), max_n_initial ({})], got {}",
c.min_n_initial, c.max_n_initial, c.n_initial
),
));
}
let config = self.config;
let seed = if config.seed == 0 { 1 } else { config.seed };
let mut sampler_rngs: Vec<u64> = (0..self.factories.len())
.map(|i| seed.wrapping_add(i as u64).max(1))
.collect();
let n_factory_arms = self.factories.len().max(1);
let bandit = DiscountedThompsonSampling::with_seed(n_factory_arms, config.discount, seed);
let champion_factory_idx = 0;
let champion_space = self.factories[0].config_space();
let champion_params = champion_space
.sample(&mut sampler_rngs[0])
.unwrap_or_else(|e| {
panic!(
"AutoTunerBuilder: initial champion search-space sample failed for factory '{}': {}",
self.factories[0].name(),
e
)
});
let champion = self.factories[0]
.create(&champion_params)
.unwrap_or_else(|e| {
panic!(
"AutoTunerBuilder: initial champion creation failed for factory '{}': {}. \
The factory's config_space() must produce configs that create() accepts.",
self.factories[0].name(),
e
)
});
let champion_ewma = EwmaRegressionMetrics::new(config.ewma_span);
let effective_n_initial = config.n_initial;
let drift_detector: Option<Box<dyn DriftDetector>> = if config.use_drift_rerace {
Some(Box::new(Adwin::default()))
} else {
None
};
let adaptor = if config.auto_builder {
let first_supports = self.factories[0].supports_auto_builder();
if !first_supports {
warn!(
factory = self.factories[0].name(),
"auto_builder=true has no effect for non-SGBT factories; \
the SPSA adaptor requires SGBT-family models. Skipping adaptor."
);
None
} else {
let n_feat = self.factories[0].n_features_hint().max(1);
let region = auto_builder::FeasibleRegion::from_data(100, n_feat, 1.0);
Some(auto_builder::DiagnosticLearner::with_objective(
region,
config.meta_objective,
))
}
} else {
None
};
let mut tuner = AutoTuner {
champion,
champion_ewma,
champion_params,
champion_factory_idx,
candidates: Vec::new(),
current_round: 0,
samples_in_round: 0,
factories: self.factories,
sampler_rngs,
bandit,
normalizer: RewardNormalizer::with_span(config.ewma_span),
config,
total_samples: 0,
promotions: 0,
tournaments_completed: 0,
effective_n_initial,
drift_detector,
adaptor,
last_replacement_count: 0,
budget_ledger: BudgetLedger::new(),
};
tuner.start_tournament();
Ok(tuner)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::automl::Factory;
#[test]
fn tournament_builder_default() {
let tuner = AutoTuner::builder()
.factory(Factory::sgbt(5))
.build()
.expect("valid config");
assert_eq!(
tuner.config.n_initial, 8,
"default n_initial should be 8, got {}",
tuner.config.n_initial
);
assert_eq!(
tuner.config.round_budget, 100,
"default round_budget should be 100, got {}",
tuner.config.round_budget
);
assert_eq!(
tuner.config.metric,
AutoMetric::MAE,
"default metric should be MAE"
);
assert_eq!(
tuner.config.ewma_span, 50,
"default ewma_span should be 50, got {}",
tuner.config.ewma_span
);
assert!(
(tuner.config.discount - 0.99).abs() < 1e-12,
"default discount should be 0.99, got {}",
tuner.config.discount
);
}
#[test]
fn tournament_builder_creates_champion_and_tournament() {
let tuner = AutoTuner::builder()
.factory(Factory::sgbt(3))
.build()
.expect("valid config");
assert_eq!(
tuner.total_samples, 0,
"initial total_samples should be 0, got {}",
tuner.total_samples
);
assert_eq!(
tuner.candidates.len(),
8,
"first tournament should have 8 candidates, got {}",
tuner.candidates.len()
);
}
#[test]
fn tournament_snapshot() {
let tuner = AutoTuner::builder()
.factory(Factory::sgbt(3))
.n_initial(4)
.build()
.expect("valid config");
let snap = tuner.snapshot();
assert_eq!(snap.champion_factory, "SGBT");
assert_eq!(snap.candidates.len(), 4);
}
#[test]
fn tournament_reset() {
let mut tuner = AutoTuner::builder()
.factory(Factory::sgbt(3))
.build()
.expect("valid config");
for i in 0..10 {
let x = [i as f64, 0.5, 0.3];
let y = i as f64 * 0.1;
tuner.train(&x, y);
}
assert!(tuner.total_samples > 0);
tuner.reset();
assert_eq!(
tuner.total_samples, 0,
"total_samples should be 0 after reset"
);
assert_eq!(tuner.promotions, 0, "promotions should be 0 after reset");
assert_eq!(tuner.candidates.len(), 8, "tournament should be restarted");
}
}