mod feasible;
#[cfg(feature = "distill")]
mod distillation;
#[cfg(feature = "distill")]
#[cfg_attr(docsrs, doc(cfg(feature = "distill")))]
pub use distillation::{DistillationConfig, DistillationStats};
#[cfg(feature = "distill")]
use distillation::CandidateDistillState;
pub use feasible::FeasibleRegion;
use crate::automl::ModelFactory;
use crate::ensemble::config::SGBTConfig;
use crate::learner::SGBTLearner;
use irithyll_core::learner::StreamingLearner;
use std::collections::VecDeque;
use tracing::warn;
const DRIFT_WINDOW: usize = 1024;
#[derive(Debug, Clone, Default)]
pub struct ConfigDiagnostics {
pub residual_alignment: f64,
pub regularization_sensitivity: f64,
pub depth_sufficiency: f64,
pub effective_dof: f64,
pub uncertainty: f64,
}
#[derive(Debug, Clone)]
pub struct ConfigBounds {
pub max_depth: (usize, usize),
pub n_steps: (usize, usize),
pub grace_period: (usize, usize),
pub learning_rate: (f64, f64),
pub lambda: (f64, f64),
pub n_bins: (usize, usize),
pub feature_subsample: (f64, f64),
}
#[derive(Debug, Clone, Default)]
pub struct WelfordStats {
pub n: u64,
pub mean_error: f64,
pub m2: f64,
pub dir_correct: u64,
}
impl WelfordStats {
pub fn update(&mut self, error: f64) {
self.n += 1;
let delta = error - self.mean_error;
self.mean_error += delta / self.n as f64;
let delta2 = error - self.mean_error;
self.m2 += delta * delta2;
}
pub fn update_dir(&mut self, prediction: f64, target: f64) {
if (prediction >= 0.0) == (target >= 0.0) {
self.dir_correct += 1;
}
}
pub fn dir_accuracy(&self) -> f64 {
if self.n == 0 {
f64::NAN
} else {
self.dir_correct as f64 / self.n as f64
}
}
pub fn variance(&self) -> f64 {
if self.n > 1 {
self.m2 / (self.n - 1) as f64
} else {
0.0
}
}
pub fn std_error(&self) -> f64 {
if self.n > 1 {
(self.variance() / self.n as f64).sqrt()
} else {
f64::INFINITY
}
}
}
struct RaceCandidate {
model: Box<dyn StreamingLearner>,
stats: WelfordStats,
config_idx: usize,
}
impl core::fmt::Debug for RaceCandidate {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("RaceCandidate")
.field("stats", &self.stats)
.field("config_idx", &self.config_idx)
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone)]
pub struct RaceResults {
pub winner_idx: usize,
pub winner_mean_error: f64,
pub all_results: Vec<(usize, f64, f64, u64)>,
}
#[derive(Clone, Debug, PartialEq)]
#[non_exhaustive]
pub enum TerminateAfter {
Samples(u64),
Corrections(usize),
Duration(std::time::Duration),
Never,
}
pub struct WelfordRace {
candidates: Vec<RaceCandidate>,
termination: TerminateAfter,
terminated: bool,
correction_count: usize,
first_feed_at: Option<std::time::Instant>,
#[cfg(feature = "distill")]
distill_cfg: Option<DistillationConfig>,
#[cfg(feature = "distill")]
distill_state: Vec<CandidateDistillState>,
#[cfg(feature = "distill")]
distill_stats: DistillationStats,
last_winner_idx: Option<usize>,
winner_change_count: u64,
samples_at_last_winner_change: u64,
drift_recent_errors: VecDeque<f64>,
}
impl core::fmt::Debug for WelfordRace {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("WelfordRace")
.field("n_candidates", &self.candidates.len())
.field(
"n_samples",
&self.candidates.first().map(|c| c.stats.n).unwrap_or(0),
)
.field("terminated", &self.terminated)
.field("termination", &self.termination)
.finish()
}
}
impl WelfordRace {
pub fn new(configs: Vec<SGBTConfig>) -> Self {
let candidates = configs
.into_iter()
.enumerate()
.map(|(i, config)| RaceCandidate {
model: Box::new(SGBTLearner::from_config(config)),
stats: WelfordStats::default(),
config_idx: i,
})
.collect();
Self {
candidates,
termination: TerminateAfter::Never,
terminated: false,
correction_count: 0,
first_feed_at: None,
#[cfg(feature = "distill")]
distill_cfg: None,
#[cfg(feature = "distill")]
distill_state: Vec::new(),
#[cfg(feature = "distill")]
distill_stats: DistillationStats {
disabled: true,
n_distillations_triggered: 0,
last_distillation_at_samples: None,
candidates_currently_distilling: Vec::new(),
},
last_winner_idx: None,
winner_change_count: 0,
samples_at_last_winner_change: 0,
drift_recent_errors: VecDeque::with_capacity(DRIFT_WINDOW),
}
}
pub fn from_factory(factory: &dyn ModelFactory, k: usize, seed: u64) -> Self {
let space = factory.config_space();
let mut rng = if seed == 0 { 1 } else { seed };
let mut candidates = Vec::with_capacity(k);
for i in 0..k {
let params = match space.sample(&mut rng) {
Ok(p) => p,
Err(e) => {
warn!(
factory = factory.name(),
error = %e,
"search-space sampler unsatisfiable in WelfordRace::from_factory; skipping slot"
);
continue;
}
};
match factory.create(¶ms) {
Ok(model) => {
candidates.push(RaceCandidate {
model,
stats: WelfordStats::default(),
config_idx: i,
});
}
Err(e) => {
warn!(
factory = factory.name(),
error = %e,
"factory rejected config in WelfordRace::from_factory; skipping slot"
);
}
}
}
Self {
candidates,
termination: TerminateAfter::Never,
terminated: false,
correction_count: 0,
first_feed_at: None,
#[cfg(feature = "distill")]
distill_cfg: None,
#[cfg(feature = "distill")]
distill_state: Vec::new(),
#[cfg(feature = "distill")]
distill_stats: DistillationStats {
disabled: true,
n_distillations_triggered: 0,
last_distillation_at_samples: None,
candidates_currently_distilling: Vec::new(),
},
last_winner_idx: None,
winner_change_count: 0,
samples_at_last_winner_change: 0,
drift_recent_errors: VecDeque::with_capacity(DRIFT_WINDOW),
}
}
pub fn feed(&mut self, features: &[f64], target: f64) {
if self.terminated {
return;
}
if self.first_feed_at.is_none() {
self.first_feed_at = Some(std::time::Instant::now());
}
for c in &mut self.candidates {
let pred = c.model.predict(features);
let error = (target - pred).abs();
c.stats.update(error);
c.stats.update_dir(pred, target);
c.model.train_one(features, target, 1.0);
}
#[cfg(feature = "distill")]
if self.distill_cfg.is_some() {
let winner_idx = self
.candidates
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
a.stats
.mean_error
.partial_cmp(&b.stats.mean_error)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0);
let pareto_front = vec![winner_idx];
self.run_distillation_pass(features, target, &pareto_front, winner_idx);
}
let current_winner = self.current_winner_idx_scalar();
if current_winner != self.last_winner_idx {
self.winner_change_count += 1;
self.samples_at_last_winner_change = self.n_samples();
self.last_winner_idx = current_winner;
}
if let Some(winner_cfg_idx) = current_winner {
let winner_pos = self
.candidates
.iter()
.position(|c| c.config_idx == winner_cfg_idx);
if let Some(pos) = winner_pos {
let sq_err = {
let winner_pred = self.candidates[pos].model.predict(features);
(target - winner_pred).powi(2)
};
if self.drift_recent_errors.len() >= DRIFT_WINDOW {
self.drift_recent_errors.pop_front();
}
self.drift_recent_errors.push_back(sq_err);
}
}
self.recompute_termination();
}
pub fn select_winner(self) -> (Box<dyn StreamingLearner>, RaceResults) {
let mut results: Vec<(usize, f64, f64, u64)> = self
.candidates
.iter()
.map(|c| {
(
c.config_idx,
c.stats.mean_error,
c.stats.std_error(),
c.stats.n,
)
})
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let winner_idx = results[0].0;
let winner_mean = results[0].1;
let winner_model = self
.candidates
.into_iter()
.find(|c| c.config_idx == winner_idx)
.map(|c| c.model)
.expect("winner must exist in candidates");
(
winner_model,
RaceResults {
winner_idx,
winner_mean_error: winner_mean,
all_results: results,
},
)
}
pub fn n_candidates(&self) -> usize {
self.candidates.len()
}
pub fn n_samples(&self) -> u64 {
self.candidates.first().map(|c| c.stats.n).unwrap_or(0)
}
fn current_winner_idx_scalar(&self) -> Option<usize> {
self.candidates
.iter()
.filter(|c| c.stats.n > 0)
.min_by(|a, b| {
a.stats
.mean_error
.partial_cmp(&b.stats.mean_error)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|c| c.config_idx)
}
}
impl WelfordRace {
pub fn with_termination(mut self, criterion: TerminateAfter) -> Self {
self.termination = criterion;
self
}
#[inline]
pub fn is_terminated(&self) -> bool {
self.terminated
}
pub fn signal_correction(&mut self) {
if self.terminated {
return;
}
self.correction_count += 1;
self.recompute_termination();
}
pub fn samples_until_termination(&self) -> Option<u64> {
match self.termination {
TerminateAfter::Samples(n) => {
let seen = self.n_samples();
Some(n.saturating_sub(seen))
}
_ => None,
}
}
fn recompute_termination(&mut self) {
if self.terminated {
return; }
self.terminated = match self.termination {
TerminateAfter::Samples(n) => self.n_samples() >= n,
TerminateAfter::Corrections(k) => self.correction_count >= k,
TerminateAfter::Duration(d) => self
.first_feed_at
.map(|t| t.elapsed() >= d)
.unwrap_or(false),
TerminateAfter::Never => false,
};
}
}
impl WelfordRace {
pub fn samples_since_last_winner_change(&self) -> u64 {
self.n_samples()
.saturating_sub(self.samples_at_last_winner_change)
}
pub fn winner_change_count(&self) -> u64 {
self.winner_change_count
}
pub fn race_drift_score(&self) -> f64 {
let half = DRIFT_WINDOW / 2; if self.drift_recent_errors.len() < half {
return 0.0;
}
let mid = self.drift_recent_errors.len() / 2;
let baseline: f64 = self.drift_recent_errors.iter().take(mid).sum::<f64>() / mid as f64;
let recent: f64 = self.drift_recent_errors.iter().skip(mid).sum::<f64>()
/ (self.drift_recent_errors.len() - mid) as f64;
(recent - baseline) / (baseline.abs() + 1e-12)
}
}
#[derive(Debug, Clone, Copy)]
struct CandidateSignals {
mean_error: f64,
se_error: f64,
empirical_sigma: f64,
n_steps: f64,
dir_accuracy: f64,
}
impl CandidateSignals {
#[allow(deprecated)]
fn from_candidate(c: &RaceCandidate) -> Self {
let diag = c.model.diagnostics_array();
CandidateSignals {
mean_error: c.stats.mean_error,
se_error: c.stats.std_error(),
empirical_sigma: diag[4],
n_steps: c.stats.n as f64,
dir_accuracy: c.stats.dir_accuracy(),
}
}
fn has_nan(&self) -> bool {
self.mean_error.is_nan()
|| self.se_error.is_nan()
|| self.empirical_sigma.is_nan()
|| self.n_steps.is_nan()
|| self.dir_accuracy.is_nan()
}
}
fn pareto_dominates(a: &CandidateSignals, b: &CandidateSignals) -> bool {
let no_worse_lower = a.mean_error <= b.mean_error
&& a.se_error <= b.se_error
&& a.empirical_sigma <= b.empirical_sigma;
let no_worse_higher = a.n_steps >= b.n_steps && a.dir_accuracy >= b.dir_accuracy;
let strictly_better = a.mean_error < b.mean_error
|| a.se_error < b.se_error
|| a.empirical_sigma < b.empirical_sigma
|| a.n_steps > b.n_steps
|| a.dir_accuracy > b.dir_accuracy;
no_worse_lower && no_worse_higher && strictly_better
}
impl WelfordRace {
pub fn pareto_front(&self) -> Vec<usize> {
let signals: Vec<Option<CandidateSignals>> = self
.candidates
.iter()
.map(|c| {
let sig = CandidateSignals::from_candidate(c);
if sig.has_nan() {
None
} else {
Some(sig)
}
})
.collect();
(0..self.candidates.len())
.filter(|&i| {
let Some(sig_i) = signals[i] else {
return false;
};
!signals.iter().enumerate().any(|(j, sig_j_opt)| {
if j == i {
return false;
}
match sig_j_opt {
Some(sig_j) => pareto_dominates(sig_j, &sig_i),
None => false,
}
})
})
.collect()
}
pub fn pareto_winner_idx(&self) -> Option<usize> {
let front = self.pareto_front();
if front.is_empty() {
return None;
}
if front.len() == 1 {
return Some(front[0]);
}
use crate::automl::racing::{bernstein_compare, BERNSTEIN_DELTA};
let arm_stats: Vec<crate::automl::racing::ArmStats> = front
.iter()
.map(|&idx| {
let c = &self.candidates[idx];
let n = c.stats.n;
let range = if n > 1 {
4.0 * (c.stats.m2 / (n - 1) as f64).sqrt()
} else {
0.0
};
(c.stats.mean_error, c.stats.m2, n, range)
})
.collect();
if let Some(front_slot) = bernstein_compare(&arm_stats, BERNSTEIN_DELTA) {
return Some(front[front_slot]);
}
front.into_iter().min_by(|&a, &b| {
self.candidates[a]
.stats
.mean_error
.partial_cmp(&self.candidates[b].stats.mean_error)
.unwrap_or(std::cmp::Ordering::Equal)
})
}
#[cfg(test)]
pub(super) fn inject_stats_for_test(
&mut self,
pos: usize,
n: u64,
mean_error: f64,
m2: f64,
dir_correct: u64,
) {
self.candidates[pos].stats.n = n;
self.candidates[pos].stats.mean_error = mean_error;
self.candidates[pos].stats.m2 = m2;
self.candidates[pos].stats.dir_correct = dir_correct;
}
}
#[derive(Debug, Clone, Default)]
pub struct SmoothAdjustments {
pub lr_multiplier: f64,
pub lambda_direction: f64,
}
#[derive(Debug, Clone, Default)]
pub struct StructuralChange {
pub depth_delta: i32,
pub steps_delta: i32,
}
#[derive(Debug, Clone, Copy, Default)]
#[non_exhaustive]
pub enum MetaObjective {
#[default]
MinimizeRMSE,
MaximizeR2,
MaximizeDirection,
MaximizeF1,
MaximizeKappa,
Composite {
rmse_weight: f64,
r2_weight: f64,
dir_weight: f64,
},
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum SPSAPhase {
Init,
PerturbPlus,
PerturbMinus,
}
#[derive(Debug)]
pub struct DiagnosticLearner {
uncertainty_ewma: f64,
alignment_ewma: f64,
reg_sensitivity_ewma: f64,
depth_signal_ewma: f64,
dof_ewma: f64,
alpha: f64,
region: FeasibleRegion,
n_samples: u64,
initialized: bool,
observation_interval: u64,
theta: [f64; 2],
theta_best: [f64; 2],
best_performance: f64,
a_init: f64,
a: f64,
c_init: f64,
c_floor: [f64; 2],
k_local: u64,
big_a: f64,
phase: SPSAPhase,
current_delta: [f64; 2],
perf_plus: f64,
perf_minus: f64,
samples_in_phase: u64,
cusum_s: f64,
perf_ewma_baseline: f64,
perf_variance: f64,
last_emitted_theta: [f64; 2],
total_steps: u64,
rng_state: u64,
objective: MetaObjective,
squared_error_ewma: f64,
target_ewma: f64,
target_var_ewma: f64,
direction_ewma: f64,
tp_ewma: f64,
fp_ewma: f64,
fn_ewma: f64,
accuracy_ewma: f64,
pos_rate_ewma: f64,
pred_pos_rate_ewma: f64,
}
impl DiagnosticLearner {
pub fn new(region: FeasibleRegion) -> Self {
Self::with_objective(region, MetaObjective::default())
}
pub fn with_objective(region: FeasibleRegion, objective: MetaObjective) -> Self {
let bounds = region.config_bounds();
let observation_interval =
((bounds.grace_period.0 + bounds.grace_period.1) / 2).clamp(1, 50) as u64;
let big_a = 10.0;
let a_init = 0.05 * (big_a + 1.0_f64).powf(0.602);
let c_floor = [0.001; 2];
Self {
uncertainty_ewma: 0.0,
alignment_ewma: 0.0,
reg_sensitivity_ewma: 0.0,
depth_signal_ewma: 0.0,
dof_ewma: 0.0,
alpha: 1.0 - (-2.0 / observation_interval as f64).exp(),
region,
n_samples: 0,
initialized: false,
observation_interval,
theta: [0.5, 0.5],
theta_best: [0.5, 0.5],
best_performance: f64::NEG_INFINITY,
a_init,
a: a_init,
c_init: 0.1,
c_floor,
k_local: 0,
big_a,
phase: SPSAPhase::Init,
current_delta: [0.0; 2],
perf_plus: 0.0,
perf_minus: 0.0,
samples_in_phase: 0,
cusum_s: 0.0,
perf_ewma_baseline: 0.0,
perf_variance: 0.0,
last_emitted_theta: [0.5, 0.5],
total_steps: 0,
rng_state: 0xDEAD_BEEF_CAFE_1234,
objective,
squared_error_ewma: 0.0,
target_ewma: 0.0,
target_var_ewma: 0.0,
direction_ewma: 0.5,
tp_ewma: 0.0,
fp_ewma: 0.0,
fn_ewma: 0.0,
accuracy_ewma: 0.5,
pos_rate_ewma: 0.5,
pred_pos_rate_ewma: 0.5,
}
}
pub fn after_train(
&mut self,
diagnostics: &ConfigDiagnostics,
prediction: f64,
target: f64,
) -> SmoothAdjustments {
self.n_samples += 1;
let a = self.alpha;
self.uncertainty_ewma = a * diagnostics.uncertainty + (1.0 - a) * self.uncertainty_ewma;
self.alignment_ewma = a * diagnostics.residual_alignment + (1.0 - a) * self.alignment_ewma;
self.reg_sensitivity_ewma =
a * diagnostics.regularization_sensitivity + (1.0 - a) * self.reg_sensitivity_ewma;
self.depth_signal_ewma =
a * diagnostics.depth_sufficiency + (1.0 - a) * self.depth_signal_ewma;
self.dof_ewma = a * diagnostics.effective_dof + (1.0 - a) * self.dof_ewma;
let error = target - prediction;
self.squared_error_ewma = a * (error * error) + (1.0 - a) * self.squared_error_ewma;
let old_target_ewma = self.target_ewma;
self.target_ewma = a * target + (1.0 - a) * self.target_ewma;
let dev = target - old_target_ewma;
self.target_var_ewma = a * (dev * dev) + (1.0 - a) * self.target_var_ewma;
let correct_dir = if (prediction * target) >= 0.0 {
1.0
} else {
0.0
};
self.direction_ewma = a * correct_dir + (1.0 - a) * self.direction_ewma;
let predicted_positive = prediction > 0.5;
let actual_positive = target > 0.5;
let tp = if predicted_positive && actual_positive {
1.0
} else {
0.0
};
let fp = if predicted_positive && !actual_positive {
1.0
} else {
0.0
};
let fn_ = if !predicted_positive && actual_positive {
1.0
} else {
0.0
};
self.tp_ewma = a * tp + (1.0 - a) * self.tp_ewma;
self.fp_ewma = a * fp + (1.0 - a) * self.fp_ewma;
self.fn_ewma = a * fn_ + (1.0 - a) * self.fn_ewma;
let correct = if (predicted_positive && actual_positive)
|| (!predicted_positive && !actual_positive)
{
1.0
} else {
0.0
};
self.accuracy_ewma = a * correct + (1.0 - a) * self.accuracy_ewma;
self.pos_rate_ewma =
a * (if actual_positive { 1.0 } else { 0.0 }) + (1.0 - a) * self.pos_rate_ewma;
self.pred_pos_rate_ewma =
a * (if predicted_positive { 1.0 } else { 0.0 }) + (1.0 - a) * self.pred_pos_rate_ewma;
self.samples_in_phase += 1;
let no_op = SmoothAdjustments {
lr_multiplier: 1.0,
lambda_direction: 0.0,
};
match self.phase {
SPSAPhase::Init => {
let perf = self.current_performance();
self.perf_variance =
a * (perf - self.perf_ewma_baseline).powi(2) + (1.0 - a) * self.perf_variance;
self.perf_ewma_baseline = a * perf + (1.0 - a) * self.perf_ewma_baseline;
if self.samples_in_phase >= 50 {
let noise_std = self.perf_variance.sqrt();
self.c_init = (2.0 * noise_std).clamp(0.005, 0.08);
self.initialized = true;
self.generate_delta();
self.phase = SPSAPhase::PerturbPlus;
self.samples_in_phase = 0;
let target_theta = self.perturbed_theta(1.0);
return self.adjustment_for_theta(&target_theta);
}
no_op
}
SPSAPhase::PerturbPlus => {
if self.samples_in_phase >= self.observation_interval {
self.perf_plus = self.current_performance();
self.phase = SPSAPhase::PerturbMinus;
self.samples_in_phase = 0;
let target_theta = self.perturbed_theta(-1.0);
return self.adjustment_for_theta(&target_theta);
}
no_op
}
SPSAPhase::PerturbMinus => {
if self.samples_in_phase >= self.observation_interval {
self.perf_minus = self.current_performance();
self.do_spsa_update();
self.generate_delta();
self.phase = SPSAPhase::PerturbPlus;
self.samples_in_phase = 0;
let target_theta = self.perturbed_theta(1.0);
return self.adjustment_for_theta(&target_theta);
}
no_op
}
}
}
pub fn after_train_diagnostics_only(
&mut self,
diagnostics: &ConfigDiagnostics,
) -> SmoothAdjustments {
self.after_train(diagnostics, 0.0, 0.0)
}
pub fn at_replacement(&mut self, diagnostics: &ConfigDiagnostics) -> Option<StructuralChange> {
if !self.initialized {
return None;
}
self.region.update(self.n_samples as usize);
let bounds = self.region.config_bounds();
let needs_more_depth = diagnostics.depth_sufficiency > self.depth_signal_ewma * 1.5
&& bounds.max_depth.1 > bounds.max_depth.0;
let dof_ratio = if self.n_samples > 0 {
diagnostics.effective_dof / self.n_samples as f64
} else {
0.0
};
let target_dof_ratio = (self.region.budget() / self.n_samples as f64).clamp(0.01, 0.5);
let needs_more_steps = dof_ratio < target_dof_ratio * 0.5;
let needs_fewer_steps = dof_ratio > target_dof_ratio * 2.0;
if needs_more_depth || needs_more_steps || needs_fewer_steps {
Some(StructuralChange {
depth_delta: if needs_more_depth { 1 } else { 0 },
steps_delta: if needs_more_steps {
2
} else if needs_fewer_steps {
-2
} else {
0
},
})
} else {
None
}
}
pub fn reset(&mut self) {
self.uncertainty_ewma = 0.0;
self.alignment_ewma = 0.0;
self.reg_sensitivity_ewma = 0.0;
self.depth_signal_ewma = 0.0;
self.dof_ewma = 0.0;
self.n_samples = 0;
self.initialized = false;
self.theta = [0.5, 0.5];
self.theta_best = [0.5, 0.5];
self.best_performance = f64::NEG_INFINITY;
self.a = self.a_init;
self.c_init = 0.1;
self.k_local = 0;
self.phase = SPSAPhase::Init;
self.current_delta = [0.0; 2];
self.perf_plus = 0.0;
self.perf_minus = 0.0;
self.samples_in_phase = 0;
self.cusum_s = 0.0;
self.perf_ewma_baseline = 0.0;
self.perf_variance = 0.0;
self.last_emitted_theta = [0.5, 0.5];
self.total_steps = 0;
self.rng_state = 0xDEAD_BEEF_CAFE_1234;
self.squared_error_ewma = 0.0;
self.target_ewma = 0.0;
self.target_var_ewma = 0.0;
self.direction_ewma = 0.5;
self.tp_ewma = 0.0;
self.fp_ewma = 0.0;
self.fn_ewma = 0.0;
self.accuracy_ewma = 0.5;
self.pos_rate_ewma = 0.5;
self.pred_pos_rate_ewma = 0.5;
}
pub fn region(&self) -> &FeasibleRegion {
&self.region
}
pub fn update_region(&mut self, n_samples: usize, target_variance: f64) {
self.region.update(n_samples);
self.region.update_variance(target_variance);
}
pub fn total_steps(&self) -> u64 {
self.total_steps
}
pub fn objective(&self) -> MetaObjective {
self.objective
}
#[cfg(test)]
fn phase(&self) -> SPSAPhase {
self.phase
}
fn current_performance(&self) -> f64 {
match self.objective {
MetaObjective::MinimizeRMSE => -self.squared_error_ewma.sqrt(),
MetaObjective::MaximizeR2 => {
if self.target_var_ewma > 1e-15 {
1.0 - self.squared_error_ewma / self.target_var_ewma
} else {
0.0
}
}
MetaObjective::MaximizeDirection => self.direction_ewma,
MetaObjective::MaximizeF1 => {
let denom = 2.0 * self.tp_ewma + self.fp_ewma + self.fn_ewma;
if denom > 1e-15 {
2.0 * self.tp_ewma / denom
} else {
0.0
}
}
MetaObjective::MaximizeKappa => {
let expected = self.pos_rate_ewma * self.pred_pos_rate_ewma
+ (1.0 - self.pos_rate_ewma) * (1.0 - self.pred_pos_rate_ewma);
if (1.0 - expected).abs() > 1e-15 {
(self.accuracy_ewma - expected) / (1.0 - expected)
} else {
0.0
}
}
MetaObjective::Composite {
rmse_weight,
r2_weight,
dir_weight,
} => {
let rmse_score = -self.squared_error_ewma.sqrt();
let r2_score = if self.target_var_ewma > 1e-15 {
1.0 - self.squared_error_ewma / self.target_var_ewma
} else {
0.0
};
let dir_score = self.direction_ewma;
rmse_weight * rmse_score + r2_weight * r2_score + dir_weight * dir_score
}
}
}
fn do_spsa_update(&mut self) {
let a_k = self.a / (self.big_a + self.k_local as f64 + 1.0).powf(0.602);
let c_k_base = self.c_init / (self.k_local as f64 + 1.0).powf(0.101);
for i in 0..2 {
let c_k = c_k_base.max(self.c_floor[i]);
if self.current_delta[i].abs() > 0.5 {
let g_hat =
(self.perf_plus - self.perf_minus) / (2.0 * c_k * self.current_delta[i]);
self.theta[i] += a_k * g_hat; self.theta[i] = self.theta[i].clamp(0.0, 1.0);
}
}
if self.perf_plus < self.best_performance && self.perf_minus < self.best_performance {
self.a *= 0.5;
self.theta = self.theta_best;
} else {
let best = self.perf_plus.max(self.perf_minus);
if best > self.best_performance {
self.best_performance = best;
self.theta_best = self.theta;
}
}
let drift_margin = 0.5 * self.perf_variance.sqrt();
let drift_threshold = 5.0 * self.perf_variance.sqrt();
let current = self.current_performance();
self.cusum_s = (self.cusum_s + (self.perf_ewma_baseline - current) - drift_margin).max(0.0);
if self.cusum_s > drift_threshold && drift_threshold > 1e-15 {
self.k_local = 0;
self.a = self.a_init;
self.cusum_s = 0.0;
self.perf_ewma_baseline = current;
}
self.k_local += 1;
self.total_steps += 1;
let perf = self.current_performance();
let a = self.alpha;
self.perf_variance =
a * (perf - self.perf_ewma_baseline).powi(2) + (1.0 - a) * self.perf_variance;
self.perf_ewma_baseline = a * perf + (1.0 - a) * self.perf_ewma_baseline;
}
fn theta_to_config(&self, theta: &[f64; 2]) -> (f64, f64) {
let bounds = self.region.config_bounds();
let lr = bounds.learning_rate.0
* (bounds.learning_rate.1 / bounds.learning_rate.0.max(1e-15)).powf(theta[0]);
let lambda = bounds.lambda.0 + theta[1] * (bounds.lambda.1 - bounds.lambda.0);
(lr.max(1e-10), lambda.max(0.0))
}
fn adjustment_for_theta(&mut self, target: &[f64; 2]) -> SmoothAdjustments {
let (target_lr, target_lambda) = self.theta_to_config(target);
let (last_lr, last_lambda) = self.theta_to_config(&self.last_emitted_theta);
self.last_emitted_theta = *target;
let raw_mult = target_lr / last_lr.max(1e-15);
let dampened_mult = 1.0 + 0.3 * (raw_mult - 1.0);
let raw_dir = target_lambda - last_lambda;
let dampened_dir = 0.3 * raw_dir;
SmoothAdjustments {
lr_multiplier: dampened_mult,
lambda_direction: dampened_dir,
}
}
fn perturbed_theta(&self, sign: f64) -> [f64; 2] {
let c_k_base = self.c_init / (self.k_local as f64 + 1.0).powf(0.101);
let mut result = [0.0; 2];
for (i, val) in result.iter_mut().enumerate() {
let c_k = c_k_base.max(self.c_floor[i]);
*val = (self.theta[i] + sign * c_k * self.current_delta[i]).clamp(0.0, 1.0);
}
result
}
fn generate_delta(&mut self) {
for d in &mut self.current_delta {
self.rng_state ^= self.rng_state << 13;
self.rng_state ^= self.rng_state >> 7;
self.rng_state ^= self.rng_state << 17;
*d = if self.rng_state % 2 == 0 { 1.0 } else { -1.0 };
}
}
}
pub type DiagnosticAdaptor = DiagnosticLearner;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn feasible_region_sparse_data() {
let region = FeasibleRegion::from_data(100, 3, 1.0);
let bounds = region.config_bounds();
assert!(
bounds.max_depth.1 <= 4,
"sparse data (n=100) should have tight depth: got max {}",
bounds.max_depth.1
);
assert!(
bounds.n_steps.1 <= 15,
"sparse data (n=100) should have tight n_steps: got max {}",
bounds.n_steps.1
);
}
#[test]
fn feasible_region_abundant_data() {
let region = FeasibleRegion::from_data(10_000, 3, 1.0);
let bounds = region.config_bounds();
assert!(
bounds.max_depth.1 >= 4,
"abundant data (n=10000) should allow deeper trees: got max {}",
bounds.max_depth.1
);
assert!(
bounds.n_steps.1 >= 20,
"abundant data (n=10000) should allow more steps: got max {}",
bounds.n_steps.1
);
}
#[test]
fn feasible_region_center_config_valid() {
let region = FeasibleRegion::from_data(500, 5, 2.0);
let config = region.center_config();
assert!(config.n_steps > 0, "center n_steps must be > 0");
assert!(config.max_depth > 0, "center max_depth must be > 0");
assert!(
config.learning_rate > 0.0 && config.learning_rate <= 1.0,
"center learning_rate must be in (0, 1]"
);
}
#[test]
fn feasible_region_perturbations() {
let region = FeasibleRegion::from_data(500, 5, 2.0);
let configs = region.perturbation_configs();
assert!(
configs.len() > 1,
"perturbation_configs should produce > 1 configs, got {}",
configs.len()
);
for (i, cfg) in configs.iter().enumerate() {
assert!(cfg.n_steps > 0, "config[{i}] n_steps must be > 0");
assert!(cfg.max_depth > 0, "config[{i}] max_depth must be > 0");
assert!(
cfg.learning_rate > 0.0,
"config[{i}] learning_rate must be > 0"
);
}
}
#[test]
fn feasible_region_update_expands() {
let mut region = FeasibleRegion::from_data(100, 3, 1.0);
let budget_before = region.budget();
region.update(10_000);
assert!(
region.budget() > budget_before,
"budget should increase with more data: before={budget_before}, after={}",
region.budget()
);
}
#[test]
fn welford_race_all_see_all() {
let region = FeasibleRegion::from_data(200, 2, 1.0);
let configs = region.perturbation_configs();
let n_configs = configs.len();
let mut race = WelfordRace::new(configs);
for i in 0..100 {
let x = i as f64 * 0.1;
race.feed(&[x, x * 0.5], x * 2.0 + 1.0);
}
assert_eq!(race.n_candidates(), n_configs);
assert_eq!(race.n_samples(), 100);
let (_winner, results) = race.select_winner();
for (idx, _mean, _se, n) in &results.all_results {
assert_eq!(
*n, 100,
"config {idx} should have seen 100 samples, got {n}"
);
}
}
#[test]
fn welford_race_selects_best() {
let region = FeasibleRegion::from_data(500, 1, 1.0);
let configs = region.perturbation_configs();
let mut race = WelfordRace::new(configs);
for i in 0..200 {
let x = i as f64 * 0.01;
let noise = ((i * 7 + 3) % 11) as f64 * 0.001 - 0.005;
race.feed(&[x], 2.0 * x + noise);
}
let (_winner, results) = race.select_winner();
let winner_mean = results.winner_mean_error;
for (_, mean, _, _) in &results.all_results {
assert!(
winner_mean <= *mean + 1e-12,
"winner mean {winner_mean} should be <= all others, found {mean}"
);
}
}
#[test]
fn welford_stats_accuracy() {
let mut stats = WelfordStats::default();
let values = [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
for v in &values {
stats.update(*v);
}
let expected_mean = 5.0;
let expected_variance = 4.571428571428571;
assert!(
(stats.mean_error - expected_mean).abs() < 1e-10,
"mean should be {expected_mean}, got {}",
stats.mean_error
);
assert!(
(stats.variance() - expected_variance).abs() < 1e-10,
"variance should be {expected_variance}, got {}",
stats.variance()
);
}
fn make_race() -> WelfordRace {
let region = FeasibleRegion::from_data(200, 2, 1.0);
WelfordRace::new(region.perturbation_configs())
}
#[test]
fn terminate_after_samples_freezes_state() {
let mut race = make_race().with_termination(TerminateAfter::Samples(500));
assert!(
!race.is_terminated(),
"should not be terminated before any feed"
);
assert_eq!(
race.samples_until_termination(),
Some(500),
"should report 500 remaining at start"
);
for i in 0..499 {
let x = i as f64 * 0.001;
race.feed(&[x, x], x * 2.0);
}
assert!(
!race.is_terminated(),
"should not terminate after 499 samples (threshold is 500)"
);
assert_eq!(
race.samples_until_termination(),
Some(1),
"should report 1 remaining after 499 feeds"
);
race.feed(&[1.0, 1.0], 2.0);
assert!(
race.is_terminated(),
"should terminate after exactly 500 samples"
);
assert_eq!(
race.samples_until_termination(),
Some(0),
"samples_until_termination should be 0 after termination"
);
let n_before = race.n_samples();
race.feed(&[2.0, 2.0], 4.0);
race.feed(&[3.0, 3.0], 6.0);
assert_eq!(
race.n_samples(),
n_before,
"feed after termination must be a no-op: samples should not increase"
);
assert!(
race.is_terminated(),
"is_terminated must remain true after extra feeds (monotonic)"
);
}
#[test]
fn terminate_after_corrections_freezes_state() {
const K: usize = 5;
let mut race = make_race().with_termination(TerminateAfter::Corrections(K));
for i in 0..50 {
let x = i as f64 * 0.01;
race.feed(&[x, x], x * 2.0);
}
assert!(
!race.is_terminated(),
"should not terminate before K corrections"
);
for _ in 0..(K - 1) {
race.signal_correction();
}
assert!(
!race.is_terminated(),
"should not terminate after K-1 corrections (threshold is K)"
);
race.signal_correction();
assert!(
race.is_terminated(),
"should terminate on the K-th signal_correction call"
);
race.signal_correction();
assert!(
race.is_terminated(),
"is_terminated must remain true after extra signal_correction (monotonic)"
);
}
#[test]
#[ignore = "Duration-based termination is non-deterministic on real wall-clock time"]
fn terminate_after_duration_freezes_state() {
use std::time::Duration;
let mut race =
make_race().with_termination(TerminateAfter::Duration(Duration::from_millis(50)));
race.feed(&[1.0, 1.0], 2.0);
std::thread::sleep(Duration::from_millis(100));
race.feed(&[2.0, 2.0], 4.0);
assert!(
race.is_terminated(),
"should terminate after wall-clock duration elapses"
);
let n_before = race.n_samples();
race.feed(&[3.0, 3.0], 6.0);
assert_eq!(
race.n_samples(),
n_before,
"feed after duration-termination must be a no-op"
);
}
#[test]
fn terminate_after_never_default_back_compat() {
let mut race = make_race();
for i in 0..1_000 {
let x = i as f64 * 0.001;
race.feed(&[x, x], x * 2.0);
}
assert!(
!race.is_terminated(),
"TerminateAfter::Never must never terminate, even after 1000 feeds"
);
assert_eq!(
race.samples_until_termination(),
None,
"samples_until_termination must be None for TerminateAfter::Never"
);
}
#[test]
fn is_terminated_is_monotonic() {
let mut race = make_race().with_termination(TerminateAfter::Samples(10));
for i in 0..10 {
race.feed(&[i as f64, i as f64], i as f64 * 2.0);
}
assert!(race.is_terminated(), "should be terminated after 10 feeds");
race.feed(&[99.0, 99.0], 198.0);
race.signal_correction();
race.feed(&[100.0, 100.0], 200.0);
assert!(
race.is_terminated(),
"is_terminated must remain true (monotonic invariant violated)"
);
}
#[test]
fn samples_until_termination_decrements_correctly() {
let mut race = make_race().with_termination(TerminateAfter::Samples(5));
assert_eq!(
race.samples_until_termination(),
Some(5),
"remaining should be 5 at start"
);
for step in 0..5usize {
let x = step as f64;
race.feed(&[x, x], x * 2.0);
let remaining = race.samples_until_termination();
let expected = Some((4 - step) as u64);
assert_eq!(
remaining,
expected,
"after {} feeds, remaining should be {:?}, got {:?}",
step + 1,
expected,
remaining
);
}
assert!(
race.is_terminated(),
"race must be terminated after all 5 feeds"
);
}
#[test]
fn diagnostic_learner_init_phase_no_adjustments() {
let region = FeasibleRegion::from_data(200, 3, 1.0);
let mut learner = DiagnosticLearner::new(region);
let diag = ConfigDiagnostics {
residual_alignment: 0.5,
regularization_sensitivity: 1.0,
depth_sufficiency: 0.5,
effective_dof: 10.0,
uncertainty: 0.1,
};
for _ in 0..49 {
let adj = learner.after_train(&diag, 0.5, 1.0);
assert_eq!(
adj.lr_multiplier, 1.0,
"during init phase, lr_multiplier should be 1.0"
);
assert_eq!(
adj.lambda_direction, 0.0,
"during init phase, lambda_direction should be 0.0"
);
}
assert_eq!(
learner.phase(),
SPSAPhase::Init,
"should still be in Init phase after 49 samples"
);
}
#[test]
fn diagnostic_learner_phase_cycling() {
let region = FeasibleRegion::from_data(200, 3, 1.0);
let bounds = region.config_bounds();
let interval = ((bounds.grace_period.0 + bounds.grace_period.1) / 2).clamp(1, 50) as u64;
let mut learner = DiagnosticLearner::new(region);
let diag = ConfigDiagnostics {
residual_alignment: 0.5,
regularization_sensitivity: 1.0,
depth_sufficiency: 0.5,
effective_dof: 10.0,
uncertainty: 0.1,
};
for i in 0..50 {
learner.after_train(&diag, i as f64 * 0.01, i as f64 * 0.01 + 0.1);
}
assert_eq!(
learner.phase(),
SPSAPhase::PerturbPlus,
"should be PerturbPlus after init phase"
);
for i in 0..interval {
let idx = 50 + i;
learner.after_train(&diag, idx as f64 * 0.01, idx as f64 * 0.01 + 0.1);
}
assert_eq!(
learner.phase(),
SPSAPhase::PerturbMinus,
"should be PerturbMinus after completing PerturbPlus"
);
for i in 0..interval {
let idx = 50 + interval + i;
learner.after_train(&diag, idx as f64 * 0.01, idx as f64 * 0.01 + 0.1);
}
assert_eq!(
learner.phase(),
SPSAPhase::PerturbPlus,
"should cycle back to PerturbPlus after PerturbMinus"
);
}
#[test]
fn diagnostic_learner_theta_bounds_clamping() {
let region = FeasibleRegion::from_data(10_000, 5, 1.0);
let mut learner = DiagnosticLearner::new(region);
let diag = ConfigDiagnostics {
residual_alignment: 0.9,
regularization_sensitivity: 0.1,
depth_sufficiency: 0.5,
effective_dof: 10.0,
uncertainty: 0.1,
};
for i in 0..5_000 {
let pred = i as f64 * 0.01;
let target = pred + 0.01;
learner.after_train(&diag, pred, target);
}
assert!(
learner.theta[0] >= 0.0 && learner.theta[0] <= 1.0,
"theta[0] must be in [0, 1], got {}",
learner.theta[0]
);
assert!(
learner.theta[1] >= 0.0 && learner.theta[1] <= 1.0,
"theta[1] must be in [0, 1], got {}",
learner.theta[1]
);
}
#[test]
fn diagnostic_learner_backward_compat_alias() {
let region = FeasibleRegion::from_data(200, 3, 1.0);
let mut adaptor: DiagnosticAdaptor = DiagnosticAdaptor::new(region);
let diag = ConfigDiagnostics {
residual_alignment: 0.5,
..Default::default()
};
let adj = adaptor.after_train(&diag, 0.0, 0.0);
assert_eq!(
adj.lr_multiplier, 1.0,
"backward compat alias: init phase should return no-op"
);
let adj2 = adaptor.after_train_diagnostics_only(&diag);
assert_eq!(
adj2.lr_multiplier, 1.0,
"backward compat alias: diagnostics_only should return no-op"
);
}
#[test]
fn diagnostic_learner_structural_change() {
let region = FeasibleRegion::from_data(10_000, 3, 10.0);
let mut learner = DiagnosticLearner::new(region);
let init_diag = ConfigDiagnostics {
depth_sufficiency: 0.1,
effective_dof: 5.0,
..Default::default()
};
for _ in 0..500 {
learner.after_train(&init_diag, 0.5, 1.0);
}
let mut check_region = learner.region().clone();
check_region.update(500);
let bounds = check_region.config_bounds();
assert!(
bounds.max_depth.1 > bounds.max_depth.0,
"region must have depth headroom for this test: bounds={:?}",
bounds.max_depth
);
let high_depth_diag = ConfigDiagnostics {
depth_sufficiency: 1.0,
effective_dof: 5.0,
..Default::default()
};
let change = learner.at_replacement(&high_depth_diag);
assert!(
change.is_some(),
"high depth_sufficiency should trigger structural change"
);
let change = change.unwrap();
assert!(
change.depth_delta > 0,
"should suggest increasing depth, got delta={}",
change.depth_delta
);
}
#[test]
fn diagnostic_learner_reset_clears_state() {
let region = FeasibleRegion::from_data(1_000, 3, 1.0);
let mut learner = DiagnosticLearner::new(region);
let diag = ConfigDiagnostics {
residual_alignment: 0.5,
..Default::default()
};
for i in 0..200 {
learner.after_train(&diag, i as f64 * 0.01, i as f64 * 0.01 + 0.1);
}
learner.reset();
assert_eq!(
learner.total_steps(),
0,
"total_steps should be 0 after reset"
);
assert_eq!(
learner.phase(),
SPSAPhase::Init,
"phase should be Init after reset"
);
assert_eq!(
learner.theta,
[0.5, 0.5],
"theta should be [0.5, 0.5] after reset"
);
let adj = learner.after_train(&diag, 0.0, 0.0);
assert_eq!(
adj.lr_multiplier, 1.0,
"after reset, first sample should be init-phase no-op, got lr={}",
adj.lr_multiplier
);
assert_eq!(
adj.lambda_direction, 0.0,
"after reset, first sample should be init-phase no-op, got lambda={}",
adj.lambda_direction
);
}
#[test]
fn diagnostic_learner_meta_objective_default() {
let region = FeasibleRegion::from_data(200, 3, 1.0);
let learner = DiagnosticLearner::new(region);
assert!(
matches!(learner.objective(), MetaObjective::MinimizeRMSE),
"default objective should be MinimizeRMSE"
);
}
#[test]
fn diagnostic_learner_with_custom_objective() {
let region = FeasibleRegion::from_data(200, 3, 1.0);
let learner = DiagnosticLearner::with_objective(region, MetaObjective::MaximizeF1);
assert!(
matches!(learner.objective(), MetaObjective::MaximizeF1),
"objective should be MaximizeF1"
);
}
fn run_learner_total_lr(n_calls: u64, diag: &ConfigDiagnostics) -> f64 {
let region = FeasibleRegion::from_data(50_000, 5, 1.0);
let mut learner = DiagnosticLearner::new(region);
for i in 0..50 {
learner.after_train(diag, i as f64 * 0.01, i as f64 * 0.01 + 0.1);
}
let mut total_lr_log = 0.0_f64;
for i in 0..n_calls {
let pred = (50 + i) as f64 * 0.01;
let target = pred + 0.1;
let adj = learner.after_train(diag, pred, target);
total_lr_log += adj.lr_multiplier.ln();
}
total_lr_log.exp()
}
#[test]
fn spsa_bounded_total_adjustment() {
let diag = ConfigDiagnostics {
residual_alignment: 0.5,
..Default::default()
};
let total_100 = run_learner_total_lr(100, &diag);
let total_10000 = run_learner_total_lr(10_000, &diag);
let total_40000 = run_learner_total_lr(40_000, &diag);
assert!(
total_100.is_finite(),
"total LR after 100 calls must be finite, got {total_100}"
);
assert!(
total_10000.is_finite(),
"total LR after 10000 calls must be finite, got {total_10000}"
);
assert!(
total_40000.is_finite(),
"total LR after 40000 calls must be finite, got {total_40000}"
);
}
#[test]
fn train_one_weighted_default_ignores_weight() {
let region = FeasibleRegion::from_data(200, 1, 1.0);
let mut race = WelfordRace::new(region.perturbation_configs());
race.feed(&[1.0], 2.0);
let n_before = race.n_samples();
race.feed(&[1.0], 2.0);
let n_after = race.n_samples();
assert_eq!(
n_after,
n_before + 1,
"n_samples must increment by 1 per feed regardless of distillation state"
);
}
#[test]
fn distillation_disabled_by_default() {
let region = FeasibleRegion::from_data(200, 1, 1.0);
let mut race = WelfordRace::new(region.perturbation_configs());
for i in 0..10 {
race.feed(&[i as f64 * 0.1], i as f64 * 0.2);
}
#[cfg(feature = "distill")]
{
let stats = race.distillation_stats();
assert!(
stats.disabled,
"distillation_stats().disabled must be true when with_distillation not called"
);
assert_eq!(
stats.n_distillations_triggered, 0,
"no distillations should have triggered without configuration"
);
}
assert_eq!(
race.n_samples(),
10,
"n_samples must equal 10 after 10 feeds"
);
}
#[cfg(feature = "distill")]
#[test]
fn distillation_replay_buffer_respects_size_limit() {
let region = FeasibleRegion::from_data(200, 1, 1.0);
let cfg = DistillationConfig {
trigger_after_dominated_samples: 500,
replay_buffer_size: 5,
distill_weight: 0.3,
};
let mut race = WelfordRace::new(region.perturbation_configs()).with_distillation(cfg);
for i in 0..20 {
race.feed(&[i as f64 * 0.1], i as f64 * 0.2);
}
let stats = race.distillation_stats();
assert!(
!stats.disabled,
"distillation_stats().disabled must be false after with_distillation"
);
assert!(
stats.candidates_currently_distilling.is_empty(),
"no candidates should be distilling after 20 samples with trigger=500"
);
}
#[cfg(feature = "distill")]
#[test]
fn distillation_triggers_after_dominated_samples() {
let region = FeasibleRegion::from_data(200, 1, 1.0);
let trigger = 3u64;
let cfg = DistillationConfig {
trigger_after_dominated_samples: trigger,
replay_buffer_size: 100,
distill_weight: 0.3,
};
let mut race = WelfordRace::new(region.perturbation_configs()).with_distillation(cfg);
for i in 0..(trigger + 5) {
let x = i as f64 * 0.1;
race.feed(&[x], x * 2.0 + 1.0);
}
let stats = race.distillation_stats();
assert!(
stats.n_distillations_triggered > 0,
"distillation must trigger after {} dominated samples; got {} triggers",
trigger,
stats.n_distillations_triggered
);
}
#[cfg(feature = "distill")]
#[test]
fn distillation_distill_weight_clamped_to_unit_interval() {
let region = FeasibleRegion::from_data(200, 1, 1.0);
let cfg = DistillationConfig {
trigger_after_dominated_samples: 500,
replay_buffer_size: 100,
distill_weight: 5.0,
};
let race = WelfordRace::new(region.perturbation_configs()).with_distillation(cfg);
assert!(
!race.distillation_stats().disabled,
"race with clamped weight must be active (disabled=false)"
);
let region2 = FeasibleRegion::from_data(200, 1, 1.0);
let cfg2 = DistillationConfig {
trigger_after_dominated_samples: 500,
replay_buffer_size: 100,
distill_weight: -1.0,
};
let race2 = WelfordRace::new(region2.perturbation_configs()).with_distillation(cfg2);
assert!(
!race2.distillation_stats().disabled,
"race with negative weight (clamped) must be active"
);
}
#[test]
fn pareto_single_candidate_wins() {
let region = FeasibleRegion::from_data(200, 1, 1.0);
let configs = vec![region.center_config()];
let mut race = WelfordRace::new(configs);
for i in 0..50 {
let x = i as f64 * 0.01;
race.feed(&[x], 2.0 * x + 0.1);
}
let front = race.pareto_front();
assert_eq!(
front.len(),
1,
"single-candidate race must have front of size 1, got {}",
front.len()
);
assert_eq!(front[0], 0, "single candidate must be at position 0");
let winner = race.pareto_winner_idx();
assert_eq!(
winner,
Some(0),
"single-candidate pareto_winner_idx must be Some(0), got {winner:?}"
);
}
#[test]
fn pareto_no_samples_returns_none() {
let region = FeasibleRegion::from_data(200, 1, 1.0);
let race = WelfordRace::new(region.perturbation_configs());
let winner = race.pareto_winner_idx();
assert_eq!(
winner, None,
"pareto_winner_idx before any feed must be None, got {winner:?}"
);
}
#[test]
fn pareto_dominated_candidate_excluded() {
let region = FeasibleRegion::from_data(200, 1, 1.0);
let configs = vec![region.center_config(), region.center_config()];
let mut race = WelfordRace::new(configs);
let n = 200u64;
let m2_a = 0.01 * (n - 1) as f64; let m2_b = 0.50 * (n - 1) as f64; race.inject_stats_for_test(0, n, 0.10, m2_a, 180); race.inject_stats_for_test(1, n, 0.80, m2_b, 100);
let front = race.pareto_front();
assert!(
front.contains(&0),
"candidate 0 (lower error + higher dir_accuracy) must be on front: front={front:?}"
);
assert!(
!front.contains(&1),
"candidate 1 (dominated) must NOT be on front: front={front:?}"
);
let winner = race.pareto_winner_idx();
assert_eq!(
winner,
Some(0),
"pareto_winner_idx must pick the dominating candidate (0): got {winner:?}"
);
}
#[test]
fn pareto_non_dominated_pair_returns_both_in_front() {
let region = FeasibleRegion::from_data(200, 1, 1.0);
let configs = vec![region.center_config(), region.center_config()];
let mut race = WelfordRace::new(configs);
let n = 200u64;
let m2 = 0.10 * (n - 1) as f64;
race.inject_stats_for_test(0, n, 0.20, m2, 100); race.inject_stats_for_test(1, n, 0.50, m2, 180);
let front = race.pareto_front();
assert!(
front.contains(&0),
"candidate 0 (lower mean_error) must be on front: front={front:?}"
);
assert!(
front.contains(&1),
"candidate 1 (better dir_accuracy) must be on front: front={front:?}"
);
assert_eq!(
front.len(),
2,
"non-dominated pair must produce front of size 2, got {}",
front.len()
);
let winner = race.pareto_winner_idx();
assert_eq!(
winner,
Some(0),
"min(mean_error) tiebreak must select candidate 0 (0.20 < 0.50): got {winner:?}"
);
}
#[test]
fn pareto_matches_scalar_when_single_metric_varies() {
let region = FeasibleRegion::from_data(200, 1, 1.0);
let configs = vec![region.center_config(), region.center_config()];
let mut race = WelfordRace::new(configs);
let n = 200u64;
let m2 = 0.10 * (n - 1) as f64;
let dir_correct = 150u64; race.inject_stats_for_test(0, n, 0.15, m2, dir_correct);
race.inject_stats_for_test(1, n, 0.45, m2, dir_correct);
let pareto_winner = race.pareto_winner_idx();
let scalar_winner = race.current_winner_idx_scalar();
assert_eq!(
pareto_winner, scalar_winner,
"when only mean_error differs, Pareto winner must match scalar winner: \
pareto={pareto_winner:?}, scalar={scalar_winner:?}"
);
assert_eq!(
pareto_winner,
Some(0),
"lower mean_error (0.15) must win: got {pareto_winner:?}"
);
}
#[test]
fn pareto_handles_nan_signals() {
let region = FeasibleRegion::from_data(200, 1, 1.0);
let configs = vec![region.center_config(), region.center_config()];
let mut race = WelfordRace::new(configs);
let n = 200u64;
let m2 = 0.10 * (n - 1) as f64;
race.inject_stats_for_test(0, n, 0.20, m2, 150);
let front = race.pareto_front();
assert!(
front.contains(&0),
"valid candidate must be on front: front={front:?}"
);
assert!(
!front.contains(&1),
"NaN-signal candidate (n=0 -> NaN dir_accuracy) must be excluded: front={front:?}"
);
let winner = race.pareto_winner_idx();
assert_eq!(
winner,
Some(0),
"valid candidate must win when other has NaN signal: got {winner:?}"
);
}
#[test]
fn pareto_invokes_bernstein_tiebreak_for_multi_front() {
let region = FeasibleRegion::from_data(200, 1, 1.0);
let configs = vec![region.center_config(), region.center_config()];
let mut race = WelfordRace::new(configs);
let n = 2000u64;
let var = 0.0001f64;
let m2 = var * (n - 1) as f64;
race.inject_stats_for_test(0, n, 0.10, m2, 1000); race.inject_stats_for_test(1, n, 0.50, m2, 1800);
let front = race.pareto_front();
assert_eq!(
front.len(),
2,
"both candidates must be on the front for Bernstein tiebreak to trigger: front={front:?}"
);
let winner = race.pareto_winner_idx();
assert_eq!(
winner,
Some(0),
"Bernstein tiebreak must select arm 0 (hi_ci_0 < lo_ci_1): got {winner:?}"
);
}
#[test]
fn samples_since_change_increments_per_feed() {
let region = FeasibleRegion::from_data(200, 1, 1.0);
let mut race = WelfordRace::new(region.perturbation_configs());
assert_eq!(
race.samples_since_last_winner_change(),
0,
"before feeds, samples_since_last_winner_change should be 0"
);
let mut prev = race.samples_since_last_winner_change();
let mut any_increment = false;
for i in 0..50 {
let x = i as f64 * 0.1;
race.feed(&[x], x * 2.0 + 1.0);
let now = race.samples_since_last_winner_change();
if now > prev {
any_increment = true;
}
prev = now;
}
assert!(
any_increment,
"samples_since_last_winner_change must increment per feed during stable regime"
);
}
#[test]
fn samples_since_change_resets_on_winner_flip() {
let region = FeasibleRegion::from_data(200, 1, 1.0);
let mut race = WelfordRace::new(region.perturbation_configs());
for i in 0..20 {
let x = i as f64 * 0.1;
race.feed(&[x], x * 2.0 + 1.0);
}
let counter = race.samples_since_last_winner_change();
let total = race.n_samples();
assert!(
counter <= total,
"samples_since_last_winner_change ({counter}) must be <= n_samples ({total})"
);
}
#[test]
fn winner_change_count_is_monotonically_non_decreasing() {
let region = FeasibleRegion::from_data(200, 1, 1.0);
let mut race = WelfordRace::new(region.perturbation_configs());
let mut prev_count = race.winner_change_count();
for i in 0..100 {
let x = i as f64 * 0.1;
let target = if i % 20 < 10 { x * 2.0 + 1.0 } else { -x * 0.5 };
race.feed(&[x], target);
let now = race.winner_change_count();
assert!(
now >= prev_count,
"winner_change_count must be monotonically non-decreasing: was {prev_count}, now {now} at step {i}"
);
prev_count = now;
}
}
#[test]
fn race_drift_score_returns_zero_until_buffer_half_full() {
let region = FeasibleRegion::from_data(5000, 1, 1.0);
let mut race = WelfordRace::new(region.perturbation_configs());
for i in 0..511 {
let x = i as f64 * 0.01;
race.feed(&[x], x * 2.0 + 1.0);
}
assert_eq!(
race.race_drift_score(),
0.0,
"drift_score must be 0.0 when fewer than 512 errors collected (cold-start convention)"
);
}
#[test]
fn race_drift_score_positive_when_error_growing() {
let region = FeasibleRegion::from_data(5000, 1, 1.0);
let mut race = WelfordRace::new(region.perturbation_configs());
for i in 0..1200 {
let x = i as f64 * 0.001;
race.feed(&[x], x * 2.0 + 0.5);
}
for i in 0..512 {
let sign = if i % 2 == 0 { 1.0_f64 } else { -1.0_f64 };
race.feed(&[i as f64 * 0.001], sign * 1000.0);
}
let score = race.race_drift_score();
assert!(
score > 0.0,
"drift_score should be positive when errors are growing (got {score})"
);
}
#[test]
fn race_drift_score_negative_when_winner_improving() {
let region = FeasibleRegion::from_data(5000, 1, 1.0);
let mut race = WelfordRace::new(region.perturbation_configs());
for i in 0..700 {
let x = i as f64 * 0.01;
let noise = ((i * 31 + 7) % 100) as f64 * 10.0;
race.feed(&[x], x * 2.0 + noise);
}
for i in 700..1400 {
let x = i as f64 * 0.01;
race.feed(&[x], x * 2.0 + 1.0);
}
let score = race.race_drift_score();
assert!(
score < 0.0,
"drift_score should be negative when winner is improving (got {score})"
);
}
}