use alloc::boxed::Box;
use alloc::vec::Vec;
use crate::feature::FeatureType;
use crate::histogram::bins::LeafHistograms;
use crate::histogram::BinnerKind;
use crate::math;
use crate::tree::leaf_model::LeafModel;
pub(crate) struct LeafState {
pub histograms: Option<LeafHistograms>,
pub binners: Vec<BinnerKind>,
pub bins_ready: bool,
pub grad_sum: f64,
pub hess_sum: f64,
pub last_reeval_count: u64,
pub clip_grad_mean: f64,
pub clip_grad_m2: f64,
pub clip_grad_count: u64,
pub output_mean: f64,
pub output_m2: f64,
pub output_count: u64,
pub leaf_model: Option<Box<dyn LeafModel>>,
}
impl Clone for LeafState {
fn clone(&self) -> Self {
Self {
histograms: self.histograms.clone(),
binners: self.binners.clone(),
bins_ready: self.bins_ready,
grad_sum: self.grad_sum,
hess_sum: self.hess_sum,
last_reeval_count: self.last_reeval_count,
clip_grad_mean: self.clip_grad_mean,
clip_grad_m2: self.clip_grad_m2,
clip_grad_count: self.clip_grad_count,
output_mean: self.output_mean,
output_m2: self.output_m2,
output_count: self.output_count,
leaf_model: self.leaf_model.as_ref().map(|m| m.clone_warm()),
}
}
}
#[inline]
pub(crate) fn clip_gradient(state: &mut LeafState, gradient: f64, sigma: f64) -> f64 {
state.clip_grad_count += 1;
let n = state.clip_grad_count as f64;
let delta = gradient - state.clip_grad_mean;
state.clip_grad_mean += delta / n;
let delta2 = gradient - state.clip_grad_mean;
state.clip_grad_m2 += delta * delta2;
if state.clip_grad_count < 10 {
return gradient;
}
let variance = state.clip_grad_m2 / (n - 1.0);
let std_dev = math::sqrt(variance);
if std_dev < 1e-15 {
return gradient; }
let lo = state.clip_grad_mean - sigma * std_dev;
let hi = state.clip_grad_mean + sigma * std_dev;
gradient.clamp(lo, hi)
}
#[inline]
pub(crate) fn update_output_stats(state: &mut LeafState, weight: f64, decay_alpha: Option<f64>) {
state.output_count += 1;
if let Some(alpha) = decay_alpha {
if state.output_count == 1 {
state.output_mean = weight;
state.output_m2 = 0.0;
} else {
let diff = weight - state.output_mean;
state.output_mean = alpha * state.output_mean + (1.0 - alpha) * weight;
let diff2 = weight - state.output_mean;
state.output_m2 = alpha * state.output_m2 + (1.0 - alpha) * diff * diff2;
}
} else {
let delta = weight - state.output_mean;
state.output_mean += delta / (state.output_count as f64);
let delta2 = weight - state.output_mean;
state.output_m2 += delta * delta2;
}
}
#[inline]
pub(crate) fn adaptive_bound(state: &LeafState, k: f64, decay_alpha: Option<f64>) -> f64 {
if state.output_count < 10 {
return f64::MAX; }
let variance = if decay_alpha.is_some() {
state.output_m2.max(0.0)
} else {
state.output_m2 / (state.output_count as f64 - 1.0)
};
let std = math::sqrt(variance);
(math::abs(state.output_mean) + k * std).max(0.01)
}
pub(crate) fn make_binners(
n_features: usize,
feature_types: Option<&[FeatureType]>,
) -> Vec<BinnerKind> {
(0..n_features)
.map(|i| {
if let Some(ft) = feature_types {
if i < ft.len() && ft[i] == FeatureType::Categorical {
return BinnerKind::categorical();
}
}
BinnerKind::uniform()
})
.collect()
}
impl LeafState {
pub(crate) fn new(n_features: usize) -> Self {
Self::new_with_types(n_features, None)
}
pub(crate) fn new_with_types(n_features: usize, feature_types: Option<&[FeatureType]>) -> Self {
let binners = make_binners(n_features, feature_types);
Self {
histograms: None,
binners,
bins_ready: false,
grad_sum: 0.0,
hess_sum: 0.0,
last_reeval_count: 0,
clip_grad_mean: 0.0,
clip_grad_m2: 0.0,
clip_grad_count: 0,
output_mean: 0.0,
output_m2: 0.0,
output_count: 0,
leaf_model: None,
}
}
#[allow(dead_code)]
pub(crate) fn with_histograms(histograms: LeafHistograms) -> Self {
let n_features = histograms.n_features();
let binners: Vec<BinnerKind> = (0..n_features).map(|_| BinnerKind::uniform()).collect();
let grad_sum: f64 = histograms
.histograms
.first()
.map_or(0.0, |h| h.total_gradient());
let hess_sum: f64 = histograms
.histograms
.first()
.map_or(0.0, |h| h.total_hessian());
Self {
histograms: Some(histograms),
binners,
bins_ready: true,
grad_sum,
hess_sum,
last_reeval_count: 0,
clip_grad_mean: 0.0,
clip_grad_m2: 0.0,
clip_grad_count: 0,
output_mean: 0.0,
output_m2: 0.0,
output_count: 0,
leaf_model: None,
}
}
}