pub mod leaf;
pub mod split_logic;
use alloc::vec;
use alloc::vec::Vec;
use crate::histogram::bins::LeafHistograms;
use crate::math;
use crate::tree::builder::TreeConfig;
use crate::tree::leaf_model::LeafModelType;
use crate::tree::node::{NodeId, TreeArena};
use crate::tree::split::{leaf_weight, XGBoostGain};
use crate::tree::StreamingTree;
use leaf::{adaptive_bound, clip_gradient, make_binners, update_output_stats, LeafState};
pub struct HoeffdingTree {
pub(crate) arena: TreeArena,
pub(crate) root: NodeId,
pub(crate) config: TreeConfig,
pub(crate) leaf_states: Vec<Option<LeafState>>,
pub(crate) n_features: Option<usize>,
pub(crate) samples_seen: u64,
pub(crate) split_criterion: XGBoostGain,
pub(crate) feature_mask: Vec<usize>,
pub(crate) feature_mask_bits: Vec<u64>,
pub(crate) rng_state: u64,
pub(crate) split_gains: Vec<f64>,
pub(crate) node_bandwidths: Vec<f64>,
}
impl HoeffdingTree {
pub fn new(config: TreeConfig) -> Self {
let mut arena = TreeArena::new();
let root = arena.add_leaf(0);
let mut leaf_states = vec![None; root.0 as usize + 1];
let root_model = match config.leaf_model_type {
LeafModelType::ClosedForm => None,
_ => Some(config.leaf_model_type.create(config.seed, config.delta)),
};
leaf_states[root.0 as usize] = Some(LeafState {
histograms: None,
binners: Vec::new(),
bins_ready: false,
grad_sum: 0.0,
hess_sum: 0.0,
last_reeval_count: 0,
clip_grad_mean: 0.0,
clip_grad_m2: 0.0,
clip_grad_count: 0,
output_mean: 0.0,
output_m2: 0.0,
output_count: 0,
leaf_model: root_model,
});
let seed = config.seed;
Self {
arena,
root,
config,
leaf_states,
n_features: None,
samples_seen: 0,
split_criterion: XGBoostGain::default(),
feature_mask: Vec::new(),
feature_mask_bits: Vec::new(),
rng_state: seed,
split_gains: Vec::new(),
node_bandwidths: Vec::new(),
}
}
fn make_leaf_model(
&self,
node: NodeId,
) -> Option<alloc::boxed::Box<dyn crate::tree::leaf_model::LeafModel>> {
match self.config.leaf_model_type {
LeafModelType::ClosedForm => None,
_ => Some(
self.config
.leaf_model_type
.create(self.config.seed ^ (node.0 as u64), self.config.delta),
),
}
}
pub fn from_arena(
config: TreeConfig,
arena: TreeArena,
n_features: Option<usize>,
samples_seen: u64,
rng_state: u64,
) -> Self {
let root = if arena.n_nodes() > 0 {
NodeId(0)
} else {
let mut arena_mut = arena;
let root = arena_mut.add_leaf(0);
return Self {
arena: arena_mut,
root,
config: config.clone(),
leaf_states: {
let mut v = vec![None; root.0 as usize + 1];
v[root.0 as usize] = Some(LeafState::new(n_features.unwrap_or(0)));
v
},
n_features,
samples_seen,
split_criterion: XGBoostGain::default(),
feature_mask: Vec::new(),
feature_mask_bits: Vec::new(),
rng_state,
split_gains: vec![0.0; n_features.unwrap_or(0)],
node_bandwidths: Vec::new(),
};
};
let nf = n_features.unwrap_or(0);
let mut leaf_states: Vec<Option<LeafState>> = vec![None; arena.n_nodes()];
for (i, slot) in leaf_states.iter_mut().enumerate() {
if arena.is_leaf[i] {
*slot = Some(LeafState::new(nf));
}
}
Self {
arena,
root,
config,
leaf_states,
n_features,
samples_seen,
split_criterion: XGBoostGain::default(),
feature_mask: Vec::new(),
feature_mask_bits: Vec::new(),
rng_state,
split_gains: vec![0.0; nf],
node_bandwidths: Vec::new(),
}
}
#[inline]
pub fn root(&self) -> NodeId {
self.root
}
#[inline]
pub fn arena(&self) -> &TreeArena {
&self.arena
}
#[inline]
pub fn tree_config(&self) -> &TreeConfig {
&self.config
}
#[inline]
pub fn n_features(&self) -> Option<usize> {
self.n_features
}
#[inline]
pub fn rng_state(&self) -> u64 {
self.rng_state
}
#[inline]
pub fn leaf_grad_hess(&self, node: NodeId) -> Option<(f64, f64)> {
self.leaf_states
.get(node.0 as usize)
.and_then(|o| o.as_ref())
.map(|state| (state.grad_sum, state.hess_sum))
}
pub(crate) fn route_to_leaf(&self, features: &[f64]) -> NodeId {
let mut current = self.root;
while !self.arena.is_leaf(current) {
let feat_idx = self.arena.get_feature_idx(current) as usize;
current = if let Some(mask) = self.arena.get_categorical_mask(current) {
let cat_val = features[feat_idx] as u64;
if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
self.arena.get_left(current)
} else {
self.arena.get_right(current)
}
} else {
let threshold = self.arena.get_threshold(current);
if features[feat_idx] <= threshold {
self.arena.get_left(current)
} else {
self.arena.get_right(current)
}
};
}
current
}
#[inline]
fn leaf_prediction(&self, leaf_id: NodeId, features: &[f64]) -> f64 {
let (raw, leaf_bound) = if let Some(state) = self
.leaf_states
.get(leaf_id.0 as usize)
.and_then(|o| o.as_ref())
{
if let Some(min_h) = self.config.min_hessian_sum {
if state.hess_sum < min_h {
return 0.0;
}
}
let val = if let Some(ref model) = state.leaf_model {
model.predict(features)
} else if state.hess_sum != 0.0 {
leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda)
} else {
self.arena.leaf_value[leaf_id.0 as usize]
};
let bound = self
.config
.adaptive_leaf_bound
.map(|k| adaptive_bound(state, k, self.config.leaf_decay_alpha));
(val, bound)
} else {
(0.0, None)
};
if let Some(bound) = leaf_bound {
if bound < f64::MAX {
return raw.clamp(-bound, bound);
}
}
if let Some(max) = self.config.max_leaf_output {
raw.clamp(-max, max)
} else {
raw
}
}
pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
self.predict_smooth_recursive(self.root, features, bandwidth)
}
pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
self.predict_smooth_auto_recursive(self.root, features, bandwidths)
}
pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
let mut current = self.root;
let mut parent = None;
while !self.arena.is_leaf(current) {
parent = Some(current);
let feat_idx = self.arena.get_feature_idx(current) as usize;
current = if let Some(mask) = self.arena.get_categorical_mask(current) {
let cat_val = features[feat_idx] as u64;
if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
self.arena.get_left(current)
} else {
self.arena.get_right(current)
}
} else {
let threshold = self.arena.get_threshold(current);
if features[feat_idx] <= threshold {
self.arena.get_left(current)
} else {
self.arena.get_right(current)
}
};
}
let leaf_pred = self.leaf_prediction(current, features);
let parent_id = match parent {
Some(p) => p,
None => return leaf_pred,
};
let parent_pred = self.leaf_prediction(parent_id, features);
let leaf_hess = self
.leaf_states
.get(current.0 as usize)
.and_then(|o| o.as_ref())
.map(|s| s.hess_sum)
.unwrap_or(0.0);
let alpha = leaf_hess / (leaf_hess + self.config.lambda);
alpha * leaf_pred + (1.0 - alpha) * parent_pred
}
pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
self.predict_sibling_recursive(self.root, features, bandwidths)
}
fn predict_sibling_recursive(&self, node: NodeId, features: &[f64], bandwidths: &[f64]) -> f64 {
if self.arena.is_leaf(node) {
return self.leaf_prediction(node, features);
}
let feat_idx = self.arena.get_feature_idx(node) as usize;
let left = self.arena.get_left(node);
let right = self.arena.get_right(node);
if let Some(mask) = self.arena.get_categorical_mask(node) {
let cat_val = features[feat_idx] as u64;
return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
self.predict_sibling_recursive(left, features, bandwidths)
} else {
self.predict_sibling_recursive(right, features, bandwidths)
};
}
let threshold = self.arena.get_threshold(node);
let margin = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
if !margin.is_finite() || margin <= 0.0 {
return if features[feat_idx] <= threshold {
self.predict_sibling_recursive(left, features, bandwidths)
} else {
self.predict_sibling_recursive(right, features, bandwidths)
};
}
let dist = features[feat_idx] - threshold;
if dist < -margin {
self.predict_sibling_recursive(left, features, bandwidths)
} else if dist > margin {
self.predict_sibling_recursive(right, features, bandwidths)
} else {
let t = (dist + margin) / (2.0 * margin); let left_pred = self.predict_sibling_recursive(left, features, bandwidths);
let right_pred = self.predict_sibling_recursive(right, features, bandwidths);
(1.0 - t) * left_pred + t * right_pred
}
}
pub fn collect_split_thresholds_per_feature(&self) -> Vec<Vec<f64>> {
let n = self.n_features.unwrap_or(0);
let mut thresholds: Vec<Vec<f64>> = vec![Vec::new(); n];
for i in 0..self.arena.n_nodes() {
if !self.arena.is_leaf[i] && self.arena.categorical_mask[i].is_none() {
let feat_idx = self.arena.feature_idx[i] as usize;
if feat_idx < n {
thresholds[feat_idx].push(self.arena.threshold[i]);
}
}
}
thresholds
}
fn compute_node_bandwidth(&self, node: NodeId, all_thresholds: &[Vec<f64>]) -> f64 {
let feat_idx = self.arena.get_feature_idx(node) as usize;
let threshold = self.arena.get_threshold(node);
let thresholds = if feat_idx < all_thresholds.len() {
&all_thresholds[feat_idx]
} else {
return f64::INFINITY;
};
let below = thresholds.iter().rev().find(|&&t| t < threshold - 1e-15);
let above = thresholds.iter().find(|&&t| t > threshold + 1e-15);
match (below, above) {
(Some(&b), Some(&a)) => (threshold - b).min(a - threshold),
(Some(&b), None) => threshold - b,
(None, Some(&a)) => a - threshold,
(None, None) => f64::INFINITY,
}
}
pub fn recompute_bandwidths(&mut self) {
let n = self.arena.n_nodes();
self.node_bandwidths.resize(n, f64::INFINITY);
let mut all_thresholds = self.collect_split_thresholds_per_feature();
for v in &mut all_thresholds {
v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
}
for i in 0..n {
let nid = NodeId(i as u32);
if !self.arena.is_leaf(nid) {
self.node_bandwidths[i] = self.compute_node_bandwidth(nid, &all_thresholds);
} else {
self.node_bandwidths[i] = f64::INFINITY;
}
}
}
pub fn predict_soft_routed(&self, features: &[f64]) -> f64 {
self.predict_soft_recursive(self.root, features)
}
fn predict_soft_recursive(&self, node: NodeId, features: &[f64]) -> f64 {
if self.arena.is_leaf(node) {
return self.leaf_prediction(node, features);
}
let feat_idx = self.arena.get_feature_idx(node) as usize;
let left = self.arena.get_left(node);
let right = self.arena.get_right(node);
if let Some(mask) = self.arena.get_categorical_mask(node) {
let cat_val = features[feat_idx] as u64;
return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
self.predict_soft_recursive(left, features)
} else {
self.predict_soft_recursive(right, features)
};
}
let threshold = self.arena.get_threshold(node);
let margin = self
.node_bandwidths
.get(node.0 as usize)
.copied()
.unwrap_or(f64::INFINITY);
let left_pred = self.predict_soft_recursive(left, features);
let right_pred = self.predict_soft_recursive(right, features);
if !margin.is_finite() || margin <= 0.0 {
let dist = features[feat_idx] - threshold;
let scale = math::abs(threshold) * 0.01 + 1e-10;
let z = (-dist / scale).clamp(-500.0, 500.0);
let t = 1.0 / (1.0 + math::exp(z));
return (1.0 - t) * left_pred + t * right_pred;
}
let dist = features[feat_idx] - threshold;
let t = ((dist + margin) / (2.0 * margin)).clamp(0.0, 1.0);
(1.0 - t) * left_pred + t * right_pred
}
fn predict_smooth_recursive(&self, node: NodeId, features: &[f64], bandwidth: f64) -> f64 {
if self.arena.is_leaf(node) {
return self.leaf_prediction(node, features);
}
let feat_idx = self.arena.get_feature_idx(node) as usize;
let left = self.arena.get_left(node);
let right = self.arena.get_right(node);
if let Some(mask) = self.arena.get_categorical_mask(node) {
let cat_val = features[feat_idx] as u64;
return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
self.predict_smooth_recursive(left, features, bandwidth)
} else {
self.predict_smooth_recursive(right, features, bandwidth)
};
}
let threshold = self.arena.get_threshold(node);
let z = (threshold - features[feat_idx]) / bandwidth;
let alpha = 1.0 / (1.0 + math::exp(-z));
let left_pred = self.predict_smooth_recursive(left, features, bandwidth);
let right_pred = self.predict_smooth_recursive(right, features, bandwidth);
alpha * left_pred + (1.0 - alpha) * right_pred
}
fn predict_smooth_auto_recursive(
&self,
node: NodeId,
features: &[f64],
bandwidths: &[f64],
) -> f64 {
if self.arena.is_leaf(node) {
return self.leaf_prediction(node, features);
}
let feat_idx = self.arena.get_feature_idx(node) as usize;
let left = self.arena.get_left(node);
let right = self.arena.get_right(node);
if let Some(mask) = self.arena.get_categorical_mask(node) {
let cat_val = features[feat_idx] as u64;
return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
self.predict_smooth_auto_recursive(left, features, bandwidths)
} else {
self.predict_smooth_auto_recursive(right, features, bandwidths)
};
}
let threshold = self.arena.get_threshold(node);
let bw = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
if !bw.is_finite() {
return if features[feat_idx] <= threshold {
self.predict_smooth_auto_recursive(left, features, bandwidths)
} else {
self.predict_smooth_auto_recursive(right, features, bandwidths)
};
}
let z = (threshold - features[feat_idx]) / bw;
let alpha = 1.0 / (1.0 + math::exp(-z));
let left_pred = self.predict_smooth_auto_recursive(left, features, bandwidths);
let right_pred = self.predict_smooth_auto_recursive(right, features, bandwidths);
alpha * left_pred + (1.0 - alpha) * right_pred
}
pub(crate) fn attempt_split(&mut self, leaf_id: NodeId) -> bool {
let depth = self.arena.get_depth(leaf_id);
let hard_ceiling = if self.config.adaptive_depth.is_some() {
self.config.max_depth.saturating_mul(2)
} else {
self.config.max_depth
};
let at_max_depth = depth as usize >= hard_ceiling;
if at_max_depth {
match self.config.split_reeval_interval {
None => return false,
Some(interval) => {
let state = match self
.leaf_states
.get(leaf_id.0 as usize)
.and_then(|o| o.as_ref())
{
Some(s) => s,
None => return false,
};
let sample_count = self.arena.get_sample_count(leaf_id);
if sample_count - state.last_reeval_count < interval as u64 {
return false;
}
}
}
}
let n_features = match self.n_features {
Some(n) => n,
None => return false,
};
let sample_count = self.arena.get_sample_count(leaf_id);
if sample_count < self.config.grace_period as u64 {
return false;
}
let (feature_mask, feature_mask_bits) = split_logic::generate_feature_mask(
core::mem::take(&mut self.feature_mask),
core::mem::take(&mut self.feature_mask_bits),
&mut self.rng_state,
self.config.feature_subsample_rate,
n_features,
);
self.feature_mask = feature_mask;
self.feature_mask_bits = feature_mask_bits;
if self.config.leaf_decay_alpha.is_some() {
if let Some(state) = self
.leaf_states
.get_mut(leaf_id.0 as usize)
.and_then(|o| o.as_mut())
{
if let Some(ref mut histograms) = state.histograms {
histograms.materialize_decay();
}
}
}
let state = match self
.leaf_states
.get(leaf_id.0 as usize)
.and_then(|o| o.as_ref())
{
Some(s) => s,
None => return false,
};
let histograms = match &state.histograms {
Some(h) => h,
None => return false,
};
let ctx = split_logic::private::SplitContext {
config: &self.config,
n_features: self.n_features,
n_feature_mask: &self.feature_mask,
split_criterion: &self.split_criterion,
rng_state: &mut self.rng_state,
};
let candidates = split_logic::private::evaluate_split_candidates(
histograms,
self.config.feature_types.as_deref(),
&ctx,
);
if candidates.is_empty() {
return false;
}
let best_gain = candidates[0].1.gain;
let second_best_gain = if candidates.len() > 1 {
candidates[1].1.gain
} else {
0.0
};
let ctx = split_logic::private::SplitContext {
config: &self.config,
n_features: self.n_features,
n_feature_mask: &self.feature_mask,
split_criterion: &self.split_criterion,
rng_state: &mut self.rng_state,
};
if !split_logic::private::should_split_hoeffding(
best_gain,
second_best_gain,
sample_count,
&ctx,
) {
if at_max_depth {
if let Some(state) = self
.leaf_states
.get_mut(leaf_id.0 as usize)
.and_then(|o| o.as_mut())
{
state.last_reeval_count = sample_count;
}
}
return false;
}
let (best_feat_idx, ref best_candidate, ref fisher_order) = candidates[0];
if best_feat_idx < self.split_gains.len() {
self.split_gains[best_feat_idx] += best_candidate.gain;
}
let best_hist = &histograms.histograms[best_feat_idx];
let left_value = leaf_weight(
best_candidate.left_grad,
best_candidate.left_hess,
self.config.lambda,
);
let right_value = leaf_weight(
best_candidate.right_grad,
best_candidate.right_hess,
self.config.lambda,
);
let (left_id, right_id) = if let Some(ref order) = fisher_order {
let mut mask: u64 = 0;
for &sorted_pos in order.iter().take(best_candidate.bin_idx + 1) {
if sorted_pos < 64 {
mask |= 1u64 << sorted_pos;
}
}
self.arena.split_leaf_categorical(
leaf_id,
best_feat_idx as u32,
0.0,
left_value,
right_value,
mask,
)
} else {
let threshold = if best_candidate.bin_idx < best_hist.edges.edges.len() {
best_hist.edges.edges[best_candidate.bin_idx]
} else {
f64::MAX
};
self.arena.split_leaf(
leaf_id,
best_feat_idx as u32,
threshold,
left_value,
right_value,
)
};
let parent_state = self
.leaf_states
.get_mut(leaf_id.0 as usize)
.and_then(|o| o.take());
let nf = n_features;
let max_child = left_id.0.max(right_id.0) as usize;
if self.leaf_states.len() <= max_child {
self.leaf_states.resize_with(max_child + 1, || None);
}
if let Some(parent) = parent_state {
if let Some(parent_hists) = parent.histograms {
let edges_per_feature: Vec<crate::histogram::BinEdges> = parent_hists
.histograms
.iter()
.map(|h| h.edges.clone())
.collect();
let left_hists = LeafHistograms::new(&edges_per_feature);
let right_hists = LeafHistograms::new(&edges_per_feature);
let ft = self.config.feature_types.as_deref();
let child_binners_l = make_binners(nf, ft);
let child_binners_r = make_binners(nf, ft);
let left_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
let right_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
let left_state = LeafState {
histograms: Some(left_hists),
binners: child_binners_l,
bins_ready: true,
grad_sum: 0.0,
hess_sum: 0.0,
last_reeval_count: 0,
clip_grad_mean: 0.0,
clip_grad_m2: 0.0,
clip_grad_count: 0,
output_mean: 0.0,
output_m2: 0.0,
output_count: 0,
leaf_model: left_model,
};
let right_state = LeafState {
histograms: Some(right_hists),
binners: child_binners_r,
bins_ready: true,
grad_sum: 0.0,
hess_sum: 0.0,
last_reeval_count: 0,
clip_grad_mean: 0.0,
clip_grad_m2: 0.0,
clip_grad_count: 0,
output_mean: 0.0,
output_m2: 0.0,
output_count: 0,
leaf_model: right_model,
};
self.leaf_states[left_id.0 as usize] = Some(left_state);
self.leaf_states[right_id.0 as usize] = Some(right_state);
} else {
let ft = self.config.feature_types.as_deref();
let mut ls = LeafState::new_with_types(nf, ft);
ls.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
self.leaf_states[left_id.0 as usize] = Some(ls);
let mut rs = LeafState::new_with_types(nf, ft);
rs.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
self.leaf_states[right_id.0 as usize] = Some(rs);
}
} else {
let ft = self.config.feature_types.as_deref();
let mut ls = LeafState::new_with_types(nf, ft);
ls.leaf_model = self.make_leaf_model(left_id);
self.leaf_states[left_id.0 as usize] = Some(ls);
let mut rs = LeafState::new_with_types(nf, ft);
rs.leaf_model = self.make_leaf_model(right_id);
self.leaf_states[right_id.0 as usize] = Some(rs);
}
self.recompute_bandwidths();
true
}
}
impl StreamingTree for HoeffdingTree {
fn train_one(&mut self, features: &[f64], gradient: f64, hessian: f64) {
self.samples_seen += 1;
let n_features = if let Some(n) = self.n_features {
n
} else {
let n = features.len();
self.n_features = Some(n);
self.split_gains.resize(n, 0.0);
if let Some(state) = self
.leaf_states
.get_mut(self.root.0 as usize)
.and_then(|o| o.as_mut())
{
state.binners = make_binners(n, self.config.feature_types.as_deref());
}
n
};
debug_assert_eq!(
features.len(),
n_features,
"feature count mismatch: got {} but expected {}",
features.len(),
n_features,
);
let leaf_id = self.route_to_leaf(features);
self.arena.increment_sample_count(leaf_id);
let sample_count = self.arena.get_sample_count(leaf_id);
let idx = leaf_id.0 as usize;
if self.leaf_states.len() <= idx {
self.leaf_states.resize_with(idx + 1, || None);
}
if self.leaf_states[idx].is_none() {
self.leaf_states[idx] = Some(LeafState::new_with_types(
n_features,
self.config.feature_types.as_deref(),
));
}
let state = self.leaf_states[idx].as_mut().unwrap();
let gradient = if let Some(sigma) = self.config.gradient_clip_sigma {
clip_gradient(state, gradient, sigma)
} else {
gradient
};
if !state.bins_ready {
for (binner, &val) in state.binners.iter_mut().zip(features.iter()) {
binner.observe(val);
}
if let Some(alpha) = self.config.leaf_decay_alpha {
state.grad_sum = alpha * state.grad_sum + gradient;
state.hess_sum = alpha * state.hess_sum + hessian;
} else {
state.grad_sum += gradient;
state.hess_sum += hessian;
}
let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
self.arena.set_leaf_value(leaf_id, lw);
if self.config.adaptive_leaf_bound.is_some() {
update_output_stats(state, lw, self.config.leaf_decay_alpha);
}
if let Some(ref mut model) = state.leaf_model {
model.update(features, gradient, hessian, self.config.lambda);
}
if sample_count >= self.config.grace_period as u64 {
let edges_per_feature: Vec<crate::histogram::BinEdges> = state
.binners
.iter()
.map(|b| b.compute_edges(self.config.n_bins))
.collect();
let mut histograms = LeafHistograms::new(&edges_per_feature);
if let Some(alpha) = self.config.leaf_decay_alpha {
histograms.accumulate_with_decay(features, gradient, hessian, alpha);
} else {
histograms.accumulate(features, gradient, hessian);
}
state.histograms = Some(histograms);
state.bins_ready = true;
}
return;
}
if let Some(ref mut histograms) = state.histograms {
if let Some(alpha) = self.config.leaf_decay_alpha {
histograms.accumulate_with_decay(features, gradient, hessian, alpha);
} else {
histograms.accumulate(features, gradient, hessian);
}
}
if let Some(alpha) = self.config.leaf_decay_alpha {
state.grad_sum = alpha * state.grad_sum + gradient;
state.hess_sum = alpha * state.hess_sum + hessian;
} else {
state.grad_sum += gradient;
state.hess_sum += hessian;
}
let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
self.arena.set_leaf_value(leaf_id, lw);
if self.config.adaptive_leaf_bound.is_some() {
update_output_stats(state, lw, self.config.leaf_decay_alpha);
}
if let Some(ref mut model) = state.leaf_model {
model.update(features, gradient, hessian, self.config.lambda);
}
if sample_count % (self.config.grace_period as u64) == 0 {
self.attempt_split(leaf_id);
}
}
fn predict(&self, features: &[f64]) -> f64 {
let leaf_id = self.route_to_leaf(features);
self.leaf_prediction(leaf_id, features)
}
#[inline]
fn n_leaves(&self) -> usize {
self.arena.n_leaves()
}
#[inline]
fn n_samples_seen(&self) -> u64 {
self.samples_seen
}
fn reset(&mut self) {
self.arena.reset();
let root = self.arena.add_leaf(0);
self.root = root;
self.leaf_states.clear();
let n_features = self.n_features.unwrap_or(0);
self.leaf_states.resize_with(root.0 as usize + 1, || None);
let mut root_state =
LeafState::new_with_types(n_features, self.config.feature_types.as_deref());
root_state.leaf_model = self.make_leaf_model(root);
self.leaf_states[root.0 as usize] = Some(root_state);
self.samples_seen = 0;
self.feature_mask.clear();
self.feature_mask_bits.clear();
self.rng_state = self.config.seed;
self.split_gains.iter_mut().for_each(|g| *g = 0.0);
self.node_bandwidths.clear();
}
fn split_gains(&self) -> &[f64] {
&self.split_gains
}
fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
let leaf_id = self.route_to_leaf(features);
let value = self.leaf_prediction(leaf_id, features);
if let Some(state) = self
.leaf_states
.get(leaf_id.0 as usize)
.and_then(|o| o.as_ref())
{
let variance = 1.0 / (state.hess_sum + self.config.lambda);
(value, variance)
} else {
(value, f64::INFINITY)
}
}
}
impl Clone for HoeffdingTree {
fn clone(&self) -> Self {
Self {
arena: self.arena.clone(),
root: self.root,
config: self.config.clone(),
leaf_states: self.leaf_states.clone(),
n_features: self.n_features,
samples_seen: self.samples_seen,
split_criterion: self.split_criterion,
feature_mask: self.feature_mask.clone(),
feature_mask_bits: self.feature_mask_bits.clone(),
rng_state: self.rng_state,
split_gains: self.split_gains.clone(),
node_bandwidths: self.node_bandwidths.clone(),
}
}
}
unsafe impl Send for HoeffdingTree {}
unsafe impl Sync for HoeffdingTree {}
#[cfg(test)]
mod tests {
use super::*;
use crate::tree::builder::TreeConfig;
use crate::tree::StreamingTree;
#[test]
fn single_sample_predict_not_nan() {
let config = TreeConfig::new().grace_period(10);
let mut tree = HoeffdingTree::new(config);
let features = vec![1.0, 2.0, 3.0];
tree.train_one(&features, -0.5, 1.0);
let pred = tree.predict(&features);
assert!(!pred.is_nan(), "prediction should not be NaN, got {}", pred);
assert!(
pred.is_finite(),
"prediction should be finite, got {}",
pred
);
assert!((pred - 0.25).abs() < 1e-10, "expected ~0.25, got {}", pred);
}
}