Skip to main content

irithyll_core/tree/
hoeffding.rs

1//! Hoeffding-bound split decisions for streaming tree construction.
2//!
3//! [`HoeffdingTree`] is the core streaming decision tree. It grows incrementally:
4//! each sample updates per-leaf histogram accumulators, and splits are committed
5//! only when the Hoeffding bound guarantees the best candidate split is
6//! statistically superior to the runner-up (or a tie-breaking threshold is met).
7//!
8//! # Algorithm
9//!
10//! For each incoming `(features, gradient, hessian)` triple:
11//!
12//! 1. Route the sample from root to a leaf via threshold comparisons.
13//! 2. At the leaf, accumulate gradient/hessian into per-feature histograms.
14//! 3. Once enough samples arrive (grace period), evaluate candidate splits
15//!    using the XGBoost gain formula.
16//! 4. Apply the Hoeffding bound: if the gap between the best and second-best
17//!    gain exceeds `epsilon = sqrt(R^2 * ln(1/delta) / (2n))`, commit the split.
18//! 5. When splitting, use the histogram subtraction trick to initialize one
19//!    child's histograms for free.
20
21use alloc::boxed::Box;
22use alloc::vec;
23use alloc::vec::Vec;
24
25use crate::feature::FeatureType;
26use crate::histogram::bins::LeafHistograms;
27use crate::histogram::{BinEdges, BinnerKind};
28use crate::math;
29use crate::tree::builder::TreeConfig;
30use crate::tree::leaf_model::{LeafModel, LeafModelType};
31use crate::tree::node::{NodeId, TreeArena};
32use crate::tree::split::{leaf_weight, SplitCandidate, SplitCriterion, XGBoostGain};
33use crate::tree::StreamingTree;
34
35/// Tie-breaking threshold (tau). When `epsilon < tau`, we accept the best split
36/// even if the gap between best and second-best gain is small, because the
37/// Hoeffding bound is already tight enough that further samples won't help.
38const TAU: f64 = 0.05;
39
40// ---------------------------------------------------------------------------
41// xorshift64 -- minimal deterministic PRNG
42// ---------------------------------------------------------------------------
43
44/// Advance an xorshift64 state and return the new value.
45#[inline]
46fn xorshift64(state: &mut u64) -> u64 {
47    let mut s = *state;
48    s ^= s << 13;
49    s ^= s >> 7;
50    s ^= s << 17;
51    *state = s;
52    s
53}
54
55// ---------------------------------------------------------------------------
56// LeafState -- per-leaf bookkeeping
57// ---------------------------------------------------------------------------
58
59/// State tracked per leaf node for split decisions.
60///
61/// Each active leaf owns its own set of histogram accumulators (one per feature)
62/// and running gradient/hessian sums for leaf weight updates.
63struct LeafState {
64    /// Histogram accumulators for this leaf. `None` until bin edges are computed
65    /// (after the grace period).
66    histograms: Option<LeafHistograms>,
67
68    /// Per-feature binning strategies that collect observed values to compute
69    /// bin edges. Uses `BinnerKind` enum dispatch instead of `Box<dyn>` to
70    /// eliminate N_features heap allocations per new leaf.
71    binners: Vec<BinnerKind>,
72
73    /// Whether bin edges have been computed (after grace period samples).
74    bins_ready: bool,
75
76    /// Running gradient sum for leaf weight updates.
77    grad_sum: f64,
78
79    /// Running hessian sum for leaf weight updates.
80    hess_sum: f64,
81
82    /// Sample count at last split re-evaluation (for EFDT-inspired re-eval).
83    last_reeval_count: u64,
84
85    /// EWMA gradient mean for gradient clipping (Welford online algorithm).
86    clip_grad_mean: f64,
87
88    /// EWMA gradient M2 accumulator (Welford) for variance estimation.
89    clip_grad_m2: f64,
90
91    /// Number of gradients observed for clipping statistics.
92    clip_grad_count: u64,
93
94    /// EWMA/Welford mean of this leaf's output weight (for adaptive bounds).
95    output_mean: f64,
96
97    /// EWMA/Welford M2 (variance accumulator) of this leaf's output weight.
98    output_m2: f64,
99
100    /// Number of output weight observations.
101    output_count: u64,
102
103    /// Optional trainable leaf model (linear / MLP). `None` for closed-form leaves.
104    leaf_model: Option<Box<dyn LeafModel>>,
105}
106
107impl Clone for LeafState {
108    fn clone(&self) -> Self {
109        Self {
110            histograms: self.histograms.clone(),
111            binners: self.binners.clone(),
112            bins_ready: self.bins_ready,
113            grad_sum: self.grad_sum,
114            hess_sum: self.hess_sum,
115            last_reeval_count: self.last_reeval_count,
116            clip_grad_mean: self.clip_grad_mean,
117            clip_grad_m2: self.clip_grad_m2,
118            clip_grad_count: self.clip_grad_count,
119            output_mean: self.output_mean,
120            output_m2: self.output_m2,
121            output_count: self.output_count,
122            leaf_model: self.leaf_model.as_ref().map(|m| m.clone_warm()),
123        }
124    }
125}
126
127/// Clip a gradient using Welford online stats tracked per leaf.
128///
129/// Updates running mean/variance, then clamps the gradient to `mean ± sigma * std_dev`.
130/// Returns the (possibly clamped) gradient. During warmup (< 10 samples), no clipping
131/// is applied to let the statistics stabilize.
132#[inline]
133fn clip_gradient(state: &mut LeafState, gradient: f64, sigma: f64) -> f64 {
134    state.clip_grad_count += 1;
135    let n = state.clip_grad_count as f64;
136
137    // Welford online update
138    let delta = gradient - state.clip_grad_mean;
139    state.clip_grad_mean += delta / n;
140    let delta2 = gradient - state.clip_grad_mean;
141    state.clip_grad_m2 += delta * delta2;
142
143    // No clipping during warmup
144    if state.clip_grad_count < 10 {
145        return gradient;
146    }
147
148    let variance = state.clip_grad_m2 / (n - 1.0);
149    let std_dev = math::sqrt(variance);
150
151    if std_dev < 1e-15 {
152        return gradient; // All gradients identical -- no clipping needed
153    }
154
155    let lo = state.clip_grad_mean - sigma * std_dev;
156    let hi = state.clip_grad_mean + sigma * std_dev;
157    gradient.clamp(lo, hi)
158}
159
160/// Update per-leaf output weight tracking for adaptive bounds.
161///
162/// If `decay_alpha` is Some, uses EWMA synchronized with leaf_decay_alpha.
163/// Otherwise uses Welford online algorithm (batch scenarios).
164#[inline]
165fn update_output_stats(state: &mut LeafState, weight: f64, decay_alpha: Option<f64>) {
166    state.output_count += 1;
167
168    if let Some(alpha) = decay_alpha {
169        // EWMA — synchronized with leaf gradient decay
170        if state.output_count == 1 {
171            state.output_mean = weight;
172            state.output_m2 = 0.0;
173        } else {
174            let diff = weight - state.output_mean;
175            state.output_mean = alpha * state.output_mean + (1.0 - alpha) * weight;
176            let diff2 = weight - state.output_mean;
177            state.output_m2 = alpha * state.output_m2 + (1.0 - alpha) * diff * diff2;
178        }
179    } else {
180        // Welford online (no decay — batch scenarios)
181        let delta = weight - state.output_mean;
182        state.output_mean += delta / (state.output_count as f64);
183        let delta2 = weight - state.output_mean;
184        state.output_m2 += delta * delta2;
185    }
186}
187
188/// Get the adaptive output bound for this leaf.
189///
190/// Returns `|mean| + k * std`, with a floor of 0.01 to never fully suppress a leaf.
191/// During warmup (< 10 samples), returns `f64::MAX` (no bound).
192#[inline]
193fn adaptive_bound(state: &LeafState, k: f64, decay_alpha: Option<f64>) -> f64 {
194    if state.output_count < 10 {
195        return f64::MAX; // warmup — no bound yet
196    }
197
198    let variance = if decay_alpha.is_some() {
199        // EWMA variance is stored directly in output_m2
200        state.output_m2.max(0.0)
201    } else {
202        // Welford: variance = M2 / (n - 1)
203        state.output_m2 / (state.output_count as f64 - 1.0)
204    };
205    let std = math::sqrt(variance);
206
207    // Bound = |mean| + k * std, floor at 0.01 to never fully suppress a leaf
208    (math::abs(state.output_mean) + k * std).max(0.01)
209}
210
211/// Create binners according to feature types.
212fn make_binners(n_features: usize, feature_types: Option<&[FeatureType]>) -> Vec<BinnerKind> {
213    (0..n_features)
214        .map(|i| {
215            if let Some(ft) = feature_types {
216                if i < ft.len() && ft[i] == FeatureType::Categorical {
217                    return BinnerKind::categorical();
218                }
219            }
220            BinnerKind::uniform()
221        })
222        .collect()
223}
224
225impl LeafState {
226    /// Create a fresh leaf state for a leaf with `n_features` features.
227    fn new(n_features: usize) -> Self {
228        Self::new_with_types(n_features, None)
229    }
230
231    /// Create a fresh leaf state respecting per-feature type declarations.
232    fn new_with_types(n_features: usize, feature_types: Option<&[FeatureType]>) -> Self {
233        let binners = make_binners(n_features, feature_types);
234
235        Self {
236            histograms: None,
237            binners,
238            bins_ready: false,
239            grad_sum: 0.0,
240            hess_sum: 0.0,
241            last_reeval_count: 0,
242            clip_grad_mean: 0.0,
243            clip_grad_m2: 0.0,
244            clip_grad_count: 0,
245            output_mean: 0.0,
246            output_m2: 0.0,
247            output_count: 0,
248            leaf_model: None,
249        }
250    }
251
252    /// Create a leaf state with pre-computed histograms (used after a split
253    /// when we can initialize the child via histogram subtraction).
254    #[allow(dead_code)]
255    fn with_histograms(histograms: LeafHistograms) -> Self {
256        let n_features = histograms.n_features();
257        let binners: Vec<BinnerKind> = (0..n_features).map(|_| BinnerKind::uniform()).collect();
258
259        // Recover grad/hess sums from the histograms.
260        let grad_sum: f64 = histograms
261            .histograms
262            .first()
263            .map_or(0.0, |h| h.total_gradient());
264        let hess_sum: f64 = histograms
265            .histograms
266            .first()
267            .map_or(0.0, |h| h.total_hessian());
268
269        Self {
270            histograms: Some(histograms),
271            binners,
272            bins_ready: true,
273            grad_sum,
274            hess_sum,
275            last_reeval_count: 0,
276            clip_grad_mean: 0.0,
277            clip_grad_m2: 0.0,
278            clip_grad_count: 0,
279            output_mean: 0.0,
280            output_m2: 0.0,
281            output_count: 0,
282            leaf_model: None,
283        }
284    }
285}
286
287// ---------------------------------------------------------------------------
288// HoeffdingTree
289// ---------------------------------------------------------------------------
290
291/// A streaming decision tree that uses Hoeffding-bound split decisions.
292///
293/// The tree grows incrementally: each call to [`train_one`](StreamingTree::train_one)
294/// routes one sample to its leaf, updates histograms, and potentially triggers
295/// a split when statistical evidence is sufficient.
296///
297/// # Feature subsampling
298///
299/// When `config.feature_subsample_rate < 1.0`, each split evaluation considers
300/// only a random subset of features (selected via a deterministic xorshift64 RNG).
301/// This adds diversity when the tree is used inside an ensemble.
302pub struct HoeffdingTree {
303    /// Arena-allocated node storage.
304    arena: TreeArena,
305
306    /// Root node identifier.
307    root: NodeId,
308
309    /// Tree configuration / hyperparameters.
310    config: TreeConfig,
311
312    /// Per-leaf state indexed by `NodeId.0`. Dense Vec -- NodeIds are
313    /// contiguous u32 indices from TreeArena, so direct indexing is optimal.
314    leaf_states: Vec<Option<LeafState>>,
315
316    /// Number of features, learned from the first sample.
317    n_features: Option<usize>,
318
319    /// Total samples seen across all calls to `train_one`.
320    samples_seen: u64,
321
322    /// Split gain evaluator.
323    split_criterion: XGBoostGain,
324
325    /// Scratch buffer for the feature mask (avoids repeated allocation).
326    feature_mask: Vec<usize>,
327
328    /// Bitset scratch buffer for O(1) membership test during feature mask generation.
329    /// Each bit `i` indicates whether feature `i` is already in `feature_mask`.
330    feature_mask_bits: Vec<u64>,
331
332    /// xorshift64 RNG state for feature subsampling.
333    rng_state: u64,
334
335    /// Accumulated split gains per feature for importance tracking.
336    /// Indexed by feature index; grows lazily when n_features is learned.
337    split_gains: Vec<f64>,
338}
339
340impl HoeffdingTree {
341    /// Create a new `HoeffdingTree` with the given configuration.
342    ///
343    /// The tree starts with a single root leaf and no feature information;
344    /// the number of features is inferred from the first training sample.
345    pub fn new(config: TreeConfig) -> Self {
346        let mut arena = TreeArena::new();
347        let root = arena.add_leaf(0);
348
349        // Insert a placeholder leaf state for the root. We don't know n_features
350        // yet, so give it 0 binners -- it will be properly initialized on the
351        // first sample.
352        let mut leaf_states = vec![None; root.0 as usize + 1];
353        let root_model = match config.leaf_model_type {
354            LeafModelType::ClosedForm => None,
355            _ => Some(config.leaf_model_type.create(config.seed, config.delta)),
356        };
357        leaf_states[root.0 as usize] = Some(LeafState {
358            histograms: None,
359            binners: Vec::new(),
360            bins_ready: false,
361            grad_sum: 0.0,
362            hess_sum: 0.0,
363            last_reeval_count: 0,
364            clip_grad_mean: 0.0,
365            clip_grad_m2: 0.0,
366            clip_grad_count: 0,
367            output_mean: 0.0,
368            output_m2: 0.0,
369            output_count: 0,
370            leaf_model: root_model,
371        });
372
373        let seed = config.seed;
374        Self {
375            arena,
376            root,
377            config,
378            leaf_states,
379            n_features: None,
380            samples_seen: 0,
381            split_criterion: XGBoostGain::default(),
382            feature_mask: Vec::new(),
383            feature_mask_bits: Vec::new(),
384            rng_state: seed,
385            split_gains: Vec::new(),
386        }
387    }
388
389    /// Create a leaf model for a new leaf if the config requires one.
390    ///
391    /// Returns `None` for `ClosedForm` (the default), which uses the existing
392    /// `leaf_weight()` path with zero overhead. For `Linear` and `MLP`, returns
393    /// a fresh model seeded deterministically from the config seed and node id.
394    fn make_leaf_model(&self, node: NodeId) -> Option<Box<dyn LeafModel>> {
395        match self.config.leaf_model_type {
396            LeafModelType::ClosedForm => None,
397            _ => Some(
398                self.config
399                    .leaf_model_type
400                    .create(self.config.seed ^ (node.0 as u64), self.config.delta),
401            ),
402        }
403    }
404
405    /// Reconstruct a `HoeffdingTree` from a pre-built arena.
406    ///
407    /// Used during model deserialization. The tree is restored with node
408    /// topology and leaf values intact, but histogram accumulators are empty
409    /// (they will rebuild naturally from continued training).
410    ///
411    /// The root is assumed to be `NodeId(0)`. Leaf states are created empty
412    /// for all current leaf nodes in the arena.
413    pub fn from_arena(
414        config: TreeConfig,
415        arena: TreeArena,
416        n_features: Option<usize>,
417        samples_seen: u64,
418        rng_state: u64,
419    ) -> Self {
420        let root = if arena.n_nodes() > 0 {
421            NodeId(0)
422        } else {
423            // Empty arena -- add a root leaf (shouldn't normally happen in restore).
424            let mut arena_mut = arena;
425            let root = arena_mut.add_leaf(0);
426            return Self {
427                arena: arena_mut,
428                root,
429                config: config.clone(),
430                leaf_states: {
431                    let mut v = vec![None; root.0 as usize + 1];
432                    v[root.0 as usize] = Some(LeafState::new(n_features.unwrap_or(0)));
433                    v
434                },
435                n_features,
436                samples_seen,
437                split_criterion: XGBoostGain::default(),
438                feature_mask: Vec::new(),
439                feature_mask_bits: Vec::new(),
440                rng_state,
441                split_gains: vec![0.0; n_features.unwrap_or(0)],
442            };
443        };
444
445        // Build leaf states for every leaf in the arena.
446        let nf = n_features.unwrap_or(0);
447        let mut leaf_states: Vec<Option<LeafState>> = vec![None; arena.n_nodes()];
448        for (i, slot) in leaf_states.iter_mut().enumerate() {
449            if arena.is_leaf[i] {
450                *slot = Some(LeafState::new(nf));
451            }
452        }
453
454        Self {
455            arena,
456            root,
457            config,
458            leaf_states,
459            n_features,
460            samples_seen,
461            split_criterion: XGBoostGain::default(),
462            feature_mask: Vec::new(),
463            feature_mask_bits: Vec::new(),
464            rng_state,
465            split_gains: vec![0.0; nf],
466        }
467    }
468
469    /// Root node identifier.
470    #[inline]
471    pub fn root(&self) -> NodeId {
472        self.root
473    }
474
475    /// Immutable access to the underlying arena.
476    #[inline]
477    pub fn arena(&self) -> &TreeArena {
478        &self.arena
479    }
480
481    /// Immutable access to the tree configuration.
482    #[inline]
483    pub fn tree_config(&self) -> &TreeConfig {
484        &self.config
485    }
486
487    /// Number of features (learned from the first sample, `None` before any training).
488    #[inline]
489    pub fn n_features(&self) -> Option<usize> {
490        self.n_features
491    }
492
493    /// Current RNG state (for deterministic checkpoint/restore).
494    #[inline]
495    pub fn rng_state(&self) -> u64 {
496        self.rng_state
497    }
498
499    /// Read-only access to the gradient and hessian sums for a leaf node.
500    ///
501    /// Returns `Some((grad_sum, hess_sum))` if `node` is a leaf with an active
502    /// leaf state, or `None` if the node has no state (e.g. internal node
503    /// or freshly allocated).
504    ///
505    /// These sums enable inverse-hessian confidence estimation:
506    /// `confidence = 1.0 / (hess_sum + lambda)`. High hessian means the leaf
507    /// has seen consistent, informative data; low hessian means uncertainty.
508    #[inline]
509    pub fn leaf_grad_hess(&self, node: NodeId) -> Option<(f64, f64)> {
510        self.leaf_states
511            .get(node.0 as usize)
512            .and_then(|o| o.as_ref())
513            .map(|state| (state.grad_sum, state.hess_sum))
514    }
515
516    /// Route a feature vector from the root down to a leaf, returning the leaf's NodeId.
517    fn route_to_leaf(&self, features: &[f64]) -> NodeId {
518        let mut current = self.root;
519        while !self.arena.is_leaf(current) {
520            let feat_idx = self.arena.get_feature_idx(current) as usize;
521            current = if let Some(mask) = self.arena.get_categorical_mask(current) {
522                // Categorical split: use bitmask routing.
523                // The feature value is cast to a bin index. If that bin's bit is set
524                // in the mask, go left; otherwise go right.
525                // For categorical features, the bin index in the histogram corresponds
526                // to the sorted category position, but for bitmask routing we use
527                // the original bin index directly.
528                let cat_val = features[feat_idx] as u64;
529                if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
530                    self.arena.get_left(current)
531                } else {
532                    self.arena.get_right(current)
533                }
534            } else {
535                // Continuous split: standard threshold comparison.
536                let threshold = self.arena.get_threshold(current);
537                if features[feat_idx] <= threshold {
538                    self.arena.get_left(current)
539                } else {
540                    self.arena.get_right(current)
541                }
542            };
543        }
544        current
545    }
546
547    /// Get the prediction value for a leaf node.
548    ///
549    /// Checks (in order): leaf model, live grad/hess statistics, stored leaf value.
550    /// Returns `0.0` if no leaf state exists.
551    #[inline]
552    fn leaf_prediction(&self, leaf_id: NodeId, features: &[f64]) -> f64 {
553        let (raw, leaf_bound) = if let Some(state) = self
554            .leaf_states
555            .get(leaf_id.0 as usize)
556            .and_then(|o| o.as_ref())
557        {
558            // min_hessian_sum: suppress fresh leaves with insufficient samples
559            if let Some(min_h) = self.config.min_hessian_sum {
560                if state.hess_sum < min_h {
561                    return 0.0;
562                }
563            }
564            let val = if let Some(ref model) = state.leaf_model {
565                model.predict(features)
566            } else if state.hess_sum != 0.0 {
567                leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda)
568            } else {
569                self.arena.leaf_value[leaf_id.0 as usize]
570            };
571
572            // Compute per-leaf adaptive bound while state is in scope
573            let bound = self
574                .config
575                .adaptive_leaf_bound
576                .map(|k| adaptive_bound(state, k, self.config.leaf_decay_alpha));
577
578            (val, bound)
579        } else {
580            (0.0, None)
581        };
582
583        // Priority: per-leaf adaptive bound > global max_leaf_output > unclamped
584        if let Some(bound) = leaf_bound {
585            if bound < f64::MAX {
586                return raw.clamp(-bound, bound);
587            }
588        }
589        if let Some(max) = self.config.max_leaf_output {
590            raw.clamp(-max, max)
591        } else {
592            raw
593        }
594    }
595
596    /// Predict using sigmoid-blended soft routing for smooth interpolation.
597    ///
598    /// Instead of hard left/right routing at each split node, uses sigmoid
599    /// blending: `alpha = sigmoid((threshold - feature) / bandwidth)`. The
600    /// prediction is `alpha * left_pred + (1 - alpha) * right_pred`, computed
601    /// recursively from root to leaves.
602    ///
603    /// The result is a continuous function that varies smoothly with every
604    /// feature change — no bins, no boundaries, no jumps.
605    ///
606    /// # Arguments
607    ///
608    /// * `features` - Input feature vector.
609    /// * `bandwidth` - Controls transition sharpness. Smaller = sharper
610    ///   (closer to hard splits), larger = smoother.
611    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
612        self.predict_smooth_recursive(self.root, features, bandwidth)
613    }
614
615    /// Predict using per-feature auto-calibrated bandwidths.
616    ///
617    /// Each feature uses its own bandwidth derived from median split threshold
618    /// gaps. Features with `f64::INFINITY` bandwidth fall back to hard routing.
619    pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
620        self.predict_smooth_auto_recursive(self.root, features, bandwidths)
621    }
622
623    /// Predict with parent-leaf linear interpolation.
624    ///
625    /// Routes to the leaf but blends the leaf prediction with the parent node's
626    /// preserved prediction based on the leaf's hessian sum. Fresh leaves
627    /// (low hess_sum) smoothly transition from parent prediction to their own:
628    ///
629    /// `alpha = leaf_hess / (leaf_hess + lambda)`
630    /// `pred = alpha * leaf_pred + (1 - alpha) * parent_pred`
631    ///
632    /// This fixes static predictions from leaves that split but haven't
633    /// accumulated enough samples to outperform their parent.
634    pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
635        let mut current = self.root;
636        let mut parent = None;
637        while !self.arena.is_leaf(current) {
638            parent = Some(current);
639            let feat_idx = self.arena.get_feature_idx(current) as usize;
640            current = if let Some(mask) = self.arena.get_categorical_mask(current) {
641                let cat_val = features[feat_idx] as u64;
642                if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
643                    self.arena.get_left(current)
644                } else {
645                    self.arena.get_right(current)
646                }
647            } else {
648                let threshold = self.arena.get_threshold(current);
649                if features[feat_idx] <= threshold {
650                    self.arena.get_left(current)
651                } else {
652                    self.arena.get_right(current)
653                }
654            };
655        }
656
657        let leaf_pred = self.leaf_prediction(current, features);
658
659        // No parent (root is leaf) → return leaf prediction directly
660        let parent_id = match parent {
661            Some(p) => p,
662            None => return leaf_pred,
663        };
664
665        // Get parent's preserved prediction from its old leaf state
666        let parent_pred = self.leaf_prediction(parent_id, features);
667
668        // Blend: alpha = leaf_hess / (leaf_hess + lambda)
669        let leaf_hess = self
670            .leaf_states
671            .get(current.0 as usize)
672            .and_then(|o| o.as_ref())
673            .map(|s| s.hess_sum)
674            .unwrap_or(0.0);
675
676        let alpha = leaf_hess / (leaf_hess + self.config.lambda);
677        alpha * leaf_pred + (1.0 - alpha) * parent_pred
678    }
679
680    /// Predict with sibling-based interpolation for feature-continuous predictions.
681    ///
682    /// At the leaf's parent split, blends the leaf prediction with its sibling's
683    /// prediction based on the feature's distance from the split threshold:
684    ///
685    /// Within the margin `m` around the threshold:
686    /// `t = (feature - threshold + m) / (2 * m)`  (0 at left edge, 1 at right edge)
687    /// `pred = (1 - t) * left_pred + t * right_pred`
688    ///
689    /// Outside the margin, returns the routed child's prediction directly.
690    /// The margin `m` is derived from auto-bandwidths if available, otherwise
691    /// defaults to `feature_range / n_bins` heuristic per feature.
692    ///
693    /// This makes predictions vary continuously as features move near split
694    /// boundaries, eliminating step-function artifacts.
695    pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
696        self.predict_sibling_recursive(self.root, features, bandwidths)
697    }
698
699    fn predict_sibling_recursive(&self, node: NodeId, features: &[f64], bandwidths: &[f64]) -> f64 {
700        if self.arena.is_leaf(node) {
701            return self.leaf_prediction(node, features);
702        }
703
704        let feat_idx = self.arena.get_feature_idx(node) as usize;
705        let left = self.arena.get_left(node);
706        let right = self.arena.get_right(node);
707
708        // Categorical splits: always hard routing (no interpolation)
709        if let Some(mask) = self.arena.get_categorical_mask(node) {
710            let cat_val = features[feat_idx] as u64;
711            return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
712                self.predict_sibling_recursive(left, features, bandwidths)
713            } else {
714                self.predict_sibling_recursive(right, features, bandwidths)
715            };
716        }
717
718        let threshold = self.arena.get_threshold(node);
719        let margin = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
720
721        // No valid margin or infinite → hard routing
722        if !margin.is_finite() || margin <= 0.0 {
723            return if features[feat_idx] <= threshold {
724                self.predict_sibling_recursive(left, features, bandwidths)
725            } else {
726                self.predict_sibling_recursive(right, features, bandwidths)
727            };
728        }
729
730        let dist = features[feat_idx] - threshold;
731
732        if dist < -margin {
733            // Firmly in left child territory
734            self.predict_sibling_recursive(left, features, bandwidths)
735        } else if dist > margin {
736            // Firmly in right child territory
737            self.predict_sibling_recursive(right, features, bandwidths)
738        } else {
739            // Within the interpolation margin: linear blend
740            let t = (dist + margin) / (2.0 * margin); // 0.0 at left edge, 1.0 at right edge
741            let left_pred = self.predict_sibling_recursive(left, features, bandwidths);
742            let right_pred = self.predict_sibling_recursive(right, features, bandwidths);
743            (1.0 - t) * left_pred + t * right_pred
744        }
745    }
746
747    /// Collect all split thresholds per feature from the tree arena.
748    ///
749    /// Returns a `Vec<Vec<f64>>` indexed by feature, containing all thresholds
750    /// used in continuous splits. Categorical splits are excluded.
751    pub fn collect_split_thresholds_per_feature(&self) -> Vec<Vec<f64>> {
752        let n = self.n_features.unwrap_or(0);
753        let mut thresholds: Vec<Vec<f64>> = vec![Vec::new(); n];
754
755        for i in 0..self.arena.n_nodes() {
756            if !self.arena.is_leaf[i] && self.arena.categorical_mask[i].is_none() {
757                let feat_idx = self.arena.feature_idx[i] as usize;
758                if feat_idx < n {
759                    thresholds[feat_idx].push(self.arena.threshold[i]);
760                }
761            }
762        }
763
764        thresholds
765    }
766
767    /// Recursive sigmoid-blended prediction traversal.
768    fn predict_smooth_recursive(&self, node: NodeId, features: &[f64], bandwidth: f64) -> f64 {
769        if self.arena.is_leaf(node) {
770            // At a leaf, return the leaf prediction (same as regular predict)
771            return self.leaf_prediction(node, features);
772        }
773
774        let feat_idx = self.arena.get_feature_idx(node) as usize;
775        let left = self.arena.get_left(node);
776        let right = self.arena.get_right(node);
777
778        // Categorical splits: hard routing (sigmoid blending is meaningless for categories)
779        if let Some(mask) = self.arena.get_categorical_mask(node) {
780            let cat_val = features[feat_idx] as u64;
781            return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
782                self.predict_smooth_recursive(left, features, bandwidth)
783            } else {
784                self.predict_smooth_recursive(right, features, bandwidth)
785            };
786        }
787
788        // Continuous split: sigmoid blending for smooth transition around the threshold
789        let threshold = self.arena.get_threshold(node);
790        let z = (threshold - features[feat_idx]) / bandwidth;
791        let alpha = 1.0 / (1.0 + math::exp(-z));
792
793        let left_pred = self.predict_smooth_recursive(left, features, bandwidth);
794        let right_pred = self.predict_smooth_recursive(right, features, bandwidth);
795
796        alpha * left_pred + (1.0 - alpha) * right_pred
797    }
798
799    /// Recursive per-feature-bandwidth smooth prediction traversal.
800    fn predict_smooth_auto_recursive(
801        &self,
802        node: NodeId,
803        features: &[f64],
804        bandwidths: &[f64],
805    ) -> f64 {
806        if self.arena.is_leaf(node) {
807            return self.leaf_prediction(node, features);
808        }
809
810        let feat_idx = self.arena.get_feature_idx(node) as usize;
811        let left = self.arena.get_left(node);
812        let right = self.arena.get_right(node);
813
814        // Categorical splits: always hard routing
815        if let Some(mask) = self.arena.get_categorical_mask(node) {
816            let cat_val = features[feat_idx] as u64;
817            return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
818                self.predict_smooth_auto_recursive(left, features, bandwidths)
819            } else {
820                self.predict_smooth_auto_recursive(right, features, bandwidths)
821            };
822        }
823
824        let threshold = self.arena.get_threshold(node);
825        let bw = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
826
827        // Infinite bandwidth = feature never split on across ensemble → hard routing
828        if !bw.is_finite() {
829            return if features[feat_idx] <= threshold {
830                self.predict_smooth_auto_recursive(left, features, bandwidths)
831            } else {
832                self.predict_smooth_auto_recursive(right, features, bandwidths)
833            };
834        }
835
836        // Sigmoid-blended soft routing with per-feature bandwidth
837        let z = (threshold - features[feat_idx]) / bw;
838        let alpha = 1.0 / (1.0 + math::exp(-z));
839
840        let left_pred = self.predict_smooth_auto_recursive(left, features, bandwidths);
841        let right_pred = self.predict_smooth_auto_recursive(right, features, bandwidths);
842
843        alpha * left_pred + (1.0 - alpha) * right_pred
844    }
845
846    /// Generate the feature mask for split evaluation.
847    ///
848    /// If `feature_subsample_rate` is 1.0, all features are included.
849    /// Otherwise, a random subset is selected via xorshift64.
850    ///
851    /// Uses a `Vec<u64>` bitset for O(1) membership testing, replacing the
852    /// previous O(n) `Vec::contains()` which made the fallback loop O(n²).
853    fn generate_feature_mask(&mut self, n_features: usize) {
854        self.feature_mask.clear();
855
856        if self.config.feature_subsample_rate >= 1.0 {
857            self.feature_mask.extend(0..n_features);
858        } else {
859            let target_count =
860                crate::math::ceil((n_features as f64) * self.config.feature_subsample_rate)
861                    as usize;
862            let target_count = target_count.max(1).min(n_features);
863
864            // Prepare the bitset: one bit per feature, O(1) membership test.
865            let n_words = n_features.div_ceil(64);
866            self.feature_mask_bits.clear();
867            self.feature_mask_bits.resize(n_words, 0u64);
868
869            // Include each feature with probability = subsample_rate.
870            for i in 0..n_features {
871                let r = xorshift64(&mut self.rng_state);
872                let p = (r as f64) / (u64::MAX as f64);
873                if p < self.config.feature_subsample_rate {
874                    self.feature_mask.push(i);
875                    self.feature_mask_bits[i / 64] |= 1u64 << (i % 64);
876                }
877            }
878
879            // If we didn't get enough features, fill up deterministically.
880            // Now O(n) instead of O(n²) thanks to the bitset.
881            if self.feature_mask.len() < target_count {
882                for i in 0..n_features {
883                    if self.feature_mask.len() >= target_count {
884                        break;
885                    }
886                    if self.feature_mask_bits[i / 64] & (1u64 << (i % 64)) == 0 {
887                        self.feature_mask.push(i);
888                        self.feature_mask_bits[i / 64] |= 1u64 << (i % 64);
889                    }
890                }
891            }
892        }
893    }
894
895    /// Attempt a split at the given leaf node.
896    ///
897    /// Returns `true` if a split was performed.
898    fn attempt_split(&mut self, leaf_id: NodeId) -> bool {
899        let depth = self.arena.get_depth(leaf_id);
900        let at_max_depth = depth as usize >= self.config.max_depth;
901
902        if at_max_depth {
903            // Only proceed if split re-evaluation is enabled and the interval
904            // has elapsed since the last evaluation at this leaf.
905            match self.config.split_reeval_interval {
906                None => return false,
907                Some(interval) => {
908                    let state = match self
909                        .leaf_states
910                        .get(leaf_id.0 as usize)
911                        .and_then(|o| o.as_ref())
912                    {
913                        Some(s) => s,
914                        None => return false,
915                    };
916                    let sample_count = self.arena.get_sample_count(leaf_id);
917                    if sample_count - state.last_reeval_count < interval as u64 {
918                        return false;
919                    }
920                    // Fall through to evaluate potential split.
921                }
922            }
923        }
924
925        let n_features = match self.n_features {
926            Some(n) => n,
927            None => return false,
928        };
929
930        let sample_count = self.arena.get_sample_count(leaf_id);
931        if sample_count < self.config.grace_period as u64 {
932            return false;
933        }
934
935        // Generate the feature mask for this split evaluation.
936        self.generate_feature_mask(n_features);
937
938        // Materialize pending lazy decay before reading histogram data.
939        // This converts un-decayed coordinates to true decayed values so
940        // split evaluation sees correct gradient/hessian sums. O(n_features * n_bins)
941        // but amortized over grace_period samples -- not per-sample cost.
942        if self.config.leaf_decay_alpha.is_some() {
943            if let Some(state) = self
944                .leaf_states
945                .get_mut(leaf_id.0 as usize)
946                .and_then(|o| o.as_mut())
947            {
948                if let Some(ref mut histograms) = state.histograms {
949                    histograms.materialize_decay();
950                }
951            }
952        }
953
954        // Evaluate splits for each feature in the mask.
955        // We need to borrow leaf_states immutably while feature_mask is borrowed.
956        // Collect candidates first.
957        let state = match self
958            .leaf_states
959            .get(leaf_id.0 as usize)
960            .and_then(|o| o.as_ref())
961        {
962            Some(s) => s,
963            None => return false,
964        };
965
966        let histograms = match &state.histograms {
967            Some(h) => h,
968            None => return false,
969        };
970
971        // Collect (feature_idx, best_split_candidate, optional_fisher_order) for
972        // each feature in the mask. For categorical features, we reorder bins by
973        // Fisher optimal binary partitioning before evaluation.
974        let feature_types = &self.config.feature_types;
975        let mut candidates: Vec<(usize, SplitCandidate, Option<Vec<usize>>)> = Vec::new();
976
977        for &feat_idx in &self.feature_mask {
978            if feat_idx >= histograms.n_features() {
979                continue;
980            }
981            let hist = &histograms.histograms[feat_idx];
982            let total_grad = hist.total_gradient();
983            let total_hess = hist.total_hessian();
984
985            let is_categorical = feature_types
986                .as_ref()
987                .is_some_and(|ft| feat_idx < ft.len() && ft[feat_idx] == FeatureType::Categorical);
988
989            if is_categorical {
990                // Fisher optimal binary partitioning:
991                // 1. Compute gradient_sum/hessian_sum ratio per bin
992                // 2. Sort bins by this ratio
993                // 3. Evaluate splits on the sorted order
994                let n_bins = hist.grad_sums.len();
995                if n_bins < 2 {
996                    continue;
997                }
998
999                // Build (bin_index, ratio) pairs, filtering out empty bins
1000                let mut bin_order: Vec<usize> = (0..n_bins)
1001                    .filter(|&i| math::abs(hist.hess_sums[i]) > 1e-15)
1002                    .collect();
1003
1004                if bin_order.len() < 2 {
1005                    continue;
1006                }
1007
1008                // Sort by grad_sum / hess_sum ratio (ascending)
1009                bin_order.sort_by(|&a, &b| {
1010                    let ratio_a = hist.grad_sums[a] / hist.hess_sums[a];
1011                    let ratio_b = hist.grad_sums[b] / hist.hess_sums[b];
1012                    ratio_a
1013                        .partial_cmp(&ratio_b)
1014                        .unwrap_or(core::cmp::Ordering::Equal)
1015                });
1016
1017                // Reorder grad/hess sums according to Fisher order
1018                let sorted_grads: Vec<f64> = bin_order.iter().map(|&i| hist.grad_sums[i]).collect();
1019                let sorted_hess: Vec<f64> = bin_order.iter().map(|&i| hist.hess_sums[i]).collect();
1020
1021                if let Some(candidate) = self.split_criterion.evaluate(
1022                    &sorted_grads,
1023                    &sorted_hess,
1024                    total_grad,
1025                    total_hess,
1026                    self.config.gamma,
1027                    self.config.lambda,
1028                ) {
1029                    candidates.push((feat_idx, candidate, Some(bin_order)));
1030                }
1031            } else {
1032                // Standard continuous feature -- evaluate as-is
1033                if let Some(candidate) = self.split_criterion.evaluate(
1034                    &hist.grad_sums,
1035                    &hist.hess_sums,
1036                    total_grad,
1037                    total_hess,
1038                    self.config.gamma,
1039                    self.config.lambda,
1040                ) {
1041                    candidates.push((feat_idx, candidate, None));
1042                }
1043            }
1044        }
1045
1046        // Filter out candidates that violate monotonic constraints.
1047        if let Some(ref mc) = self.config.monotone_constraints {
1048            candidates.retain(|(feat_idx, candidate, _)| {
1049                if *feat_idx >= mc.len() {
1050                    return true; // No constraint for this feature
1051                }
1052                let constraint = mc[*feat_idx];
1053                if constraint == 0 {
1054                    return true; // Unconstrained
1055                }
1056
1057                let left_val =
1058                    leaf_weight(candidate.left_grad, candidate.left_hess, self.config.lambda);
1059                let right_val = leaf_weight(
1060                    candidate.right_grad,
1061                    candidate.right_hess,
1062                    self.config.lambda,
1063                );
1064
1065                if constraint > 0 {
1066                    // Non-decreasing: left_value <= right_value
1067                    left_val <= right_val
1068                } else {
1069                    // Non-increasing: left_value >= right_value
1070                    left_val >= right_val
1071                }
1072            });
1073        }
1074
1075        if candidates.is_empty() {
1076            return false;
1077        }
1078
1079        // Sort candidates by gain descending.
1080        candidates.sort_by(|a, b| {
1081            b.1.gain
1082                .partial_cmp(&a.1.gain)
1083                .unwrap_or(core::cmp::Ordering::Equal)
1084        });
1085
1086        let best_gain = candidates[0].1.gain;
1087        let second_best_gain = if candidates.len() > 1 {
1088            candidates[1].1.gain
1089        } else {
1090            0.0
1091        };
1092
1093        // Hoeffding bound: epsilon = sqrt(R^2 * ln(1/delta) / (2 * n))
1094        // R = 1.0 (conservative bound on the range of the gain function).
1095        //
1096        // With EWMA decay, the effective sample size is bounded by 1/(1-alpha).
1097        // We cap n at this value to prevent spurious splits from artificially
1098        // tight bounds when decay is active.
1099        let r_squared = 1.0;
1100        let n = sample_count as f64;
1101        let effective_n = match self.config.leaf_decay_alpha {
1102            Some(alpha) => n.min(1.0 / (1.0 - alpha)),
1103            None => n,
1104        };
1105        let ln_inv_delta = math::ln(1.0 / self.config.delta);
1106        let epsilon = math::sqrt(r_squared * ln_inv_delta / (2.0 * effective_n));
1107
1108        // Split condition: the best is significantly better than second-best,
1109        // OR the bound is already so tight that more samples won't help.
1110        let gap = best_gain - second_best_gain;
1111        if gap <= epsilon && epsilon >= TAU {
1112            // If this was a re-evaluation at max depth, update the count
1113            // so we don't re-evaluate again until the next interval elapses.
1114            if at_max_depth {
1115                if let Some(state) = self
1116                    .leaf_states
1117                    .get_mut(leaf_id.0 as usize)
1118                    .and_then(|o| o.as_mut())
1119                {
1120                    state.last_reeval_count = sample_count;
1121                }
1122            }
1123            return false;
1124        }
1125
1126        // --- Execute the split ---
1127        let (best_feat_idx, ref best_candidate, ref fisher_order) = candidates[0];
1128
1129        // Track split gain for feature importance.
1130        if best_feat_idx < self.split_gains.len() {
1131            self.split_gains[best_feat_idx] += best_candidate.gain;
1132        }
1133
1134        let best_hist = &histograms.histograms[best_feat_idx];
1135
1136        let left_value = leaf_weight(
1137            best_candidate.left_grad,
1138            best_candidate.left_hess,
1139            self.config.lambda,
1140        );
1141        let right_value = leaf_weight(
1142            best_candidate.right_grad,
1143            best_candidate.right_hess,
1144            self.config.lambda,
1145        );
1146
1147        // Perform the split -- categorical or continuous.
1148        let (left_id, right_id) = if let Some(ref order) = fisher_order {
1149            // Categorical split: build a bitmask from the Fisher-sorted partition.
1150            // bin_idx in the sorted order means bins order[0..=bin_idx] go left.
1151            // We need to map those back to original bin indices and set their bits.
1152            //
1153            // For categorical features, bin index = category value (since we use
1154            // one bin per category with midpoint edges).
1155            let mut mask: u64 = 0;
1156            for &sorted_pos in order.iter().take(best_candidate.bin_idx + 1) {
1157                // sorted_pos is the original bin index; for categorical features,
1158                // bin index corresponds to the category's position in sorted categories.
1159                // The actual category value is stored as an integer that maps to this bin.
1160                if sorted_pos < 64 {
1161                    mask |= 1u64 << sorted_pos;
1162                }
1163            }
1164
1165            // Threshold stores 0.0 for categorical splits (routing uses mask).
1166            self.arena.split_leaf_categorical(
1167                leaf_id,
1168                best_feat_idx as u32,
1169                0.0,
1170                left_value,
1171                right_value,
1172                mask,
1173            )
1174        } else {
1175            // Continuous split: standard threshold from bin edge.
1176            let threshold = if best_candidate.bin_idx < best_hist.edges.edges.len() {
1177                best_hist.edges.edges[best_candidate.bin_idx]
1178            } else {
1179                f64::MAX
1180            };
1181
1182            self.arena.split_leaf(
1183                leaf_id,
1184                best_feat_idx as u32,
1185                threshold,
1186                left_value,
1187                right_value,
1188            )
1189        };
1190
1191        // Build child histograms using the subtraction trick.
1192        // The "left" child gets a fresh histogram set built from the parent's
1193        // bins, populated by scanning parent bins [0..=bin_idx].
1194        // Instead of re-scanning, we use the subtraction trick: one child
1195        // gets the parent's histograms minus the other child's.
1196        //
1197        // Strategy: build the left child's histogram directly from the
1198        // parent histogram for each feature (summing bins 0..=best_bin for
1199        // the split feature). Then the right child = parent - left.
1200        //
1201        // Actually, for a streaming tree, the cleaner approach is:
1202        // - Remove the parent's state
1203        // - Create fresh states for both children (they'll accumulate from
1204        //   new samples going forward)
1205        // - BUT we can seed one child with the parent's histograms by
1206        //   constructing "virtual" histograms from the parent's data.
1207        //
1208        // The simplest correct approach: both children start fresh with
1209        // pre-computed bin edges from the parent, so they're immediately
1210        // ready to accumulate. We don't carry forward the parent's histogram
1211        // data because new samples will naturally populate the children.
1212        //
1213        // However, to be more efficient, we CAN carry forward histogram data
1214        // using the subtraction trick. Let's do this properly:
1215
1216        let parent_state = self
1217            .leaf_states
1218            .get_mut(leaf_id.0 as usize)
1219            .and_then(|o| o.take());
1220        let nf = n_features;
1221
1222        // Ensure Vec is large enough for child NodeIds.
1223        let max_child = left_id.0.max(right_id.0) as usize;
1224        if self.leaf_states.len() <= max_child {
1225            self.leaf_states.resize_with(max_child + 1, || None);
1226        }
1227
1228        if let Some(parent) = parent_state {
1229            if let Some(parent_hists) = parent.histograms {
1230                // Build left child histograms from the parent.
1231                let edges_per_feature: Vec<BinEdges> = parent_hists
1232                    .histograms
1233                    .iter()
1234                    .map(|h| h.edges.clone())
1235                    .collect();
1236
1237                // The left child inherits a copy of the parent histograms,
1238                // but we really want to compute how much of the parent's data
1239                // would have gone left vs right. We don't have that per-feature
1240                // breakdown for features other than the split feature.
1241                //
1242                // Correct approach: create fresh histogram states for both
1243                // children with the same bin edges. They start empty but
1244                // with bins_ready = true, so new samples immediately accumulate.
1245                let left_hists = LeafHistograms::new(&edges_per_feature);
1246                let right_hists = LeafHistograms::new(&edges_per_feature);
1247
1248                let ft = self.config.feature_types.as_deref();
1249                let child_binners_l = make_binners(nf, ft);
1250                let child_binners_r = make_binners(nf, ft);
1251
1252                // Warm-start children from parent's learned leaf model.
1253                // If parent has a model, children inherit its weights (resetting
1254                // optimizer state). If parent has no model (ClosedForm), children
1255                // also get None -- the fast path stays fast.
1256                let left_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1257                let right_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1258
1259                let left_state = LeafState {
1260                    histograms: Some(left_hists),
1261                    binners: child_binners_l,
1262                    bins_ready: true,
1263                    grad_sum: 0.0,
1264                    hess_sum: 0.0,
1265                    last_reeval_count: 0,
1266                    clip_grad_mean: 0.0,
1267                    clip_grad_m2: 0.0,
1268                    clip_grad_count: 0,
1269                    output_mean: 0.0,
1270                    output_m2: 0.0,
1271                    output_count: 0,
1272                    leaf_model: left_model,
1273                };
1274
1275                let right_state = LeafState {
1276                    histograms: Some(right_hists),
1277                    binners: child_binners_r,
1278                    bins_ready: true,
1279                    grad_sum: 0.0,
1280                    hess_sum: 0.0,
1281                    last_reeval_count: 0,
1282                    clip_grad_mean: 0.0,
1283                    clip_grad_m2: 0.0,
1284                    clip_grad_count: 0,
1285                    output_mean: 0.0,
1286                    output_m2: 0.0,
1287                    output_count: 0,
1288                    leaf_model: right_model,
1289                };
1290
1291                self.leaf_states[left_id.0 as usize] = Some(left_state);
1292                self.leaf_states[right_id.0 as usize] = Some(right_state);
1293            } else {
1294                // Parent didn't have histograms (shouldn't happen if bins_ready).
1295                let ft = self.config.feature_types.as_deref();
1296                let mut ls = LeafState::new_with_types(nf, ft);
1297                ls.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1298                self.leaf_states[left_id.0 as usize] = Some(ls);
1299                let mut rs = LeafState::new_with_types(nf, ft);
1300                rs.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1301                self.leaf_states[right_id.0 as usize] = Some(rs);
1302            }
1303        } else {
1304            // No parent state found (shouldn't happen).
1305            let ft = self.config.feature_types.as_deref();
1306            let mut ls = LeafState::new_with_types(nf, ft);
1307            ls.leaf_model = self.make_leaf_model(left_id);
1308            self.leaf_states[left_id.0 as usize] = Some(ls);
1309            let mut rs = LeafState::new_with_types(nf, ft);
1310            rs.leaf_model = self.make_leaf_model(right_id);
1311            self.leaf_states[right_id.0 as usize] = Some(rs);
1312        }
1313
1314        true
1315    }
1316}
1317
1318impl StreamingTree for HoeffdingTree {
1319    /// Train the tree on a single sample.
1320    ///
1321    /// Routes the sample to its leaf, updates histogram accumulators, and
1322    /// attempts a split if the Hoeffding bound is satisfied.
1323    fn train_one(&mut self, features: &[f64], gradient: f64, hessian: f64) {
1324        self.samples_seen += 1;
1325
1326        // Initialize n_features on first sample.
1327        let n_features = if let Some(n) = self.n_features {
1328            n
1329        } else {
1330            let n = features.len();
1331            self.n_features = Some(n);
1332            self.split_gains.resize(n, 0.0);
1333
1334            // Re-initialize the root's leaf state now that we know n_features.
1335            if let Some(state) = self
1336                .leaf_states
1337                .get_mut(self.root.0 as usize)
1338                .and_then(|o| o.as_mut())
1339            {
1340                state.binners = make_binners(n, self.config.feature_types.as_deref());
1341            }
1342            n
1343        };
1344
1345        debug_assert_eq!(
1346            features.len(),
1347            n_features,
1348            "feature count mismatch: got {} but expected {}",
1349            features.len(),
1350            n_features,
1351        );
1352
1353        // Route to leaf.
1354        let leaf_id = self.route_to_leaf(features);
1355
1356        // Increment the sample count in the arena.
1357        self.arena.increment_sample_count(leaf_id);
1358        let sample_count = self.arena.get_sample_count(leaf_id);
1359
1360        // Get or create the leaf state.
1361        let idx = leaf_id.0 as usize;
1362        if self.leaf_states.len() <= idx {
1363            self.leaf_states.resize_with(idx + 1, || None);
1364        }
1365        if self.leaf_states[idx].is_none() {
1366            self.leaf_states[idx] = Some(LeafState::new_with_types(
1367                n_features,
1368                self.config.feature_types.as_deref(),
1369            ));
1370        }
1371        let state = self.leaf_states[idx].as_mut().unwrap();
1372
1373        // Apply per-leaf gradient clipping if enabled.
1374        let gradient = if let Some(sigma) = self.config.gradient_clip_sigma {
1375            clip_gradient(state, gradient, sigma)
1376        } else {
1377            gradient
1378        };
1379
1380        // If bins are not yet ready, check if we've reached the grace period.
1381        if !state.bins_ready {
1382            // Observe feature values in the binners.
1383            for (binner, &val) in state.binners.iter_mut().zip(features.iter()) {
1384                binner.observe(val);
1385            }
1386
1387            // Accumulate running gradient/hessian sums (with optional EWMA decay).
1388            if let Some(alpha) = self.config.leaf_decay_alpha {
1389                state.grad_sum = alpha * state.grad_sum + gradient;
1390                state.hess_sum = alpha * state.hess_sum + hessian;
1391            } else {
1392                state.grad_sum += gradient;
1393                state.hess_sum += hessian;
1394            }
1395
1396            // Update the leaf value from running sums.
1397            let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
1398            self.arena.set_leaf_value(leaf_id, lw);
1399
1400            // Track per-leaf output weight for adaptive bounds.
1401            if self.config.adaptive_leaf_bound.is_some() {
1402                update_output_stats(state, lw, self.config.leaf_decay_alpha);
1403            }
1404
1405            // Update the leaf model if one exists (linear / MLP).
1406            if let Some(ref mut model) = state.leaf_model {
1407                model.update(features, gradient, hessian, self.config.lambda);
1408            }
1409
1410            // Check if we've reached the grace period to compute bin edges.
1411            if sample_count >= self.config.grace_period as u64 {
1412                let edges_per_feature: Vec<BinEdges> = state
1413                    .binners
1414                    .iter()
1415                    .map(|b| b.compute_edges(self.config.n_bins))
1416                    .collect();
1417
1418                let mut histograms = LeafHistograms::new(&edges_per_feature);
1419
1420                // We don't have the raw samples to replay into the histogram,
1421                // but we DO have the running grad/hess sums. We can't distribute
1422                // them across bins retroactively. The histograms start empty and
1423                // will accumulate from the next sample onward.
1424                // However, we should NOT lose the current sample. Let's accumulate
1425                // this sample into the newly created histograms.
1426                if let Some(alpha) = self.config.leaf_decay_alpha {
1427                    histograms.accumulate_with_decay(features, gradient, hessian, alpha);
1428                } else {
1429                    histograms.accumulate(features, gradient, hessian);
1430                }
1431
1432                state.histograms = Some(histograms);
1433                state.bins_ready = true;
1434            }
1435
1436            return;
1437        }
1438
1439        // Bins are ready -- accumulate into histograms (with optional decay).
1440        if let Some(ref mut histograms) = state.histograms {
1441            if let Some(alpha) = self.config.leaf_decay_alpha {
1442                histograms.accumulate_with_decay(features, gradient, hessian, alpha);
1443            } else {
1444                histograms.accumulate(features, gradient, hessian);
1445            }
1446        }
1447
1448        // Update running gradient/hessian sums and leaf value (with optional EWMA decay).
1449        if let Some(alpha) = self.config.leaf_decay_alpha {
1450            state.grad_sum = alpha * state.grad_sum + gradient;
1451            state.hess_sum = alpha * state.hess_sum + hessian;
1452        } else {
1453            state.grad_sum += gradient;
1454            state.hess_sum += hessian;
1455        }
1456        let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
1457        self.arena.set_leaf_value(leaf_id, lw);
1458
1459        // Track per-leaf output weight for adaptive bounds.
1460        if self.config.adaptive_leaf_bound.is_some() {
1461            update_output_stats(state, lw, self.config.leaf_decay_alpha);
1462        }
1463
1464        // Update the leaf model if one exists (linear / MLP).
1465        if let Some(ref mut model) = state.leaf_model {
1466            model.update(features, gradient, hessian, self.config.lambda);
1467        }
1468
1469        // Attempt split.
1470        // We only try every grace_period samples to avoid excessive computation.
1471        if sample_count % (self.config.grace_period as u64) == 0 {
1472            self.attempt_split(leaf_id);
1473        }
1474    }
1475
1476    /// Predict the leaf value for a feature vector.
1477    ///
1478    /// Routes from the root to a leaf via threshold comparisons and returns
1479    /// the leaf's current weight.
1480    fn predict(&self, features: &[f64]) -> f64 {
1481        let leaf_id = self.route_to_leaf(features);
1482        self.leaf_prediction(leaf_id, features)
1483    }
1484
1485    /// Current number of leaf nodes.
1486    #[inline]
1487    fn n_leaves(&self) -> usize {
1488        self.arena.n_leaves()
1489    }
1490
1491    /// Total number of samples seen since creation.
1492    #[inline]
1493    fn n_samples_seen(&self) -> u64 {
1494        self.samples_seen
1495    }
1496
1497    /// Reset to initial state with a single root leaf.
1498    fn reset(&mut self) {
1499        self.arena.reset();
1500        let root = self.arena.add_leaf(0);
1501        self.root = root;
1502        self.leaf_states.clear();
1503
1504        // Insert a placeholder leaf state for the new root.
1505        let n_features = self.n_features.unwrap_or(0);
1506        self.leaf_states.resize_with(root.0 as usize + 1, || None);
1507        let mut root_state =
1508            LeafState::new_with_types(n_features, self.config.feature_types.as_deref());
1509        root_state.leaf_model = self.make_leaf_model(root);
1510        self.leaf_states[root.0 as usize] = Some(root_state);
1511
1512        self.samples_seen = 0;
1513        self.feature_mask.clear();
1514        self.feature_mask_bits.clear();
1515        self.rng_state = self.config.seed;
1516        self.split_gains.iter_mut().for_each(|g| *g = 0.0);
1517    }
1518
1519    fn split_gains(&self) -> &[f64] {
1520        &self.split_gains
1521    }
1522
1523    fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
1524        let leaf_id = self.route_to_leaf(features);
1525        let value = self.leaf_prediction(leaf_id, features);
1526        if let Some(state) = self
1527            .leaf_states
1528            .get(leaf_id.0 as usize)
1529            .and_then(|o| o.as_ref())
1530        {
1531            // Variance of the leaf weight estimate = 1 / (H_sum + lambda)
1532            let variance = 1.0 / (state.hess_sum + self.config.lambda);
1533            (value, variance)
1534        } else {
1535            (value, f64::INFINITY)
1536        }
1537    }
1538}
1539
1540impl Clone for HoeffdingTree {
1541    fn clone(&self) -> Self {
1542        Self {
1543            arena: self.arena.clone(),
1544            root: self.root,
1545            config: self.config.clone(),
1546            leaf_states: self.leaf_states.clone(),
1547            n_features: self.n_features,
1548            samples_seen: self.samples_seen,
1549            split_criterion: self.split_criterion,
1550            feature_mask: self.feature_mask.clone(),
1551            feature_mask_bits: self.feature_mask_bits.clone(),
1552            rng_state: self.rng_state,
1553            split_gains: self.split_gains.clone(),
1554        }
1555    }
1556}
1557
1558// SAFETY: All fields are Send + Sync. BinnerKind is a concrete enum with
1559// Send + Sync variants. XGBoostGain is stateless. Vec<Option<LeafState>>
1560// and Vec fields are trivially Send + Sync.
1561unsafe impl Send for HoeffdingTree {}
1562unsafe impl Sync for HoeffdingTree {}
1563
1564#[cfg(test)]
1565mod tests {
1566    use super::*;
1567    use crate::tree::builder::TreeConfig;
1568    use crate::tree::StreamingTree;
1569
1570    /// Simple xorshift64 for test reproducibility (same as the tree uses).
1571    fn test_xorshift(state: &mut u64) -> u64 {
1572        xorshift64(state)
1573    }
1574
1575    /// Generate a pseudo-random f64 in [0, 1) from the RNG state.
1576    fn test_rand_f64(state: &mut u64) -> f64 {
1577        let r = test_xorshift(state);
1578        (r as f64) / (u64::MAX as f64)
1579    }
1580
1581    // -----------------------------------------------------------------------
1582    // Test 1: Single sample train + predict returns non-NaN.
1583    // -----------------------------------------------------------------------
1584    #[test]
1585    fn single_sample_predict_not_nan() {
1586        let config = TreeConfig::new().grace_period(10);
1587        let mut tree = HoeffdingTree::new(config);
1588
1589        let features = vec![1.0, 2.0, 3.0];
1590        tree.train_one(&features, -0.5, 1.0);
1591
1592        let pred = tree.predict(&features);
1593        assert!(!pred.is_nan(), "prediction should not be NaN, got {}", pred);
1594        assert!(
1595            pred.is_finite(),
1596            "prediction should be finite, got {}",
1597            pred
1598        );
1599
1600        // With gradient=-0.5, hessian=1.0, lambda=1.0:
1601        // leaf_weight = -(-0.5) / (1.0 + 1.0) = 0.25
1602        assert!((pred - 0.25).abs() < 1e-10, "expected ~0.25, got {}", pred);
1603    }
1604
1605    // -----------------------------------------------------------------------
1606    // Test 2: Train 1000 samples from y=2*x + noise, verify RMSE decreases.
1607    // -----------------------------------------------------------------------
1608    #[test]
1609    fn linear_signal_rmse_improves() {
1610        let config = TreeConfig::new()
1611            .max_depth(4)
1612            .n_bins(32)
1613            .grace_period(50)
1614            .lambda(0.1)
1615            .gamma(0.0)
1616            .delta(1e-3);
1617
1618        let mut tree = HoeffdingTree::new(config);
1619        let mut rng_state: u64 = 12345;
1620
1621        // Generate training data: y = 2*x, with x in [0, 10].
1622        // For gradient boosting, gradient = prediction - target (for squared loss),
1623        // hessian = 1.0.
1624        //
1625        // We'll simulate a simple boosting loop:
1626        // - Start with prediction = 0 for all points.
1627        // - gradient = pred - target = 0 - y = -y
1628        // - hessian = 1.0
1629
1630        let n_train = 1000;
1631        let mut features_all: Vec<f64> = Vec::with_capacity(n_train);
1632        let mut targets: Vec<f64> = Vec::with_capacity(n_train);
1633
1634        for _ in 0..n_train {
1635            let x = test_rand_f64(&mut rng_state) * 10.0;
1636            let noise = (test_rand_f64(&mut rng_state) - 0.5) * 0.5;
1637            let y = 2.0 * x + noise;
1638            features_all.push(x);
1639            targets.push(y);
1640        }
1641
1642        // Compute initial RMSE (prediction = 0).
1643        let initial_mse: f64 = targets.iter().map(|y| y * y).sum::<f64>() / n_train as f64;
1644        let initial_rmse = initial_mse.sqrt();
1645
1646        // Train the tree.
1647        for i in 0..n_train {
1648            let feat = [features_all[i]];
1649            let pred = tree.predict(&feat);
1650            // For squared loss: gradient = pred - target, hessian = 1.0.
1651            let gradient = pred - targets[i];
1652            let hessian = 1.0;
1653            tree.train_one(&feat, gradient, hessian);
1654        }
1655
1656        // Compute post-training RMSE.
1657        let mut post_mse = 0.0;
1658        for i in 0..n_train {
1659            let feat = [features_all[i]];
1660            let pred = tree.predict(&feat);
1661            let err = pred - targets[i];
1662            post_mse += err * err;
1663        }
1664        post_mse /= n_train as f64;
1665        let post_rmse = post_mse.sqrt();
1666
1667        assert!(
1668            post_rmse < initial_rmse,
1669            "RMSE should decrease after training: initial={:.4}, post={:.4}",
1670            initial_rmse,
1671            post_rmse,
1672        );
1673    }
1674
1675    // -----------------------------------------------------------------------
1676    // Test 3: No splits before grace_period samples.
1677    // -----------------------------------------------------------------------
1678    #[test]
1679    fn no_splits_before_grace_period() {
1680        let grace = 100;
1681        let config = TreeConfig::new()
1682            .grace_period(grace)
1683            .max_depth(4)
1684            .n_bins(16)
1685            .delta(1e-1); // Very lenient delta to make splits easy.
1686
1687        let mut tree = HoeffdingTree::new(config);
1688        let mut rng_state: u64 = 99999;
1689
1690        // Train grace_period - 1 samples.
1691        for _ in 0..(grace - 1) {
1692            let x = test_rand_f64(&mut rng_state) * 10.0;
1693            let y = 2.0 * x;
1694            let feat = [x];
1695            let pred = tree.predict(&feat);
1696            tree.train_one(&feat, pred - y, 1.0);
1697        }
1698
1699        assert_eq!(
1700            tree.n_leaves(),
1701            1,
1702            "should be exactly 1 leaf before grace_period, got {}",
1703            tree.n_leaves()
1704        );
1705    }
1706
1707    // -----------------------------------------------------------------------
1708    // Test 4: Tree does not exceed max_depth.
1709    // -----------------------------------------------------------------------
1710    #[test]
1711    fn respects_max_depth() {
1712        let max_depth = 3;
1713        let config = TreeConfig::new()
1714            .max_depth(max_depth)
1715            .grace_period(20)
1716            .n_bins(16)
1717            .lambda(0.01)
1718            .gamma(0.0)
1719            .delta(1e-1); // Very lenient.
1720
1721        let mut tree = HoeffdingTree::new(config);
1722        let mut rng_state: u64 = 7777;
1723
1724        // Train many samples with a clear signal to force splitting.
1725        for _ in 0..5000 {
1726            let x = test_rand_f64(&mut rng_state) * 10.0;
1727            let y = if x < 2.5 {
1728                -5.0
1729            } else if x < 5.0 {
1730                -1.0
1731            } else if x < 7.5 {
1732                1.0
1733            } else {
1734                5.0
1735            };
1736            let feat = [x];
1737            let pred = tree.predict(&feat);
1738            tree.train_one(&feat, pred - y, 1.0);
1739        }
1740
1741        // Maximum number of leaves at depth d is 2^d.
1742        let max_leaves = 1usize << max_depth;
1743        assert!(
1744            tree.n_leaves() <= max_leaves,
1745            "tree has {} leaves, but max_depth={} allows at most {}",
1746            tree.n_leaves(),
1747            max_depth,
1748            max_leaves,
1749        );
1750    }
1751
1752    // -----------------------------------------------------------------------
1753    // Test 5: Reset works -- tree returns to single leaf.
1754    // -----------------------------------------------------------------------
1755    #[test]
1756    fn reset_returns_to_single_leaf() {
1757        let config = TreeConfig::new()
1758            .grace_period(20)
1759            .max_depth(4)
1760            .n_bins(16)
1761            .delta(1e-1);
1762
1763        let mut tree = HoeffdingTree::new(config);
1764        let mut rng_state: u64 = 54321;
1765
1766        // Train enough to potentially cause splits.
1767        for _ in 0..2000 {
1768            let x = test_rand_f64(&mut rng_state) * 10.0;
1769            let y = 3.0 * x - 5.0;
1770            let feat = [x];
1771            let pred = tree.predict(&feat);
1772            tree.train_one(&feat, pred - y, 1.0);
1773        }
1774
1775        let pre_reset_samples = tree.n_samples_seen();
1776        assert!(pre_reset_samples > 0);
1777
1778        tree.reset();
1779
1780        assert_eq!(
1781            tree.n_leaves(),
1782            1,
1783            "after reset, should have exactly 1 leaf"
1784        );
1785        assert_eq!(
1786            tree.n_samples_seen(),
1787            0,
1788            "after reset, samples_seen should be 0"
1789        );
1790
1791        // Predict should still work (returns 0.0 from empty leaf).
1792        let pred = tree.predict(&[5.0]);
1793        assert!(
1794            pred.abs() < 1e-10,
1795            "prediction after reset should be ~0.0, got {}",
1796            pred
1797        );
1798    }
1799
1800    // -----------------------------------------------------------------------
1801    // Test 6: Multiple features -- verify tree uses different features.
1802    // -----------------------------------------------------------------------
1803    #[test]
1804    fn multi_feature_training() {
1805        let config = TreeConfig::new()
1806            .grace_period(30)
1807            .max_depth(4)
1808            .n_bins(16)
1809            .lambda(0.1)
1810            .delta(1e-2);
1811
1812        let mut tree = HoeffdingTree::new(config);
1813        let mut rng_state: u64 = 11111;
1814
1815        // y = x0 + 2*x1, two features.
1816        for _ in 0..1000 {
1817            let x0 = test_rand_f64(&mut rng_state) * 5.0;
1818            let x1 = test_rand_f64(&mut rng_state) * 5.0;
1819            let y = x0 + 2.0 * x1;
1820            let feat = [x0, x1];
1821            let pred = tree.predict(&feat);
1822            tree.train_one(&feat, pred - y, 1.0);
1823        }
1824
1825        // Just verify it trained without panicking and produces finite predictions.
1826        let pred = tree.predict(&[2.5, 2.5]);
1827        assert!(
1828            pred.is_finite(),
1829            "multi-feature prediction should be finite"
1830        );
1831        assert_eq!(tree.n_samples_seen(), 1000);
1832    }
1833
1834    // -----------------------------------------------------------------------
1835    // Test 7: Feature subsampling does not panic.
1836    // -----------------------------------------------------------------------
1837    #[test]
1838    fn feature_subsampling_works() {
1839        let config = TreeConfig::new()
1840            .grace_period(30)
1841            .max_depth(3)
1842            .n_bins(16)
1843            .lambda(0.1)
1844            .delta(1e-2)
1845            .feature_subsample_rate(0.5);
1846
1847        let mut tree = HoeffdingTree::new(config);
1848        let mut rng_state: u64 = 33333;
1849
1850        // 5 features, only ~50% considered per split.
1851        for _ in 0..1000 {
1852            let feats: Vec<f64> = (0..5)
1853                .map(|_| test_rand_f64(&mut rng_state) * 10.0)
1854                .collect();
1855            let y: f64 = feats.iter().sum();
1856            let pred = tree.predict(&feats);
1857            tree.train_one(&feats, pred - y, 1.0);
1858        }
1859
1860        let pred = tree.predict(&[1.0, 2.0, 3.0, 4.0, 5.0]);
1861        assert!(pred.is_finite(), "subsampled prediction should be finite");
1862    }
1863
1864    // -----------------------------------------------------------------------
1865    // Test 8: xorshift64 produces deterministic sequence.
1866    // -----------------------------------------------------------------------
1867    #[test]
1868    fn xorshift64_deterministic() {
1869        let mut s1: u64 = 42;
1870        let mut s2: u64 = 42;
1871
1872        let seq1: Vec<u64> = (0..100).map(|_| xorshift64(&mut s1)).collect();
1873        let seq2: Vec<u64> = (0..100).map(|_| xorshift64(&mut s2)).collect();
1874
1875        assert_eq!(seq1, seq2, "xorshift64 should be deterministic");
1876
1877        // Verify no zeros in the sequence (xorshift64 with non-zero seed never produces 0).
1878        for &v in &seq1 {
1879            assert_ne!(v, 0, "xorshift64 should never produce 0 with non-zero seed");
1880        }
1881    }
1882
1883    // -----------------------------------------------------------------------
1884    // Test 9: EWMA leaf decay -- recent data dominates predictions.
1885    // -----------------------------------------------------------------------
1886    #[test]
1887    fn ewma_leaf_decay_recent_data_dominates() {
1888        // half_life=50 => alpha = exp(-ln(2)/50) ≈ 0.9862
1889        let alpha = (-(2.0_f64.ln()) / 50.0).exp();
1890        let config = TreeConfig::new()
1891            .grace_period(20)
1892            .max_depth(4)
1893            .n_bins(16)
1894            .lambda(1.0)
1895            .leaf_decay_alpha(alpha);
1896        let mut tree = HoeffdingTree::new(config);
1897
1898        // Phase 1: 1000 samples targeting 1.0
1899        for _ in 0..1000 {
1900            let pred = tree.predict(&[1.0, 2.0]);
1901            let grad = pred - 1.0; // gradient for squared loss
1902            tree.train_one(&[1.0, 2.0], grad, 1.0);
1903        }
1904
1905        // Phase 2: 100 samples targeting 5.0
1906        for _ in 0..100 {
1907            let pred = tree.predict(&[1.0, 2.0]);
1908            let grad = pred - 5.0;
1909            tree.train_one(&[1.0, 2.0], grad, 1.0);
1910        }
1911
1912        let pred = tree.predict(&[1.0, 2.0]);
1913        // With EWMA, the prediction should be pulled toward 5.0 (recent target).
1914        // Without EWMA, 1000 samples at 1.0 would dominate 100 at 5.0.
1915        assert!(
1916            pred > 2.0,
1917            "EWMA should let recent data (target=5.0) pull prediction above 2.0, got {}",
1918            pred,
1919        );
1920    }
1921
1922    // -----------------------------------------------------------------------
1923    // Test 10: EWMA disabled (None) matches traditional behavior.
1924    // -----------------------------------------------------------------------
1925    #[test]
1926    fn ewma_disabled_matches_traditional() {
1927        let config_no_ewma = TreeConfig::new()
1928            .grace_period(20)
1929            .max_depth(4)
1930            .n_bins(16)
1931            .lambda(1.0);
1932        let mut tree = HoeffdingTree::new(config_no_ewma);
1933
1934        let mut rng_state: u64 = 99999;
1935        for _ in 0..200 {
1936            let x = test_rand_f64(&mut rng_state) * 10.0;
1937            let y = 3.0 * x + 1.0;
1938            let pred = tree.predict(&[x]);
1939            tree.train_one(&[x], pred - y, 1.0);
1940        }
1941
1942        let pred = tree.predict(&[5.0]);
1943        assert!(
1944            pred.is_finite(),
1945            "prediction without EWMA should be finite, got {}",
1946            pred
1947        );
1948    }
1949
1950    // -----------------------------------------------------------------------
1951    // Test 11: Split re-evaluation at max depth grows beyond frozen point.
1952    // -----------------------------------------------------------------------
1953    #[test]
1954    fn split_reeval_at_max_depth() {
1955        let config = TreeConfig::new()
1956            .grace_period(20)
1957            .max_depth(2) // Very shallow to hit max depth quickly
1958            .n_bins(16)
1959            .lambda(1.0)
1960            .split_reeval_interval(50);
1961        let mut tree = HoeffdingTree::new(config);
1962
1963        let mut rng_state: u64 = 54321;
1964        // Train enough to saturate max_depth=2 and then trigger re-evaluation.
1965        for _ in 0..2000 {
1966            let x1 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
1967            let x2 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
1968            let y = 2.0 * x1 + 3.0 * x2;
1969            let pred = tree.predict(&[x1, x2]);
1970            tree.train_one(&[x1, x2], pred - y, 1.0);
1971        }
1972
1973        // With split_reeval_interval=50, max-depth leaves can re-evaluate
1974        // and potentially split beyond max_depth. The tree should have MORE
1975        // leaves than a max_depth=2 tree without re-eval (which caps at 4).
1976        let leaves = tree.n_leaves();
1977        assert!(
1978            leaves >= 4,
1979            "split re-eval should allow growth beyond max_depth=2 cap (4 leaves), got {}",
1980            leaves,
1981        );
1982    }
1983
1984    // -----------------------------------------------------------------------
1985    // Test 12: Split re-evaluation disabled matches existing behavior.
1986    // -----------------------------------------------------------------------
1987    #[test]
1988    fn split_reeval_disabled_matches_traditional() {
1989        let config = TreeConfig::new()
1990            .grace_period(20)
1991            .max_depth(2)
1992            .n_bins(16)
1993            .lambda(1.0);
1994        // No split_reeval_interval => None => traditional hard cap
1995        let mut tree = HoeffdingTree::new(config);
1996
1997        let mut rng_state: u64 = 77777;
1998        for _ in 0..2000 {
1999            let x1 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2000            let x2 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2001            let y = 2.0 * x1 + 3.0 * x2;
2002            let pred = tree.predict(&[x1, x2]);
2003            tree.train_one(&[x1, x2], pred - y, 1.0);
2004        }
2005
2006        // Without re-eval, max_depth=2 caps at 4 leaves (2^2).
2007        let leaves = tree.n_leaves();
2008        assert!(
2009            leaves <= 4,
2010            "without re-eval, max_depth=2 should cap at 4 leaves, got {}",
2011            leaves,
2012        );
2013    }
2014
2015    // -----------------------------------------------------------------------
2016    // Test: Gradient clipping clamps outliers
2017    // -----------------------------------------------------------------------
2018    #[test]
2019    fn gradient_clipping_clamps_outliers() {
2020        let config = TreeConfig::new()
2021            .grace_period(20)
2022            .max_depth(2)
2023            .n_bins(16)
2024            .gradient_clip_sigma(2.0);
2025
2026        let mut tree = HoeffdingTree::new(config);
2027
2028        // Train 50 normal samples
2029        let mut rng_state = 42u64;
2030        for _ in 0..50 {
2031            let x = test_rand_f64(&mut rng_state) * 2.0;
2032            let grad = x * 0.1; // small gradients ~[0, 0.2]
2033            tree.train_one(&[x], grad, 1.0);
2034        }
2035
2036        let pred_before = tree.predict(&[1.0]);
2037
2038        // Now inject an extreme outlier gradient
2039        tree.train_one(&[1.0], 1000.0, 1.0);
2040
2041        let pred_after = tree.predict(&[1.0]);
2042
2043        // With clipping at 2-sigma, the outlier should be clamped.
2044        // Without clipping, the prediction would jump massively.
2045        // The change should be bounded.
2046        let delta = (pred_after - pred_before).abs();
2047        assert!(
2048            delta < 100.0,
2049            "gradient clipping should limit impact of outlier, but prediction changed by {}",
2050            delta,
2051        );
2052    }
2053
2054    // -----------------------------------------------------------------------
2055    // Test: clip_gradient function directly
2056    // -----------------------------------------------------------------------
2057    #[test]
2058    fn clip_gradient_welford_tracks_stats() {
2059        let mut state = LeafState::new(1);
2060
2061        // Feed 20 varied gradients to build up statistics
2062        for i in 0..20 {
2063            let grad = 1.0 + (i as f64) * 0.1; // range [1.0, 2.9]
2064            let clipped = clip_gradient(&mut state, grad, 3.0);
2065            // 3-sigma is very wide, so these should not be clipped
2066            assert!(
2067                (clipped - grad).abs() < 1e-10,
2068                "normal gradients should not be clipped at 3-sigma"
2069            );
2070        }
2071
2072        // Now an extreme outlier -- mean is ~1.95, std ~0.59, 3-sigma range is ~[0.18, 3.72]
2073        let clipped = clip_gradient(&mut state, 100.0, 3.0);
2074        assert!(
2075            clipped < 100.0,
2076            "extreme outlier should be clipped, got {}",
2077            clipped,
2078        );
2079        assert!(
2080            clipped > 0.0,
2081            "clipped value should be positive, got {}",
2082            clipped,
2083        );
2084    }
2085
2086    // -----------------------------------------------------------------------
2087    // Test: clip_gradient warmup period
2088    // -----------------------------------------------------------------------
2089    #[test]
2090    fn clip_gradient_warmup_no_clipping() {
2091        let mut state = LeafState::new(1);
2092
2093        // During warmup (< 10 samples), no clipping
2094        for i in 0..9 {
2095            let val = if i == 8 { 1000.0 } else { 1.0 };
2096            let clipped = clip_gradient(&mut state, val, 2.0);
2097            assert_eq!(clipped, val, "warmup should not clip");
2098        }
2099    }
2100
2101    // -----------------------------------------------------------------------
2102    // Test: adaptive_bound warmup returns f64::MAX
2103    // -----------------------------------------------------------------------
2104    #[test]
2105    fn adaptive_bound_warmup_returns_max() {
2106        let mut state = LeafState::new(1);
2107        // Feed < 10 output weights
2108        for i in 0..9 {
2109            update_output_stats(&mut state, 0.5 + i as f64 * 0.01, None);
2110        }
2111        let bound = adaptive_bound(&state, 3.0, None);
2112        assert_eq!(bound, f64::MAX, "warmup should return f64::MAX");
2113    }
2114
2115    // -----------------------------------------------------------------------
2116    // Test: adaptive_bound tightens after warmup (Welford path)
2117    // -----------------------------------------------------------------------
2118    #[test]
2119    fn adaptive_bound_tightens_after_warmup() {
2120        let mut state = LeafState::new(1);
2121        // Feed 20 outputs centered around 0.3 with small variance
2122        for i in 0..20 {
2123            let w = 0.3 + (i as f64 - 10.0) * 0.01; // range [0.2, 0.39]
2124            update_output_stats(&mut state, w, None);
2125        }
2126        let bound = adaptive_bound(&state, 3.0, None);
2127        // Bound should be much less than a global max of 3.0
2128        assert!(
2129            bound < 1.0,
2130            "3-sigma bound on outputs ~0.3 should be < 1.0, got {}",
2131            bound,
2132        );
2133        assert!(bound > 0.2, "bound should be > |mean|, got {}", bound,);
2134    }
2135
2136    // -----------------------------------------------------------------------
2137    // Test: adaptive_bound clamps outlier leaf
2138    // -----------------------------------------------------------------------
2139    #[test]
2140    fn adaptive_bound_clamps_outlier_leaf() {
2141        let mut state = LeafState::new(1);
2142        // Build stats: 20 outputs ~0.3
2143        for _ in 0..20 {
2144            update_output_stats(&mut state, 0.3, None);
2145        }
2146        let bound = adaptive_bound(&state, 3.0, None);
2147        // A leaf output of 2.9 should be clamped
2148        let clamped = (2.9_f64).clamp(-bound, bound);
2149        assert!(
2150            clamped < 2.9,
2151            "2.9 should be clamped by adaptive bound {}, got {}",
2152            bound,
2153            clamped,
2154        );
2155    }
2156
2157    // -----------------------------------------------------------------------
2158    // Test: adaptive_bound with EWMA decay adapts
2159    // -----------------------------------------------------------------------
2160    #[test]
2161    fn adaptive_bound_with_decay_adapts() {
2162        let alpha = 0.95; // fast decay for testing
2163        let mut state = LeafState::new(1);
2164
2165        // Phase 1: outputs around 0.3
2166        for _ in 0..30 {
2167            update_output_stats(&mut state, 0.3, Some(alpha));
2168        }
2169        let bound_phase1 = adaptive_bound(&state, 3.0, Some(alpha));
2170
2171        // Phase 2: outputs shift to 2.0
2172        for _ in 0..100 {
2173            update_output_stats(&mut state, 2.0, Some(alpha));
2174        }
2175        let bound_phase2 = adaptive_bound(&state, 3.0, Some(alpha));
2176
2177        // After regime change, bound should adapt upward
2178        assert!(
2179            bound_phase2 > bound_phase1,
2180            "EWMA bound should adapt: phase1={}, phase2={}",
2181            bound_phase1,
2182            bound_phase2,
2183        );
2184    }
2185
2186    // -----------------------------------------------------------------------
2187    // Test: adaptive_bound disabled by default
2188    // -----------------------------------------------------------------------
2189    #[test]
2190    fn adaptive_bound_disabled_by_default() {
2191        let config = TreeConfig::default();
2192        assert!(
2193            config.adaptive_leaf_bound.is_none(),
2194            "adaptive_leaf_bound should default to None",
2195        );
2196    }
2197
2198    // -----------------------------------------------------------------------
2199    // Test: adaptive_bound warmup falls back to global max_leaf_output
2200    // -----------------------------------------------------------------------
2201    #[test]
2202    fn adaptive_bound_warmup_falls_back_to_global() {
2203        let mut state = LeafState::new(1);
2204        // Only 5 samples — still in warmup
2205        for _ in 0..5 {
2206            update_output_stats(&mut state, 0.3, None);
2207        }
2208        let bound = adaptive_bound(&state, 3.0, None);
2209        assert_eq!(bound, f64::MAX, "warmup should yield f64::MAX");
2210        // In leaf_prediction, f64::MAX falls through to global max_leaf_output
2211    }
2212
2213    // -----------------------------------------------------------------------
2214    // Test: Monotonic constraints filter invalid splits
2215    // -----------------------------------------------------------------------
2216    #[test]
2217    fn monotonic_constraint_splits_respected() {
2218        // Train with +1 constraint on feature 0 (increasing).
2219        // Use a dataset where feature 0 has a negative relationship.
2220        let config = TreeConfig::new()
2221            .grace_period(30)
2222            .max_depth(4)
2223            .n_bins(16)
2224            .monotone_constraints(vec![1]); // feature 0 must be increasing
2225
2226        let mut tree = HoeffdingTree::new(config);
2227
2228        let mut rng_state = 42u64;
2229        for _ in 0..500 {
2230            let x = test_rand_f64(&mut rng_state) * 10.0;
2231            // Negative relationship: high x → low y → positive gradient
2232            let grad = x * 0.5 - 2.5;
2233            tree.train_one(&[x], grad, 1.0);
2234        }
2235
2236        // Any split that occurred should satisfy: left_value <= right_value
2237        // for the monotone +1 constraint. Verify prediction is non-decreasing.
2238        let pred_low = tree.predict(&[0.0]);
2239        let pred_mid = tree.predict(&[5.0]);
2240        let pred_high = tree.predict(&[10.0]);
2241
2242        // Due to constraint, prediction must be non-decreasing
2243        assert!(
2244            pred_low <= pred_mid + 1e-10 && pred_mid <= pred_high + 1e-10,
2245            "monotonic +1 violated: pred(0)={}, pred(5)={}, pred(10)={}",
2246            pred_low,
2247            pred_mid,
2248            pred_high,
2249        );
2250    }
2251
2252    // -----------------------------------------------------------------------
2253    // Test: predict_with_variance returns finite values
2254    // -----------------------------------------------------------------------
2255    #[test]
2256    fn predict_with_variance_finite() {
2257        let config = TreeConfig::new().grace_period(10);
2258        let mut tree = HoeffdingTree::new(config);
2259
2260        // Train a few samples
2261        for i in 0..30 {
2262            let x = i as f64 * 0.1;
2263            tree.train_one(&[x], x - 1.0, 1.0);
2264        }
2265
2266        let (value, variance) = tree.predict_with_variance(&[1.0]);
2267        assert!(value.is_finite(), "value should be finite");
2268        assert!(variance.is_finite(), "variance should be finite");
2269        assert!(variance > 0.0, "variance should be positive");
2270    }
2271
2272    // -----------------------------------------------------------------------
2273    // Test: predict_with_variance decreases with more data
2274    // -----------------------------------------------------------------------
2275    #[test]
2276    fn predict_with_variance_decreases_with_data() {
2277        let config = TreeConfig::new().grace_period(10);
2278        let mut tree = HoeffdingTree::new(config);
2279
2280        // Train 20 samples, check variance
2281        for i in 0..20 {
2282            tree.train_one(&[1.0], 0.5, 1.0);
2283            if i == 0 {
2284                continue;
2285            }
2286        }
2287        let (_, var_20) = tree.predict_with_variance(&[1.0]);
2288
2289        // Train 200 more samples
2290        for _ in 0..200 {
2291            tree.train_one(&[1.0], 0.5, 1.0);
2292        }
2293        let (_, var_220) = tree.predict_with_variance(&[1.0]);
2294
2295        assert!(
2296            var_220 < var_20,
2297            "variance should decrease with more data: var@20={} vs var@220={}",
2298            var_20,
2299            var_220,
2300        );
2301    }
2302
2303    // -----------------------------------------------------------------------
2304    // Test: predict_smooth matches hard prediction at small bandwidth
2305    // -----------------------------------------------------------------------
2306    #[test]
2307    fn predict_smooth_matches_hard_at_small_bandwidth() {
2308        let config = TreeConfig::new()
2309            .max_depth(3)
2310            .n_bins(16)
2311            .grace_period(20)
2312            .lambda(1.0);
2313        let mut tree = HoeffdingTree::new(config);
2314
2315        // Train enough to get splits
2316        let mut rng = 42u64;
2317        for _ in 0..500 {
2318            let x = test_rand_f64(&mut rng) * 10.0;
2319            let y = 2.0 * x + 1.0;
2320            let features = vec![x, x * 0.5];
2321            let pred = tree.predict(&features);
2322            let grad = pred - y;
2323            let hess = 1.0;
2324            tree.train_one(&features, grad, hess);
2325        }
2326
2327        // With very small bandwidth, smooth prediction should approximate hard prediction
2328        let features = vec![5.0, 2.5];
2329        let hard = tree.predict(&features);
2330        let smooth = tree.predict_smooth(&features, 0.001);
2331        assert!(
2332            (hard - smooth).abs() < 0.1,
2333            "smooth with tiny bandwidth should approximate hard: hard={}, smooth={}",
2334            hard,
2335            smooth,
2336        );
2337    }
2338
2339    // -----------------------------------------------------------------------
2340    // Test: predict_smooth is continuous
2341    // -----------------------------------------------------------------------
2342    #[test]
2343    fn predict_smooth_is_continuous() {
2344        let config = TreeConfig::new()
2345            .max_depth(3)
2346            .n_bins(16)
2347            .grace_period(20)
2348            .lambda(1.0);
2349        let mut tree = HoeffdingTree::new(config);
2350
2351        // Train to get splits
2352        let mut rng = 42u64;
2353        for _ in 0..500 {
2354            let x = test_rand_f64(&mut rng) * 10.0;
2355            let y = 2.0 * x + 1.0;
2356            let features = vec![x, x * 0.5];
2357            let pred = tree.predict(&features);
2358            let grad = pred - y;
2359            tree.train_one(&features, grad, 1.0);
2360        }
2361
2362        // Check that small input changes produce small output changes (continuity)
2363        let bandwidth = 1.0;
2364        let base = tree.predict_smooth(&[5.0, 2.5], bandwidth);
2365        let nudged = tree.predict_smooth(&[5.001, 2.5], bandwidth);
2366        let diff = (base - nudged).abs();
2367        assert!(
2368            diff < 0.1,
2369            "smooth prediction should be continuous: base={}, nudged={}, diff={}",
2370            base,
2371            nudged,
2372            diff,
2373        );
2374    }
2375
2376    // -----------------------------------------------------------------------
2377    // Test: leaf_grad_hess returns valid sums after training.
2378    // -----------------------------------------------------------------------
2379    #[test]
2380    fn leaf_grad_hess_returns_sums() {
2381        let config = TreeConfig::new().grace_period(100).lambda(1.0);
2382        let mut tree = HoeffdingTree::new(config);
2383
2384        let features = vec![1.0, 2.0, 3.0];
2385
2386        // Train several samples with known gradients
2387        for _ in 0..10 {
2388            tree.train_one(&features, -0.5, 1.0);
2389        }
2390
2391        // The root should be a leaf (grace_period=100, only 10 samples)
2392        let root = tree.root();
2393        let (grad, hess) = tree
2394            .leaf_grad_hess(root)
2395            .expect("root should have leaf state");
2396
2397        // grad_sum should be sum of all gradients: 10 * (-0.5) = -5.0
2398        assert!(
2399            (grad - (-5.0)).abs() < 1e-10,
2400            "grad_sum should be -5.0, got {}",
2401            grad
2402        );
2403        // hess_sum should be sum of all hessians: 10 * 1.0 = 10.0
2404        assert!(
2405            (hess - 10.0).abs() < 1e-10,
2406            "hess_sum should be 10.0, got {}",
2407            hess
2408        );
2409    }
2410
2411    #[test]
2412    fn leaf_grad_hess_returns_none_for_invalid_node() {
2413        let config = TreeConfig::new();
2414        let tree = HoeffdingTree::new(config);
2415
2416        // NodeId::NONE should return None
2417        assert!(tree.leaf_grad_hess(NodeId::NONE).is_none());
2418        // A non-existent node should return None
2419        assert!(tree.leaf_grad_hess(NodeId(999)).is_none());
2420    }
2421}