Skip to main content

irithyll_core/tree/hoeffding/
mod.rs

1//! Hoeffding-bound split decisions for streaming tree construction.
2//!
3//! [`HoeffdingTree`] is the core streaming decision tree. It grows incrementally:
4//! each sample updates per-leaf histogram accumulators, and splits are committed
5//! only when the Hoeffding bound guarantees the best candidate split is
6//! statistically superior to the runner-up (or a tie-breaking threshold is met).
7//!
8//! # Algorithm
9//!
10//! For each incoming `(features, gradient, hessian)` triple:
11//!
12//! 1. Route the sample from root to a leaf via threshold comparisons.
13//! 2. At the leaf, accumulate gradient/hessian into per-feature histograms.
14//! 3. Once enough samples arrive (grace period), evaluate candidate splits
15//!    using the XGBoost gain formula.
16//! 4. Apply the Hoeffding bound: if the gap between the best and second-best
17//!    gain exceeds `epsilon = sqrt(R^2 * ln(1/delta) / (2n))`, commit the split.
18//! 5. When splitting, use the histogram subtraction trick to initialize one
19//!    child's histograms for free.
20
21pub mod leaf;
22pub mod split_logic;
23
24use alloc::vec;
25use alloc::vec::Vec;
26
27use crate::histogram::bins::LeafHistograms;
28use crate::math;
29use crate::tree::builder::TreeConfig;
30use crate::tree::leaf_model::LeafModelType;
31use crate::tree::node::{NodeId, TreeArena};
32use crate::tree::split::{leaf_weight, XGBoostGain};
33use crate::tree::StreamingTree;
34
35use leaf::{adaptive_bound, clip_gradient, make_binners, update_output_stats, LeafState};
36
37/// A streaming decision tree that uses Hoeffding-bound split decisions.
38///
39/// The tree grows incrementally: each call to [`train_one`](StreamingTree::train_one)
40/// routes one sample to its leaf, updates histograms, and potentially triggers
41/// a split when statistical evidence is sufficient.
42///
43/// # Feature subsampling
44///
45/// When `config.feature_subsample_rate < 1.0`, each split evaluation considers
46/// only a random subset of features (selected via a deterministic xorshift64 RNG).
47/// This adds diversity when the tree is used inside an ensemble.
48pub struct HoeffdingTree {
49    /// Arena-allocated node storage.
50    pub(crate) arena: TreeArena,
51
52    /// Root node identifier.
53    pub(crate) root: NodeId,
54
55    /// Tree configuration / hyperparameters.
56    pub(crate) config: TreeConfig,
57
58    /// Per-leaf state indexed by `NodeId.0`. Dense Vec -- NodeIds are
59    /// contiguous u32 indices from TreeArena, so direct indexing is optimal.
60    pub(crate) leaf_states: Vec<Option<LeafState>>,
61
62    /// Number of features, learned from the first sample.
63    pub(crate) n_features: Option<usize>,
64
65    /// Total samples seen across all calls to `train_one`.
66    pub(crate) samples_seen: u64,
67
68    /// Split gain evaluator.
69    pub(crate) split_criterion: XGBoostGain,
70
71    /// Scratch buffer for the feature mask (avoids repeated allocation).
72    pub(crate) feature_mask: Vec<usize>,
73
74    /// Bitset scratch buffer for O(1) membership test during feature mask generation.
75    /// Each bit `i` indicates whether feature `i` is already in `feature_mask`.
76    pub(crate) feature_mask_bits: Vec<u64>,
77
78    /// xorshift64 RNG state for feature subsampling.
79    pub(crate) rng_state: u64,
80
81    /// Accumulated split gains per feature for importance tracking.
82    /// Indexed by feature index; grows lazily when n_features is learned.
83    pub(crate) split_gains: Vec<f64>,
84
85    /// Per-node auto-bandwidth for soft routing, indexed by `NodeId.0`.
86    /// Recomputed after every structural change (split).
87    pub(crate) node_bandwidths: Vec<f64>,
88}
89
90impl HoeffdingTree {
91    /// Create a new `HoeffdingTree` with the given configuration.
92    ///
93    /// The tree starts with a single root leaf and no feature information;
94    /// the number of features is inferred from the first training sample.
95    pub fn new(config: TreeConfig) -> Self {
96        let mut arena = TreeArena::new();
97        let root = arena.add_leaf(0);
98
99        // Insert a placeholder leaf state for the root. We don't know n_features
100        // yet, so give it 0 binners -- it will be properly initialized on the
101        // first sample.
102        let mut leaf_states = vec![None; root.0 as usize + 1];
103        let root_model = match config.leaf_model_type {
104            LeafModelType::ClosedForm => None,
105            _ => Some(config.leaf_model_type.create(config.seed, config.delta)),
106        };
107        leaf_states[root.0 as usize] = Some(LeafState {
108            histograms: None,
109            binners: Vec::new(),
110            bins_ready: false,
111            grad_sum: 0.0,
112            hess_sum: 0.0,
113            last_reeval_count: 0,
114            clip_grad_mean: 0.0,
115            clip_grad_m2: 0.0,
116            clip_grad_count: 0,
117            output_mean: 0.0,
118            output_m2: 0.0,
119            output_count: 0,
120            leaf_model: root_model,
121        });
122
123        let seed = config.seed;
124        Self {
125            arena,
126            root,
127            config,
128            leaf_states,
129            n_features: None,
130            samples_seen: 0,
131            split_criterion: XGBoostGain::default(),
132            feature_mask: Vec::new(),
133            feature_mask_bits: Vec::new(),
134            rng_state: seed,
135            split_gains: Vec::new(),
136            node_bandwidths: Vec::new(),
137        }
138    }
139
140    /// Create a leaf model for a new leaf if the config requires one.
141    ///
142    /// Returns `None` for `ClosedForm` (the default), which uses the existing
143    /// `leaf_weight()` path with zero overhead. For `Linear` and `MLP`, returns
144    /// a fresh model seeded deterministically from the config seed and node id.
145    fn make_leaf_model(
146        &self,
147        node: NodeId,
148    ) -> Option<alloc::boxed::Box<dyn crate::tree::leaf_model::LeafModel>> {
149        match self.config.leaf_model_type {
150            LeafModelType::ClosedForm => None,
151            _ => Some(
152                self.config
153                    .leaf_model_type
154                    .create(self.config.seed ^ (node.0 as u64), self.config.delta),
155            ),
156        }
157    }
158
159    /// Reconstruct a `HoeffdingTree` from a pre-built arena.
160    ///
161    /// Used during model deserialization. The tree is restored with node
162    /// topology and leaf values intact, but histogram accumulators are empty
163    /// (they will rebuild naturally from continued training).
164    ///
165    /// The root is assumed to be `NodeId(0)`. Leaf states are created empty
166    /// for all current leaf nodes in the arena.
167    pub fn from_arena(
168        config: TreeConfig,
169        arena: TreeArena,
170        n_features: Option<usize>,
171        samples_seen: u64,
172        rng_state: u64,
173    ) -> Self {
174        let root = if arena.n_nodes() > 0 {
175            NodeId(0)
176        } else {
177            // Empty arena -- add a root leaf (shouldn't normally happen in restore).
178            let mut arena_mut = arena;
179            let root = arena_mut.add_leaf(0);
180            return Self {
181                arena: arena_mut,
182                root,
183                config: config.clone(),
184                leaf_states: {
185                    let mut v = vec![None; root.0 as usize + 1];
186                    v[root.0 as usize] = Some(LeafState::new(n_features.unwrap_or(0)));
187                    v
188                },
189                n_features,
190                samples_seen,
191                split_criterion: XGBoostGain::default(),
192                feature_mask: Vec::new(),
193                feature_mask_bits: Vec::new(),
194                rng_state,
195                split_gains: vec![0.0; n_features.unwrap_or(0)],
196                node_bandwidths: Vec::new(),
197            };
198        };
199
200        // Build leaf states for every leaf in the arena.
201        let nf = n_features.unwrap_or(0);
202        let mut leaf_states: Vec<Option<LeafState>> = vec![None; arena.n_nodes()];
203        for (i, slot) in leaf_states.iter_mut().enumerate() {
204            if arena.is_leaf[i] {
205                *slot = Some(LeafState::new(nf));
206            }
207        }
208
209        Self {
210            arena,
211            root,
212            config,
213            leaf_states,
214            n_features,
215            samples_seen,
216            split_criterion: XGBoostGain::default(),
217            feature_mask: Vec::new(),
218            feature_mask_bits: Vec::new(),
219            rng_state,
220            split_gains: vec![0.0; nf],
221            node_bandwidths: Vec::new(),
222        }
223    }
224
225    /// Root node identifier.
226    #[inline]
227    pub fn root(&self) -> NodeId {
228        self.root
229    }
230
231    /// Immutable access to the underlying arena.
232    #[inline]
233    pub fn arena(&self) -> &TreeArena {
234        &self.arena
235    }
236
237    /// Immutable access to the tree configuration.
238    #[inline]
239    pub fn tree_config(&self) -> &TreeConfig {
240        &self.config
241    }
242
243    /// Number of features (learned from the first sample, `None` before any training).
244    #[inline]
245    pub fn n_features(&self) -> Option<usize> {
246        self.n_features
247    }
248
249    /// Current RNG state (for deterministic checkpoint/restore).
250    #[inline]
251    pub fn rng_state(&self) -> u64 {
252        self.rng_state
253    }
254
255    /// Read-only access to the gradient and hessian sums for a leaf node.
256    ///
257    /// Returns `Some((grad_sum, hess_sum))` if `node` is a leaf with an active
258    /// leaf state, or `None` if the node has no state (e.g. internal node
259    /// or freshly allocated).
260    ///
261    /// These sums enable inverse-hessian confidence estimation:
262    /// `confidence = 1.0 / (hess_sum + lambda)`. High hessian means the leaf
263    /// has seen consistent, informative data; low hessian means uncertainty.
264    #[inline]
265    pub fn leaf_grad_hess(&self, node: NodeId) -> Option<(f64, f64)> {
266        self.leaf_states
267            .get(node.0 as usize)
268            .and_then(|o| o.as_ref())
269            .map(|state| (state.grad_sum, state.hess_sum))
270    }
271
272    /// Route a feature vector from the root down to a leaf, returning the leaf's NodeId.
273    pub(crate) fn route_to_leaf(&self, features: &[f64]) -> NodeId {
274        let mut current = self.root;
275        while !self.arena.is_leaf(current) {
276            let feat_idx = self.arena.get_feature_idx(current) as usize;
277            current = if let Some(mask) = self.arena.get_categorical_mask(current) {
278                // Categorical split: use bitmask routing.
279                // The feature value is cast to a bin index. If that bin's bit is set
280                // in the mask, go left; otherwise go right.
281                // For categorical features, the bin index in the histogram corresponds
282                // to the sorted category position, but for bitmask routing we use
283                // the original bin index directly.
284                let cat_val = features[feat_idx] as u64;
285                if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
286                    self.arena.get_left(current)
287                } else {
288                    self.arena.get_right(current)
289                }
290            } else {
291                // Continuous split: standard threshold comparison.
292                let threshold = self.arena.get_threshold(current);
293                if features[feat_idx] <= threshold {
294                    self.arena.get_left(current)
295                } else {
296                    self.arena.get_right(current)
297                }
298            };
299        }
300        current
301    }
302
303    /// Get the prediction value for a leaf node.
304    ///
305    /// Checks (in order): leaf model, live grad/hess statistics, stored leaf value.
306    /// Returns `0.0` if no leaf state exists.
307    #[inline]
308    fn leaf_prediction(&self, leaf_id: NodeId, features: &[f64]) -> f64 {
309        let (raw, leaf_bound) = if let Some(state) = self
310            .leaf_states
311            .get(leaf_id.0 as usize)
312            .and_then(|o| o.as_ref())
313        {
314            // min_hessian_sum: suppress fresh leaves with insufficient samples
315            if let Some(min_h) = self.config.min_hessian_sum {
316                if state.hess_sum < min_h {
317                    return 0.0;
318                }
319            }
320            let val = if let Some(ref model) = state.leaf_model {
321                model.predict(features)
322            } else if state.hess_sum != 0.0 {
323                leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda)
324            } else {
325                self.arena.leaf_value[leaf_id.0 as usize]
326            };
327
328            // Compute per-leaf adaptive bound while state is in scope
329            let bound = self
330                .config
331                .adaptive_leaf_bound
332                .map(|k| adaptive_bound(state, k, self.config.leaf_decay_alpha));
333
334            (val, bound)
335        } else {
336            (0.0, None)
337        };
338
339        // Priority: per-leaf adaptive bound > global max_leaf_output > unclamped
340        if let Some(bound) = leaf_bound {
341            if bound < f64::MAX {
342                return raw.clamp(-bound, bound);
343            }
344        }
345        if let Some(max) = self.config.max_leaf_output {
346            raw.clamp(-max, max)
347        } else {
348            raw
349        }
350    }
351
352    /// Predict using sigmoid-blended soft routing for smooth interpolation.
353    ///
354    /// Instead of hard left/right routing at each split node, uses sigmoid
355    /// blending: `alpha = sigmoid((threshold - feature) / bandwidth)`. The
356    /// prediction is `alpha * left_pred + (1 - alpha) * right_pred`, computed
357    /// recursively from root to leaves.
358    ///
359    /// The result is a continuous function that varies smoothly with every
360    /// feature change — no bins, no boundaries, no jumps.
361    ///
362    /// # Arguments
363    ///
364    /// * `features` - Input feature vector.
365    /// * `bandwidth` - Controls transition sharpness. Smaller = sharper
366    ///   (closer to hard splits), larger = smoother.
367    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
368        self.predict_smooth_recursive(self.root, features, bandwidth)
369    }
370
371    /// Predict using per-feature auto-calibrated bandwidths.
372    ///
373    /// Each feature uses its own bandwidth derived from median split threshold
374    /// gaps. Features with `f64::INFINITY` bandwidth fall back to hard routing.
375    pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
376        self.predict_smooth_auto_recursive(self.root, features, bandwidths)
377    }
378
379    /// Predict with parent-leaf linear interpolation.
380    ///
381    /// Routes to the leaf but blends the leaf prediction with the parent node's
382    /// preserved prediction based on the leaf's hessian sum. Fresh leaves
383    /// (low hess_sum) smoothly transition from parent prediction to their own:
384    ///
385    /// `alpha = leaf_hess / (leaf_hess + lambda)`
386    /// `pred = alpha * leaf_pred + (1 - alpha) * parent_pred`
387    ///
388    /// This fixes static predictions from leaves that split but haven't
389    /// accumulated enough samples to outperform their parent.
390    pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
391        let mut current = self.root;
392        let mut parent = None;
393        while !self.arena.is_leaf(current) {
394            parent = Some(current);
395            let feat_idx = self.arena.get_feature_idx(current) as usize;
396            current = if let Some(mask) = self.arena.get_categorical_mask(current) {
397                let cat_val = features[feat_idx] as u64;
398                if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
399                    self.arena.get_left(current)
400                } else {
401                    self.arena.get_right(current)
402                }
403            } else {
404                let threshold = self.arena.get_threshold(current);
405                if features[feat_idx] <= threshold {
406                    self.arena.get_left(current)
407                } else {
408                    self.arena.get_right(current)
409                }
410            };
411        }
412
413        let leaf_pred = self.leaf_prediction(current, features);
414
415        // No parent (root is leaf) → return leaf prediction directly
416        let parent_id = match parent {
417            Some(p) => p,
418            None => return leaf_pred,
419        };
420
421        // Get parent's preserved prediction from its old leaf state
422        let parent_pred = self.leaf_prediction(parent_id, features);
423
424        // Blend: alpha = leaf_hess / (leaf_hess + lambda)
425        let leaf_hess = self
426            .leaf_states
427            .get(current.0 as usize)
428            .and_then(|o| o.as_ref())
429            .map(|s| s.hess_sum)
430            .unwrap_or(0.0);
431
432        let alpha = leaf_hess / (leaf_hess + self.config.lambda);
433        alpha * leaf_pred + (1.0 - alpha) * parent_pred
434    }
435
436    /// Predict with sibling-based interpolation for feature-continuous predictions.
437    ///
438    /// At the leaf's parent split, blends the leaf prediction with its sibling's
439    /// prediction based on the feature's distance from the split threshold:
440    ///
441    /// Within the margin `m` around the threshold:
442    /// `t = (feature - threshold + m) / (2 * m)`  (0 at left edge, 1 at right edge)
443    /// `pred = (1 - t) * left_pred + t * right_pred`
444    ///
445    /// Outside the margin, returns the routed child's prediction directly.
446    /// The margin `m` is derived from auto-bandwidths if available, otherwise
447    /// defaults to `feature_range / n_bins` heuristic per feature.
448    ///
449    /// This makes predictions vary continuously as features move near split
450    /// boundaries, eliminating step-function artifacts.
451    pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
452        self.predict_sibling_recursive(self.root, features, bandwidths)
453    }
454
455    fn predict_sibling_recursive(&self, node: NodeId, features: &[f64], bandwidths: &[f64]) -> f64 {
456        if self.arena.is_leaf(node) {
457            return self.leaf_prediction(node, features);
458        }
459
460        let feat_idx = self.arena.get_feature_idx(node) as usize;
461        let left = self.arena.get_left(node);
462        let right = self.arena.get_right(node);
463
464        // Categorical splits: always hard routing (no interpolation)
465        if let Some(mask) = self.arena.get_categorical_mask(node) {
466            let cat_val = features[feat_idx] as u64;
467            return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
468                self.predict_sibling_recursive(left, features, bandwidths)
469            } else {
470                self.predict_sibling_recursive(right, features, bandwidths)
471            };
472        }
473
474        let threshold = self.arena.get_threshold(node);
475        let margin = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
476
477        // No valid margin or infinite → hard routing
478        if !margin.is_finite() || margin <= 0.0 {
479            return if features[feat_idx] <= threshold {
480                self.predict_sibling_recursive(left, features, bandwidths)
481            } else {
482                self.predict_sibling_recursive(right, features, bandwidths)
483            };
484        }
485
486        let dist = features[feat_idx] - threshold;
487
488        if dist < -margin {
489            // Firmly in left child territory
490            self.predict_sibling_recursive(left, features, bandwidths)
491        } else if dist > margin {
492            // Firmly in right child territory
493            self.predict_sibling_recursive(right, features, bandwidths)
494        } else {
495            // Within the interpolation margin: linear blend
496            let t = (dist + margin) / (2.0 * margin); // 0.0 at left edge, 1.0 at right edge
497            let left_pred = self.predict_sibling_recursive(left, features, bandwidths);
498            let right_pred = self.predict_sibling_recursive(right, features, bandwidths);
499            (1.0 - t) * left_pred + t * right_pred
500        }
501    }
502
503    /// Collect all split thresholds per feature from the tree arena.
504    ///
505    /// Returns a `Vec<Vec<f64>>` indexed by feature, containing all thresholds
506    /// used in continuous splits. Categorical splits are excluded.
507    pub fn collect_split_thresholds_per_feature(&self) -> Vec<Vec<f64>> {
508        let n = self.n_features.unwrap_or(0);
509        let mut thresholds: Vec<Vec<f64>> = vec![Vec::new(); n];
510
511        for i in 0..self.arena.n_nodes() {
512            if !self.arena.is_leaf[i] && self.arena.categorical_mask[i].is_none() {
513                let feat_idx = self.arena.feature_idx[i] as usize;
514                if feat_idx < n {
515                    thresholds[feat_idx].push(self.arena.threshold[i]);
516                }
517            }
518        }
519
520        thresholds
521    }
522
523    /// Compute per-node bandwidth from nearest neighbor thresholds on the same feature.
524    fn compute_node_bandwidth(&self, node: NodeId, all_thresholds: &[Vec<f64>]) -> f64 {
525        let feat_idx = self.arena.get_feature_idx(node) as usize;
526        let threshold = self.arena.get_threshold(node);
527
528        let thresholds = if feat_idx < all_thresholds.len() {
529            &all_thresholds[feat_idx]
530        } else {
531            return f64::INFINITY;
532        };
533
534        // Find nearest neighbors (thresholds are sorted)
535        let below = thresholds.iter().rev().find(|&&t| t < threshold - 1e-15);
536        let above = thresholds.iter().find(|&&t| t > threshold + 1e-15);
537
538        match (below, above) {
539            (Some(&b), Some(&a)) => (threshold - b).min(a - threshold),
540            (Some(&b), None) => threshold - b,
541            (None, Some(&a)) => a - threshold,
542            (None, None) => f64::INFINITY,
543        }
544    }
545
546    /// Recompute all node bandwidths. Call after structural changes.
547    pub fn recompute_bandwidths(&mut self) {
548        let n = self.arena.n_nodes();
549        self.node_bandwidths.resize(n, f64::INFINITY);
550
551        // Collect and sort thresholds once
552        let mut all_thresholds = self.collect_split_thresholds_per_feature();
553        for v in &mut all_thresholds {
554            v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
555        }
556
557        for i in 0..n {
558            let nid = NodeId(i as u32);
559            if !self.arena.is_leaf(nid) {
560                self.node_bandwidths[i] = self.compute_node_bandwidth(nid, &all_thresholds);
561            } else {
562                self.node_bandwidths[i] = f64::INFINITY;
563            }
564        }
565    }
566
567    /// Predict using per-node auto-bandwidth soft routing.
568    /// Every prediction is a continuous weighted blend — no step discontinuities.
569    pub fn predict_soft_routed(&self, features: &[f64]) -> f64 {
570        self.predict_soft_recursive(self.root, features)
571    }
572
573    fn predict_soft_recursive(&self, node: NodeId, features: &[f64]) -> f64 {
574        if self.arena.is_leaf(node) {
575            return self.leaf_prediction(node, features);
576        }
577
578        let feat_idx = self.arena.get_feature_idx(node) as usize;
579        let left = self.arena.get_left(node);
580        let right = self.arena.get_right(node);
581
582        // Categorical: hard routing
583        if let Some(mask) = self.arena.get_categorical_mask(node) {
584            let cat_val = features[feat_idx] as u64;
585            return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
586                self.predict_soft_recursive(left, features)
587            } else {
588                self.predict_soft_recursive(right, features)
589            };
590        }
591
592        let threshold = self.arena.get_threshold(node);
593        let margin = self
594            .node_bandwidths
595            .get(node.0 as usize)
596            .copied()
597            .unwrap_or(f64::INFINITY);
598
599        let left_pred = self.predict_soft_recursive(left, features);
600        let right_pred = self.predict_soft_recursive(right, features);
601
602        // Non-finite or zero margin: sigmoid fallback
603        if !margin.is_finite() || margin <= 0.0 {
604            let dist = features[feat_idx] - threshold;
605            let scale = math::abs(threshold) * 0.01 + 1e-10;
606            let z = (-dist / scale).clamp(-500.0, 500.0);
607            let t = 1.0 / (1.0 + math::exp(z));
608            return (1.0 - t) * left_pred + t * right_pred;
609        }
610
611        // Linear soft routing: always blend
612        let dist = features[feat_idx] - threshold;
613        let t = ((dist + margin) / (2.0 * margin)).clamp(0.0, 1.0);
614        (1.0 - t) * left_pred + t * right_pred
615    }
616
617    /// Recursive sigmoid-blended prediction traversal.
618    fn predict_smooth_recursive(&self, node: NodeId, features: &[f64], bandwidth: f64) -> f64 {
619        if self.arena.is_leaf(node) {
620            // At a leaf, return the leaf prediction (same as regular predict)
621            return self.leaf_prediction(node, features);
622        }
623
624        let feat_idx = self.arena.get_feature_idx(node) as usize;
625        let left = self.arena.get_left(node);
626        let right = self.arena.get_right(node);
627
628        // Categorical splits: hard routing (sigmoid blending is meaningless for categories)
629        if let Some(mask) = self.arena.get_categorical_mask(node) {
630            let cat_val = features[feat_idx] as u64;
631            return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
632                self.predict_smooth_recursive(left, features, bandwidth)
633            } else {
634                self.predict_smooth_recursive(right, features, bandwidth)
635            };
636        }
637
638        // Continuous split: sigmoid blending for smooth transition around the threshold
639        let threshold = self.arena.get_threshold(node);
640        let z = (threshold - features[feat_idx]) / bandwidth;
641        let alpha = 1.0 / (1.0 + math::exp(-z));
642
643        let left_pred = self.predict_smooth_recursive(left, features, bandwidth);
644        let right_pred = self.predict_smooth_recursive(right, features, bandwidth);
645
646        alpha * left_pred + (1.0 - alpha) * right_pred
647    }
648
649    /// Recursive per-feature-bandwidth smooth prediction traversal.
650    fn predict_smooth_auto_recursive(
651        &self,
652        node: NodeId,
653        features: &[f64],
654        bandwidths: &[f64],
655    ) -> f64 {
656        if self.arena.is_leaf(node) {
657            return self.leaf_prediction(node, features);
658        }
659
660        let feat_idx = self.arena.get_feature_idx(node) as usize;
661        let left = self.arena.get_left(node);
662        let right = self.arena.get_right(node);
663
664        // Categorical splits: always hard routing
665        if let Some(mask) = self.arena.get_categorical_mask(node) {
666            let cat_val = features[feat_idx] as u64;
667            return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
668                self.predict_smooth_auto_recursive(left, features, bandwidths)
669            } else {
670                self.predict_smooth_auto_recursive(right, features, bandwidths)
671            };
672        }
673
674        let threshold = self.arena.get_threshold(node);
675        let bw = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
676
677        // Infinite bandwidth = feature never split on across ensemble → hard routing
678        if !bw.is_finite() {
679            return if features[feat_idx] <= threshold {
680                self.predict_smooth_auto_recursive(left, features, bandwidths)
681            } else {
682                self.predict_smooth_auto_recursive(right, features, bandwidths)
683            };
684        }
685
686        // Sigmoid-blended soft routing with per-feature bandwidth
687        let z = (threshold - features[feat_idx]) / bw;
688        let alpha = 1.0 / (1.0 + math::exp(-z));
689
690        let left_pred = self.predict_smooth_auto_recursive(left, features, bandwidths);
691        let right_pred = self.predict_smooth_auto_recursive(right, features, bandwidths);
692
693        alpha * left_pred + (1.0 - alpha) * right_pred
694    }
695
696    /// Attempt a split at the given leaf node.
697    ///
698    /// Returns `true` if a split was performed.
699    pub(crate) fn attempt_split(&mut self, leaf_id: NodeId) -> bool {
700        let depth = self.arena.get_depth(leaf_id);
701
702        // When adaptive_depth is enabled, max_depth * 2 is the hard safety ceiling;
703        // the per-split CIR test handles generalization. Otherwise, use static max_depth.
704        let hard_ceiling = if self.config.adaptive_depth.is_some() {
705            self.config.max_depth.saturating_mul(2)
706        } else {
707            self.config.max_depth
708        };
709        let at_max_depth = depth as usize >= hard_ceiling;
710
711        if at_max_depth {
712            // Only proceed if split re-evaluation is enabled and the interval
713            // has elapsed since the last evaluation at this leaf.
714            match self.config.split_reeval_interval {
715                None => return false,
716                Some(interval) => {
717                    let state = match self
718                        .leaf_states
719                        .get(leaf_id.0 as usize)
720                        .and_then(|o| o.as_ref())
721                    {
722                        Some(s) => s,
723                        None => return false,
724                    };
725                    let sample_count = self.arena.get_sample_count(leaf_id);
726                    if sample_count - state.last_reeval_count < interval as u64 {
727                        return false;
728                    }
729                    // Fall through to evaluate potential split.
730                }
731            }
732        }
733
734        let n_features = match self.n_features {
735            Some(n) => n,
736            None => return false,
737        };
738
739        let sample_count = self.arena.get_sample_count(leaf_id);
740        if sample_count < self.config.grace_period as u64 {
741            return false;
742        }
743
744        // Generate the feature mask for this split evaluation.
745        let (feature_mask, feature_mask_bits) = split_logic::generate_feature_mask(
746            core::mem::take(&mut self.feature_mask),
747            core::mem::take(&mut self.feature_mask_bits),
748            &mut self.rng_state,
749            self.config.feature_subsample_rate,
750            n_features,
751        );
752        self.feature_mask = feature_mask;
753        self.feature_mask_bits = feature_mask_bits;
754
755        // Materialize pending lazy decay before reading histogram data.
756        if self.config.leaf_decay_alpha.is_some() {
757            if let Some(state) = self
758                .leaf_states
759                .get_mut(leaf_id.0 as usize)
760                .and_then(|o| o.as_mut())
761            {
762                if let Some(ref mut histograms) = state.histograms {
763                    histograms.materialize_decay();
764                }
765            }
766        }
767
768        // Evaluate splits for each feature in the mask.
769        let state = match self
770            .leaf_states
771            .get(leaf_id.0 as usize)
772            .and_then(|o| o.as_ref())
773        {
774            Some(s) => s,
775            None => return false,
776        };
777
778        let histograms = match &state.histograms {
779            Some(h) => h,
780            None => return false,
781        };
782
783        let ctx = split_logic::private::SplitContext {
784            config: &self.config,
785            n_features: self.n_features,
786            n_feature_mask: &self.feature_mask,
787            split_criterion: &self.split_criterion,
788            rng_state: &mut self.rng_state,
789        };
790
791        let candidates = split_logic::private::evaluate_split_candidates(
792            histograms,
793            self.config.feature_types.as_deref(),
794            &ctx,
795        );
796
797        if candidates.is_empty() {
798            return false;
799        }
800
801        let best_gain = candidates[0].1.gain;
802        let second_best_gain = if candidates.len() > 1 {
803            candidates[1].1.gain
804        } else {
805            0.0
806        };
807
808        // Check Hoeffding bound and adaptive depth.
809        let ctx = split_logic::private::SplitContext {
810            config: &self.config,
811            n_features: self.n_features,
812            n_feature_mask: &self.feature_mask,
813            split_criterion: &self.split_criterion,
814            rng_state: &mut self.rng_state,
815        };
816
817        if !split_logic::private::should_split_hoeffding(
818            best_gain,
819            second_best_gain,
820            sample_count,
821            &ctx,
822        ) {
823            if at_max_depth {
824                if let Some(state) = self
825                    .leaf_states
826                    .get_mut(leaf_id.0 as usize)
827                    .and_then(|o| o.as_mut())
828                {
829                    state.last_reeval_count = sample_count;
830                }
831            }
832            return false;
833        }
834
835        // --- Execute the split ---
836        let (best_feat_idx, ref best_candidate, ref fisher_order) = candidates[0];
837
838        // Track split gain for feature importance.
839        if best_feat_idx < self.split_gains.len() {
840            self.split_gains[best_feat_idx] += best_candidate.gain;
841        }
842
843        let best_hist = &histograms.histograms[best_feat_idx];
844
845        let left_value = leaf_weight(
846            best_candidate.left_grad,
847            best_candidate.left_hess,
848            self.config.lambda,
849        );
850        let right_value = leaf_weight(
851            best_candidate.right_grad,
852            best_candidate.right_hess,
853            self.config.lambda,
854        );
855
856        // Perform the split -- categorical or continuous.
857        let (left_id, right_id) = if let Some(ref order) = fisher_order {
858            let mut mask: u64 = 0;
859            for &sorted_pos in order.iter().take(best_candidate.bin_idx + 1) {
860                if sorted_pos < 64 {
861                    mask |= 1u64 << sorted_pos;
862                }
863            }
864
865            self.arena.split_leaf_categorical(
866                leaf_id,
867                best_feat_idx as u32,
868                0.0,
869                left_value,
870                right_value,
871                mask,
872            )
873        } else {
874            let threshold = if best_candidate.bin_idx < best_hist.edges.edges.len() {
875                best_hist.edges.edges[best_candidate.bin_idx]
876            } else {
877                f64::MAX
878            };
879
880            self.arena.split_leaf(
881                leaf_id,
882                best_feat_idx as u32,
883                threshold,
884                left_value,
885                right_value,
886            )
887        };
888
889        let parent_state = self
890            .leaf_states
891            .get_mut(leaf_id.0 as usize)
892            .and_then(|o| o.take());
893        let nf = n_features;
894
895        // Ensure Vec is large enough for child NodeIds.
896        let max_child = left_id.0.max(right_id.0) as usize;
897        if self.leaf_states.len() <= max_child {
898            self.leaf_states.resize_with(max_child + 1, || None);
899        }
900
901        if let Some(parent) = parent_state {
902            if let Some(parent_hists) = parent.histograms {
903                let edges_per_feature: Vec<crate::histogram::BinEdges> = parent_hists
904                    .histograms
905                    .iter()
906                    .map(|h| h.edges.clone())
907                    .collect();
908
909                let left_hists = LeafHistograms::new(&edges_per_feature);
910                let right_hists = LeafHistograms::new(&edges_per_feature);
911
912                let ft = self.config.feature_types.as_deref();
913                let child_binners_l = make_binners(nf, ft);
914                let child_binners_r = make_binners(nf, ft);
915
916                let left_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
917                let right_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
918
919                let left_state = LeafState {
920                    histograms: Some(left_hists),
921                    binners: child_binners_l,
922                    bins_ready: true,
923                    grad_sum: 0.0,
924                    hess_sum: 0.0,
925                    last_reeval_count: 0,
926                    clip_grad_mean: 0.0,
927                    clip_grad_m2: 0.0,
928                    clip_grad_count: 0,
929                    output_mean: 0.0,
930                    output_m2: 0.0,
931                    output_count: 0,
932                    leaf_model: left_model,
933                };
934
935                let right_state = LeafState {
936                    histograms: Some(right_hists),
937                    binners: child_binners_r,
938                    bins_ready: true,
939                    grad_sum: 0.0,
940                    hess_sum: 0.0,
941                    last_reeval_count: 0,
942                    clip_grad_mean: 0.0,
943                    clip_grad_m2: 0.0,
944                    clip_grad_count: 0,
945                    output_mean: 0.0,
946                    output_m2: 0.0,
947                    output_count: 0,
948                    leaf_model: right_model,
949                };
950
951                self.leaf_states[left_id.0 as usize] = Some(left_state);
952                self.leaf_states[right_id.0 as usize] = Some(right_state);
953            } else {
954                let ft = self.config.feature_types.as_deref();
955                let mut ls = LeafState::new_with_types(nf, ft);
956                ls.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
957                self.leaf_states[left_id.0 as usize] = Some(ls);
958                let mut rs = LeafState::new_with_types(nf, ft);
959                rs.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
960                self.leaf_states[right_id.0 as usize] = Some(rs);
961            }
962        } else {
963            let ft = self.config.feature_types.as_deref();
964            let mut ls = LeafState::new_with_types(nf, ft);
965            ls.leaf_model = self.make_leaf_model(left_id);
966            self.leaf_states[left_id.0 as usize] = Some(ls);
967            let mut rs = LeafState::new_with_types(nf, ft);
968            rs.leaf_model = self.make_leaf_model(right_id);
969            self.leaf_states[right_id.0 as usize] = Some(rs);
970        }
971
972        self.recompute_bandwidths();
973        true
974    }
975}
976
977impl StreamingTree for HoeffdingTree {
978    fn train_one(&mut self, features: &[f64], gradient: f64, hessian: f64) {
979        self.samples_seen += 1;
980
981        let n_features = if let Some(n) = self.n_features {
982            n
983        } else {
984            let n = features.len();
985            self.n_features = Some(n);
986            self.split_gains.resize(n, 0.0);
987
988            if let Some(state) = self
989                .leaf_states
990                .get_mut(self.root.0 as usize)
991                .and_then(|o| o.as_mut())
992            {
993                state.binners = make_binners(n, self.config.feature_types.as_deref());
994            }
995            n
996        };
997
998        debug_assert_eq!(
999            features.len(),
1000            n_features,
1001            "feature count mismatch: got {} but expected {}",
1002            features.len(),
1003            n_features,
1004        );
1005
1006        let leaf_id = self.route_to_leaf(features);
1007        self.arena.increment_sample_count(leaf_id);
1008        let sample_count = self.arena.get_sample_count(leaf_id);
1009
1010        let idx = leaf_id.0 as usize;
1011        if self.leaf_states.len() <= idx {
1012            self.leaf_states.resize_with(idx + 1, || None);
1013        }
1014        if self.leaf_states[idx].is_none() {
1015            self.leaf_states[idx] = Some(LeafState::new_with_types(
1016                n_features,
1017                self.config.feature_types.as_deref(),
1018            ));
1019        }
1020        let state = self.leaf_states[idx].as_mut().unwrap();
1021
1022        let gradient = if let Some(sigma) = self.config.gradient_clip_sigma {
1023            clip_gradient(state, gradient, sigma)
1024        } else {
1025            gradient
1026        };
1027
1028        if !state.bins_ready {
1029            for (binner, &val) in state.binners.iter_mut().zip(features.iter()) {
1030                binner.observe(val);
1031            }
1032
1033            if let Some(alpha) = self.config.leaf_decay_alpha {
1034                state.grad_sum = alpha * state.grad_sum + gradient;
1035                state.hess_sum = alpha * state.hess_sum + hessian;
1036            } else {
1037                state.grad_sum += gradient;
1038                state.hess_sum += hessian;
1039            }
1040
1041            let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
1042            self.arena.set_leaf_value(leaf_id, lw);
1043
1044            if self.config.adaptive_leaf_bound.is_some() {
1045                update_output_stats(state, lw, self.config.leaf_decay_alpha);
1046            }
1047
1048            if let Some(ref mut model) = state.leaf_model {
1049                model.update(features, gradient, hessian, self.config.lambda);
1050            }
1051
1052            if sample_count >= self.config.grace_period as u64 {
1053                let edges_per_feature: Vec<crate::histogram::BinEdges> = state
1054                    .binners
1055                    .iter()
1056                    .map(|b| b.compute_edges(self.config.n_bins))
1057                    .collect();
1058
1059                let mut histograms = LeafHistograms::new(&edges_per_feature);
1060
1061                if let Some(alpha) = self.config.leaf_decay_alpha {
1062                    histograms.accumulate_with_decay(features, gradient, hessian, alpha);
1063                } else {
1064                    histograms.accumulate(features, gradient, hessian);
1065                }
1066
1067                state.histograms = Some(histograms);
1068                state.bins_ready = true;
1069            }
1070
1071            return;
1072        }
1073
1074        if let Some(ref mut histograms) = state.histograms {
1075            if let Some(alpha) = self.config.leaf_decay_alpha {
1076                histograms.accumulate_with_decay(features, gradient, hessian, alpha);
1077            } else {
1078                histograms.accumulate(features, gradient, hessian);
1079            }
1080        }
1081
1082        if let Some(alpha) = self.config.leaf_decay_alpha {
1083            state.grad_sum = alpha * state.grad_sum + gradient;
1084            state.hess_sum = alpha * state.hess_sum + hessian;
1085        } else {
1086            state.grad_sum += gradient;
1087            state.hess_sum += hessian;
1088        }
1089        let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
1090        self.arena.set_leaf_value(leaf_id, lw);
1091
1092        if self.config.adaptive_leaf_bound.is_some() {
1093            update_output_stats(state, lw, self.config.leaf_decay_alpha);
1094        }
1095
1096        if let Some(ref mut model) = state.leaf_model {
1097            model.update(features, gradient, hessian, self.config.lambda);
1098        }
1099
1100        if sample_count % (self.config.grace_period as u64) == 0 {
1101            self.attempt_split(leaf_id);
1102        }
1103    }
1104
1105    fn predict(&self, features: &[f64]) -> f64 {
1106        let leaf_id = self.route_to_leaf(features);
1107        self.leaf_prediction(leaf_id, features)
1108    }
1109
1110    #[inline]
1111    fn n_leaves(&self) -> usize {
1112        self.arena.n_leaves()
1113    }
1114
1115    #[inline]
1116    fn n_samples_seen(&self) -> u64 {
1117        self.samples_seen
1118    }
1119
1120    fn reset(&mut self) {
1121        self.arena.reset();
1122        let root = self.arena.add_leaf(0);
1123        self.root = root;
1124        self.leaf_states.clear();
1125
1126        let n_features = self.n_features.unwrap_or(0);
1127        self.leaf_states.resize_with(root.0 as usize + 1, || None);
1128        let mut root_state =
1129            LeafState::new_with_types(n_features, self.config.feature_types.as_deref());
1130        root_state.leaf_model = self.make_leaf_model(root);
1131        self.leaf_states[root.0 as usize] = Some(root_state);
1132
1133        self.samples_seen = 0;
1134        self.feature_mask.clear();
1135        self.feature_mask_bits.clear();
1136        self.rng_state = self.config.seed;
1137        self.split_gains.iter_mut().for_each(|g| *g = 0.0);
1138        self.node_bandwidths.clear();
1139    }
1140
1141    fn split_gains(&self) -> &[f64] {
1142        &self.split_gains
1143    }
1144
1145    fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
1146        let leaf_id = self.route_to_leaf(features);
1147        let value = self.leaf_prediction(leaf_id, features);
1148        if let Some(state) = self
1149            .leaf_states
1150            .get(leaf_id.0 as usize)
1151            .and_then(|o| o.as_ref())
1152        {
1153            let variance = 1.0 / (state.hess_sum + self.config.lambda);
1154            (value, variance)
1155        } else {
1156            (value, f64::INFINITY)
1157        }
1158    }
1159}
1160
1161impl Clone for HoeffdingTree {
1162    fn clone(&self) -> Self {
1163        Self {
1164            arena: self.arena.clone(),
1165            root: self.root,
1166            config: self.config.clone(),
1167            leaf_states: self.leaf_states.clone(),
1168            n_features: self.n_features,
1169            samples_seen: self.samples_seen,
1170            split_criterion: self.split_criterion,
1171            feature_mask: self.feature_mask.clone(),
1172            feature_mask_bits: self.feature_mask_bits.clone(),
1173            rng_state: self.rng_state,
1174            split_gains: self.split_gains.clone(),
1175            node_bandwidths: self.node_bandwidths.clone(),
1176        }
1177    }
1178}
1179
1180// SAFETY: All fields are Send + Sync. BinnerKind is a concrete enum with
1181// Send + Sync variants. XGBoostGain is stateless. Vec<Option<LeafState>>
1182// and Vec fields are trivially Send + Sync.
1183unsafe impl Send for HoeffdingTree {}
1184unsafe impl Sync for HoeffdingTree {}
1185
1186#[cfg(test)]
1187mod tests {
1188    use super::*;
1189    use crate::tree::builder::TreeConfig;
1190    use crate::tree::StreamingTree;
1191
1192    #[test]
1193    fn single_sample_predict_not_nan() {
1194        let config = TreeConfig::new().grace_period(10);
1195        let mut tree = HoeffdingTree::new(config);
1196
1197        let features = vec![1.0, 2.0, 3.0];
1198        tree.train_one(&features, -0.5, 1.0);
1199
1200        let pred = tree.predict(&features);
1201        assert!(!pred.is_nan(), "prediction should not be NaN, got {}", pred);
1202        assert!(
1203            pred.is_finite(),
1204            "prediction should be finite, got {}",
1205            pred
1206        );
1207
1208        assert!((pred - 0.25).abs() < 1e-10, "expected ~0.25, got {}", pred);
1209    }
1210}