use crate::ensemble::SGBT;
use crate::loss::Loss;
use crate::sample::Observation;
use crate::tree::node::NodeId;
impl<L: Loss> SGBT<L> {
pub fn train_one(&mut self, sample: &impl Observation) {
self.samples_seen += 1;
let target = sample.target();
let features = sample.features();
if !self.base_initialized {
self.initial_targets.push(target);
if self.initial_targets.len() >= self.initial_target_count {
self.base_prediction = self.loss.initial_prediction(&self.initial_targets);
self.base_initialized = true;
self.initial_targets.clear();
self.initial_targets.shrink_to_fit();
}
}
if self.config.adaptive_mts.is_some() {
let contribution_sigma = self.compute_contribution_sigma(features);
const CONTRIBUTION_SIGMA_ALPHA: f64 = 0.001;
self.rolling_contribution_sigma = (1.0 - CONTRIBUTION_SIGMA_ALPHA)
* self.rolling_contribution_sigma
+ CONTRIBUTION_SIGMA_ALPHA * contribution_sigma;
let sigma_ratio = if self.rolling_contribution_sigma > 1e-12 {
contribution_sigma / self.rolling_contribution_sigma
} else {
1.0
};
let cap = self.config.grace_period;
if self.sigma_ring.len() >= cap {
self.sigma_ring.pop_front();
}
self.sigma_ring.push_back(sigma_ratio);
}
let mut current_pred = self.base_prediction;
let prune_alpha = self.config.quality_prune_alpha;
let prune_threshold = self.config.quality_prune_threshold;
let prune_patience = self.config.quality_prune_patience;
let error_weight = if let Some(ew_alpha) = self.config.error_weight_alpha {
let abs_error = (target - current_pred).abs();
if self.rolling_mean_error > 1e-15 {
let w = (1.0 + abs_error / (self.rolling_mean_error + 1e-15)).min(10.0);
self.rolling_mean_error =
ew_alpha * abs_error + (1.0 - ew_alpha) * self.rolling_mean_error;
w
} else {
self.rolling_mean_error = abs_error.max(1e-15);
1.0 }
} else {
1.0
};
for s in 0..self.steps.len() {
let gradient = self.loss.gradient(target, current_pred) * error_weight;
let hessian = self.loss.hessian(target, current_pred) * error_weight;
let train_count = self
.config
.variant
.train_count(hessian, &mut self.rng_state);
let step_pred =
self.steps[s].train_and_predict(features, gradient, hessian, train_count);
current_pred += self.config.learning_rate * step_pred;
if let Some(alpha) = prune_alpha {
let contribution = (self.config.learning_rate * step_pred).abs();
self.contribution_ewma[s] =
alpha * contribution + (1.0 - alpha) * self.contribution_ewma[s];
if self.contribution_ewma[s] < prune_threshold {
self.low_contrib_count[s] += 1;
if self.low_contrib_count[s] >= prune_patience {
self.steps[s].reset();
self.contribution_ewma[s] = 0.0;
self.low_contrib_count[s] = 0;
}
} else {
self.low_contrib_count[s] = 0;
}
}
}
if let Some(interval) = self.config.proactive_prune_interval {
if self.config.accuracy_based_pruning {
let mut ensemble_pred = self.base_prediction;
for step in self.steps.iter() {
ensemble_pred += self.config.learning_rate * step.predict(features);
}
let residual = target - ensemble_pred;
let sign = residual.signum();
for (i, step) in self.steps.iter().enumerate() {
let contribution = self.config.learning_rate * step.predict(features);
let alignment = contribution * sign;
self.diag.contribution_accuracy[i] = self.diag.prune_alpha * alignment
+ (1.0 - self.diag.prune_alpha) * self.diag.contribution_accuracy[i];
}
}
if interval > 0 && self.samples_seen % interval == 0 {
self.check_proactive_prune();
}
}
if let Some((base_mts, k)) = self.config.adaptive_mts {
let current_sum: u64 = self.steps.iter().map(|s| s.slot().replacements()).sum();
if current_sum != self.mts_replacement_sum {
self.mts_replacement_sum = current_sum;
if !self.sigma_ring.is_empty() {
let mean_sigma =
self.sigma_ring.iter().sum::<f64>() / self.sigma_ring.len() as f64;
let floor = (base_mts as f64 * self.config.adaptive_mts_floor).max(100.0);
let effective_mts =
(base_mts as f64 / (1.0 + k * mean_sigma)).max(floor) as u64;
for step in &mut self.steps {
step.slot_mut().set_max_tree_samples(Some(effective_mts));
}
}
}
}
self.update_diagnostic_cache(features);
self.refresh_bandwidths();
}
pub(crate) fn update_diagnostic_cache(&mut self, features: &[f64]) {
let lambda = self.config.lambda;
let lr = self.config.learning_rate;
let n_steps = self.steps.len();
let mut contributions = Vec::with_capacity(n_steps);
for step in &self.steps {
contributions.push(lr * step.predict(features));
}
if !self.diag.prev_contributions.is_empty()
&& self.diag.prev_contributions.len() == contributions.len()
&& !self.diag.prev_prev_contributions.is_empty()
&& self.diag.prev_prev_contributions.len() == contributions.len()
{
let delta_curr: Vec<f64> = contributions
.iter()
.zip(&self.diag.prev_contributions)
.map(|(a, b)| a - b)
.collect();
let delta_prev: Vec<f64> = self
.diag
.prev_contributions
.iter()
.zip(&self.diag.prev_prev_contributions)
.map(|(a, b)| a - b)
.collect();
let dot: f64 = delta_curr.iter().zip(&delta_prev).map(|(a, b)| a * b).sum();
let norm_curr: f64 = delta_curr.iter().map(|x| x * x).sum::<f64>().sqrt();
let norm_prev: f64 = delta_prev.iter().map(|x| x * x).sum::<f64>().sqrt();
self.diag.cached_residual_alignment = if norm_curr > 1e-15 && norm_prev > 1e-15 {
dot / (norm_curr * norm_prev)
} else {
0.0
};
}
self.diag.prev_prev_contributions =
std::mem::replace(&mut self.diag.prev_contributions, contributions);
let mut total_sensitivity = 0.0;
let mut total_dof = 0.0;
let mut leaf_weights: Vec<f64> = Vec::new();
let mut leaf_within_vars: Vec<f64> = Vec::new();
let mut n_leaves_total: u64 = 0;
for step in &self.steps {
let tree = step.slot().active_tree();
let arena = tree.arena();
for node_idx in 0..arena.n_nodes() {
let nid = NodeId(node_idx as u32);
if arena.is_leaf(nid) {
if let Some((g, h)) = tree.leaf_grad_hess(nid) {
let denom = h + lambda;
if denom.abs() > 1e-15 {
total_sensitivity += g.abs() / (denom * denom);
total_dof += h / denom;
leaf_weights.push(-g / denom);
leaf_within_vars.push(1.0 / denom);
n_leaves_total += 1;
}
}
}
}
}
if n_leaves_total > 0 {
let n = n_leaves_total as f64;
self.diag.cached_reg_sensitivity = total_sensitivity / n;
self.diag.cached_effective_dof = total_dof;
let mean_weight = leaf_weights.iter().sum::<f64>() / n;
let between_var = leaf_weights
.iter()
.map(|w| (w - mean_weight).powi(2))
.sum::<f64>()
/ (n - 1.0).max(1.0);
let within_var = leaf_within_vars.iter().sum::<f64>() / n;
self.diag.cached_depth_sufficiency = between_var / within_var.max(1e-15);
}
}
pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
for sample in samples {
self.train_one(sample);
}
}
pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
&mut self,
samples: &[O],
interval: usize,
mut callback: F,
) {
let interval = interval.max(1);
for (i, sample) in samples.iter().enumerate() {
self.train_one(sample);
if (i + 1) % interval == 0 {
callback(i + 1);
}
}
let total = samples.len();
if total % interval != 0 {
callback(total);
}
}
pub fn train_batch_subsampled<O: Observation>(&mut self, samples: &[O], max_samples: usize) {
if max_samples >= samples.len() {
self.train_batch(samples);
return;
}
let mut reservoir: Vec<usize> = (0..max_samples).collect();
let mut rng = self.rng_state;
for i in max_samples..samples.len() {
rng ^= rng << 13;
rng ^= rng >> 7;
rng ^= rng << 17;
let j = (rng % (i as u64 + 1)) as usize;
if j < max_samples {
reservoir[j] = i;
}
}
self.rng_state = rng;
reservoir.sort_unstable();
for &idx in &reservoir {
self.train_one(&samples[idx]);
}
}
pub fn train_batch_subsampled_with_callback<O: Observation, F: FnMut(usize)>(
&mut self,
samples: &[O],
max_samples: usize,
interval: usize,
mut callback: F,
) {
if max_samples >= samples.len() {
self.train_batch_with_callback(samples, interval, callback);
return;
}
let mut reservoir: Vec<usize> = (0..max_samples).collect();
let mut rng = self.rng_state;
for i in max_samples..samples.len() {
rng ^= rng << 13;
rng ^= rng >> 7;
rng ^= rng << 17;
let j = (rng % (i as u64 + 1)) as usize;
if j < max_samples {
reservoir[j] = i;
}
}
self.rng_state = rng;
reservoir.sort_unstable();
let interval = interval.max(1);
for (i, &idx) in reservoir.iter().enumerate() {
self.train_one(&samples[idx]);
if (i + 1) % interval == 0 {
callback(i + 1);
}
}
let total = reservoir.len();
if total % interval != 0 {
callback(total);
}
}
}