Skip to main content

irithyll_core/tree/
hoeffding.rs

1//! Hoeffding-bound split decisions for streaming tree construction.
2//!
3//! [`HoeffdingTree`] is the core streaming decision tree. It grows incrementally:
4//! each sample updates per-leaf histogram accumulators, and splits are committed
5//! only when the Hoeffding bound guarantees the best candidate split is
6//! statistically superior to the runner-up (or a tie-breaking threshold is met).
7//!
8//! # Algorithm
9//!
10//! For each incoming `(features, gradient, hessian)` triple:
11//!
12//! 1. Route the sample from root to a leaf via threshold comparisons.
13//! 2. At the leaf, accumulate gradient/hessian into per-feature histograms.
14//! 3. Once enough samples arrive (grace period), evaluate candidate splits
15//!    using the XGBoost gain formula.
16//! 4. Apply the Hoeffding bound: if the gap between the best and second-best
17//!    gain exceeds `epsilon = sqrt(R^2 * ln(1/delta) / (2n))`, commit the split.
18//! 5. When splitting, use the histogram subtraction trick to initialize one
19//!    child's histograms for free.
20
21use alloc::boxed::Box;
22use alloc::vec;
23use alloc::vec::Vec;
24
25use crate::feature::FeatureType;
26use crate::histogram::bins::LeafHistograms;
27use crate::histogram::{BinEdges, BinnerKind};
28use crate::math;
29use crate::tree::builder::TreeConfig;
30use crate::tree::leaf_model::{LeafModel, LeafModelType};
31use crate::tree::node::{NodeId, TreeArena};
32use crate::tree::split::{leaf_weight, SplitCandidate, SplitCriterion, XGBoostGain};
33use crate::tree::StreamingTree;
34
35/// Tie-breaking threshold (tau). When `epsilon < tau`, we accept the best split
36/// even if the gap between best and second-best gain is small, because the
37/// Hoeffding bound is already tight enough that further samples won't help.
38const TAU: f64 = 0.05;
39
40// ---------------------------------------------------------------------------
41// xorshift64 -- minimal deterministic PRNG
42// ---------------------------------------------------------------------------
43
44/// Advance an xorshift64 state and return the new value.
45#[inline]
46fn xorshift64(state: &mut u64) -> u64 {
47    let mut s = *state;
48    s ^= s << 13;
49    s ^= s >> 7;
50    s ^= s << 17;
51    *state = s;
52    s
53}
54
55// ---------------------------------------------------------------------------
56// LeafState -- per-leaf bookkeeping
57// ---------------------------------------------------------------------------
58
59/// State tracked per leaf node for split decisions.
60///
61/// Each active leaf owns its own set of histogram accumulators (one per feature)
62/// and running gradient/hessian sums for leaf weight updates.
63struct LeafState {
64    /// Histogram accumulators for this leaf. `None` until bin edges are computed
65    /// (after the grace period).
66    histograms: Option<LeafHistograms>,
67
68    /// Per-feature binning strategies that collect observed values to compute
69    /// bin edges. Uses `BinnerKind` enum dispatch instead of `Box<dyn>` to
70    /// eliminate N_features heap allocations per new leaf.
71    binners: Vec<BinnerKind>,
72
73    /// Whether bin edges have been computed (after grace period samples).
74    bins_ready: bool,
75
76    /// Running gradient sum for leaf weight updates.
77    grad_sum: f64,
78
79    /// Running hessian sum for leaf weight updates.
80    hess_sum: f64,
81
82    /// Sample count at last split re-evaluation (for EFDT-inspired re-eval).
83    last_reeval_count: u64,
84
85    /// EWMA gradient mean for gradient clipping (Welford online algorithm).
86    clip_grad_mean: f64,
87
88    /// EWMA gradient M2 accumulator (Welford) for variance estimation.
89    clip_grad_m2: f64,
90
91    /// Number of gradients observed for clipping statistics.
92    clip_grad_count: u64,
93
94    /// EWMA/Welford mean of this leaf's output weight (for adaptive bounds).
95    output_mean: f64,
96
97    /// EWMA/Welford M2 (variance accumulator) of this leaf's output weight.
98    output_m2: f64,
99
100    /// Number of output weight observations.
101    output_count: u64,
102
103    /// Optional trainable leaf model (linear / MLP). `None` for closed-form leaves.
104    leaf_model: Option<Box<dyn LeafModel>>,
105}
106
107impl Clone for LeafState {
108    fn clone(&self) -> Self {
109        Self {
110            histograms: self.histograms.clone(),
111            binners: self.binners.clone(),
112            bins_ready: self.bins_ready,
113            grad_sum: self.grad_sum,
114            hess_sum: self.hess_sum,
115            last_reeval_count: self.last_reeval_count,
116            clip_grad_mean: self.clip_grad_mean,
117            clip_grad_m2: self.clip_grad_m2,
118            clip_grad_count: self.clip_grad_count,
119            output_mean: self.output_mean,
120            output_m2: self.output_m2,
121            output_count: self.output_count,
122            leaf_model: self.leaf_model.as_ref().map(|m| m.clone_warm()),
123        }
124    }
125}
126
127/// Clip a gradient using Welford online stats tracked per leaf.
128///
129/// Updates running mean/variance, then clamps the gradient to `mean ± sigma * std_dev`.
130/// Returns the (possibly clamped) gradient. During warmup (< 10 samples), no clipping
131/// is applied to let the statistics stabilize.
132#[inline]
133fn clip_gradient(state: &mut LeafState, gradient: f64, sigma: f64) -> f64 {
134    state.clip_grad_count += 1;
135    let n = state.clip_grad_count as f64;
136
137    // Welford online update
138    let delta = gradient - state.clip_grad_mean;
139    state.clip_grad_mean += delta / n;
140    let delta2 = gradient - state.clip_grad_mean;
141    state.clip_grad_m2 += delta * delta2;
142
143    // No clipping during warmup
144    if state.clip_grad_count < 10 {
145        return gradient;
146    }
147
148    let variance = state.clip_grad_m2 / (n - 1.0);
149    let std_dev = math::sqrt(variance);
150
151    if std_dev < 1e-15 {
152        return gradient; // All gradients identical -- no clipping needed
153    }
154
155    let lo = state.clip_grad_mean - sigma * std_dev;
156    let hi = state.clip_grad_mean + sigma * std_dev;
157    gradient.clamp(lo, hi)
158}
159
160/// Update per-leaf output weight tracking for adaptive bounds.
161///
162/// If `decay_alpha` is Some, uses EWMA synchronized with leaf_decay_alpha.
163/// Otherwise uses Welford online algorithm (batch scenarios).
164#[inline]
165fn update_output_stats(state: &mut LeafState, weight: f64, decay_alpha: Option<f64>) {
166    state.output_count += 1;
167
168    if let Some(alpha) = decay_alpha {
169        // EWMA — synchronized with leaf gradient decay
170        if state.output_count == 1 {
171            state.output_mean = weight;
172            state.output_m2 = 0.0;
173        } else {
174            let diff = weight - state.output_mean;
175            state.output_mean = alpha * state.output_mean + (1.0 - alpha) * weight;
176            let diff2 = weight - state.output_mean;
177            state.output_m2 = alpha * state.output_m2 + (1.0 - alpha) * diff * diff2;
178        }
179    } else {
180        // Welford online (no decay — batch scenarios)
181        let delta = weight - state.output_mean;
182        state.output_mean += delta / (state.output_count as f64);
183        let delta2 = weight - state.output_mean;
184        state.output_m2 += delta * delta2;
185    }
186}
187
188/// Get the adaptive output bound for this leaf.
189///
190/// Returns `|mean| + k * std`, with a floor of 0.01 to never fully suppress a leaf.
191/// During warmup (< 10 samples), returns `f64::MAX` (no bound).
192#[inline]
193fn adaptive_bound(state: &LeafState, k: f64, decay_alpha: Option<f64>) -> f64 {
194    if state.output_count < 10 {
195        return f64::MAX; // warmup — no bound yet
196    }
197
198    let variance = if decay_alpha.is_some() {
199        // EWMA variance is stored directly in output_m2
200        state.output_m2.max(0.0)
201    } else {
202        // Welford: variance = M2 / (n - 1)
203        state.output_m2 / (state.output_count as f64 - 1.0)
204    };
205    let std = math::sqrt(variance);
206
207    // Bound = |mean| + k * std, floor at 0.01 to never fully suppress a leaf
208    (math::abs(state.output_mean) + k * std).max(0.01)
209}
210
211/// Create binners according to feature types.
212fn make_binners(n_features: usize, feature_types: Option<&[FeatureType]>) -> Vec<BinnerKind> {
213    (0..n_features)
214        .map(|i| {
215            if let Some(ft) = feature_types {
216                if i < ft.len() && ft[i] == FeatureType::Categorical {
217                    return BinnerKind::categorical();
218                }
219            }
220            BinnerKind::uniform()
221        })
222        .collect()
223}
224
225impl LeafState {
226    /// Create a fresh leaf state for a leaf with `n_features` features.
227    fn new(n_features: usize) -> Self {
228        Self::new_with_types(n_features, None)
229    }
230
231    /// Create a fresh leaf state respecting per-feature type declarations.
232    fn new_with_types(n_features: usize, feature_types: Option<&[FeatureType]>) -> Self {
233        let binners = make_binners(n_features, feature_types);
234
235        Self {
236            histograms: None,
237            binners,
238            bins_ready: false,
239            grad_sum: 0.0,
240            hess_sum: 0.0,
241            last_reeval_count: 0,
242            clip_grad_mean: 0.0,
243            clip_grad_m2: 0.0,
244            clip_grad_count: 0,
245            output_mean: 0.0,
246            output_m2: 0.0,
247            output_count: 0,
248            leaf_model: None,
249        }
250    }
251
252    /// Create a leaf state with pre-computed histograms (used after a split
253    /// when we can initialize the child via histogram subtraction).
254    #[allow(dead_code)]
255    fn with_histograms(histograms: LeafHistograms) -> Self {
256        let n_features = histograms.n_features();
257        let binners: Vec<BinnerKind> = (0..n_features).map(|_| BinnerKind::uniform()).collect();
258
259        // Recover grad/hess sums from the histograms.
260        let grad_sum: f64 = histograms
261            .histograms
262            .first()
263            .map_or(0.0, |h| h.total_gradient());
264        let hess_sum: f64 = histograms
265            .histograms
266            .first()
267            .map_or(0.0, |h| h.total_hessian());
268
269        Self {
270            histograms: Some(histograms),
271            binners,
272            bins_ready: true,
273            grad_sum,
274            hess_sum,
275            last_reeval_count: 0,
276            clip_grad_mean: 0.0,
277            clip_grad_m2: 0.0,
278            clip_grad_count: 0,
279            output_mean: 0.0,
280            output_m2: 0.0,
281            output_count: 0,
282            leaf_model: None,
283        }
284    }
285}
286
287// ---------------------------------------------------------------------------
288// HoeffdingTree
289// ---------------------------------------------------------------------------
290
291/// A streaming decision tree that uses Hoeffding-bound split decisions.
292///
293/// The tree grows incrementally: each call to [`train_one`](StreamingTree::train_one)
294/// routes one sample to its leaf, updates histograms, and potentially triggers
295/// a split when statistical evidence is sufficient.
296///
297/// # Feature subsampling
298///
299/// When `config.feature_subsample_rate < 1.0`, each split evaluation considers
300/// only a random subset of features (selected via a deterministic xorshift64 RNG).
301/// This adds diversity when the tree is used inside an ensemble.
302pub struct HoeffdingTree {
303    /// Arena-allocated node storage.
304    arena: TreeArena,
305
306    /// Root node identifier.
307    root: NodeId,
308
309    /// Tree configuration / hyperparameters.
310    config: TreeConfig,
311
312    /// Per-leaf state indexed by `NodeId.0`. Dense Vec -- NodeIds are
313    /// contiguous u32 indices from TreeArena, so direct indexing is optimal.
314    leaf_states: Vec<Option<LeafState>>,
315
316    /// Number of features, learned from the first sample.
317    n_features: Option<usize>,
318
319    /// Total samples seen across all calls to `train_one`.
320    samples_seen: u64,
321
322    /// Split gain evaluator.
323    split_criterion: XGBoostGain,
324
325    /// Scratch buffer for the feature mask (avoids repeated allocation).
326    feature_mask: Vec<usize>,
327
328    /// Bitset scratch buffer for O(1) membership test during feature mask generation.
329    /// Each bit `i` indicates whether feature `i` is already in `feature_mask`.
330    feature_mask_bits: Vec<u64>,
331
332    /// xorshift64 RNG state for feature subsampling.
333    rng_state: u64,
334
335    /// Accumulated split gains per feature for importance tracking.
336    /// Indexed by feature index; grows lazily when n_features is learned.
337    split_gains: Vec<f64>,
338}
339
340impl HoeffdingTree {
341    /// Create a new `HoeffdingTree` with the given configuration.
342    ///
343    /// The tree starts with a single root leaf and no feature information;
344    /// the number of features is inferred from the first training sample.
345    pub fn new(config: TreeConfig) -> Self {
346        let mut arena = TreeArena::new();
347        let root = arena.add_leaf(0);
348
349        // Insert a placeholder leaf state for the root. We don't know n_features
350        // yet, so give it 0 binners -- it will be properly initialized on the
351        // first sample.
352        let mut leaf_states = vec![None; root.0 as usize + 1];
353        let root_model = match config.leaf_model_type {
354            LeafModelType::ClosedForm => None,
355            _ => Some(config.leaf_model_type.create(config.seed, config.delta)),
356        };
357        leaf_states[root.0 as usize] = Some(LeafState {
358            histograms: None,
359            binners: Vec::new(),
360            bins_ready: false,
361            grad_sum: 0.0,
362            hess_sum: 0.0,
363            last_reeval_count: 0,
364            clip_grad_mean: 0.0,
365            clip_grad_m2: 0.0,
366            clip_grad_count: 0,
367            output_mean: 0.0,
368            output_m2: 0.0,
369            output_count: 0,
370            leaf_model: root_model,
371        });
372
373        let seed = config.seed;
374        Self {
375            arena,
376            root,
377            config,
378            leaf_states,
379            n_features: None,
380            samples_seen: 0,
381            split_criterion: XGBoostGain::default(),
382            feature_mask: Vec::new(),
383            feature_mask_bits: Vec::new(),
384            rng_state: seed,
385            split_gains: Vec::new(),
386        }
387    }
388
389    /// Create a leaf model for a new leaf if the config requires one.
390    ///
391    /// Returns `None` for `ClosedForm` (the default), which uses the existing
392    /// `leaf_weight()` path with zero overhead. For `Linear` and `MLP`, returns
393    /// a fresh model seeded deterministically from the config seed and node id.
394    fn make_leaf_model(&self, node: NodeId) -> Option<Box<dyn LeafModel>> {
395        match self.config.leaf_model_type {
396            LeafModelType::ClosedForm => None,
397            _ => Some(
398                self.config
399                    .leaf_model_type
400                    .create(self.config.seed ^ (node.0 as u64), self.config.delta),
401            ),
402        }
403    }
404
405    /// Reconstruct a `HoeffdingTree` from a pre-built arena.
406    ///
407    /// Used during model deserialization. The tree is restored with node
408    /// topology and leaf values intact, but histogram accumulators are empty
409    /// (they will rebuild naturally from continued training).
410    ///
411    /// The root is assumed to be `NodeId(0)`. Leaf states are created empty
412    /// for all current leaf nodes in the arena.
413    pub fn from_arena(
414        config: TreeConfig,
415        arena: TreeArena,
416        n_features: Option<usize>,
417        samples_seen: u64,
418        rng_state: u64,
419    ) -> Self {
420        let root = if arena.n_nodes() > 0 {
421            NodeId(0)
422        } else {
423            // Empty arena -- add a root leaf (shouldn't normally happen in restore).
424            let mut arena_mut = arena;
425            let root = arena_mut.add_leaf(0);
426            return Self {
427                arena: arena_mut,
428                root,
429                config: config.clone(),
430                leaf_states: {
431                    let mut v = vec![None; root.0 as usize + 1];
432                    v[root.0 as usize] = Some(LeafState::new(n_features.unwrap_or(0)));
433                    v
434                },
435                n_features,
436                samples_seen,
437                split_criterion: XGBoostGain::default(),
438                feature_mask: Vec::new(),
439                feature_mask_bits: Vec::new(),
440                rng_state,
441                split_gains: vec![0.0; n_features.unwrap_or(0)],
442            };
443        };
444
445        // Build leaf states for every leaf in the arena.
446        let nf = n_features.unwrap_or(0);
447        let mut leaf_states: Vec<Option<LeafState>> = vec![None; arena.n_nodes()];
448        for (i, slot) in leaf_states.iter_mut().enumerate() {
449            if arena.is_leaf[i] {
450                *slot = Some(LeafState::new(nf));
451            }
452        }
453
454        Self {
455            arena,
456            root,
457            config,
458            leaf_states,
459            n_features,
460            samples_seen,
461            split_criterion: XGBoostGain::default(),
462            feature_mask: Vec::new(),
463            feature_mask_bits: Vec::new(),
464            rng_state,
465            split_gains: vec![0.0; nf],
466        }
467    }
468
469    /// Root node identifier.
470    #[inline]
471    pub fn root(&self) -> NodeId {
472        self.root
473    }
474
475    /// Immutable access to the underlying arena.
476    #[inline]
477    pub fn arena(&self) -> &TreeArena {
478        &self.arena
479    }
480
481    /// Immutable access to the tree configuration.
482    #[inline]
483    pub fn tree_config(&self) -> &TreeConfig {
484        &self.config
485    }
486
487    /// Number of features (learned from the first sample, `None` before any training).
488    #[inline]
489    pub fn n_features(&self) -> Option<usize> {
490        self.n_features
491    }
492
493    /// Current RNG state (for deterministic checkpoint/restore).
494    #[inline]
495    pub fn rng_state(&self) -> u64 {
496        self.rng_state
497    }
498
499    /// Read-only access to the gradient and hessian sums for a leaf node.
500    ///
501    /// Returns `Some((grad_sum, hess_sum))` if `node` is a leaf with an active
502    /// leaf state, or `None` if the node has no state (e.g. internal node
503    /// or freshly allocated).
504    ///
505    /// These sums enable inverse-hessian confidence estimation:
506    /// `confidence = 1.0 / (hess_sum + lambda)`. High hessian means the leaf
507    /// has seen consistent, informative data; low hessian means uncertainty.
508    #[inline]
509    pub fn leaf_grad_hess(&self, node: NodeId) -> Option<(f64, f64)> {
510        self.leaf_states
511            .get(node.0 as usize)
512            .and_then(|o| o.as_ref())
513            .map(|state| (state.grad_sum, state.hess_sum))
514    }
515
516    /// Route a feature vector from the root down to a leaf, returning the leaf's NodeId.
517    fn route_to_leaf(&self, features: &[f64]) -> NodeId {
518        let mut current = self.root;
519        while !self.arena.is_leaf(current) {
520            let feat_idx = self.arena.get_feature_idx(current) as usize;
521            current = if let Some(mask) = self.arena.get_categorical_mask(current) {
522                // Categorical split: use bitmask routing.
523                // The feature value is cast to a bin index. If that bin's bit is set
524                // in the mask, go left; otherwise go right.
525                // For categorical features, the bin index in the histogram corresponds
526                // to the sorted category position, but for bitmask routing we use
527                // the original bin index directly.
528                let cat_val = features[feat_idx] as u64;
529                if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
530                    self.arena.get_left(current)
531                } else {
532                    self.arena.get_right(current)
533                }
534            } else {
535                // Continuous split: standard threshold comparison.
536                let threshold = self.arena.get_threshold(current);
537                if features[feat_idx] <= threshold {
538                    self.arena.get_left(current)
539                } else {
540                    self.arena.get_right(current)
541                }
542            };
543        }
544        current
545    }
546
547    /// Get the prediction value for a leaf node.
548    ///
549    /// Checks (in order): leaf model, live grad/hess statistics, stored leaf value.
550    /// Returns `0.0` if no leaf state exists.
551    #[inline]
552    fn leaf_prediction(&self, leaf_id: NodeId, features: &[f64]) -> f64 {
553        let (raw, leaf_bound) = if let Some(state) = self
554            .leaf_states
555            .get(leaf_id.0 as usize)
556            .and_then(|o| o.as_ref())
557        {
558            // min_hessian_sum: suppress fresh leaves with insufficient samples
559            if let Some(min_h) = self.config.min_hessian_sum {
560                if state.hess_sum < min_h {
561                    return 0.0;
562                }
563            }
564            let val = if let Some(ref model) = state.leaf_model {
565                model.predict(features)
566            } else if state.hess_sum != 0.0 {
567                leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda)
568            } else {
569                self.arena.leaf_value[leaf_id.0 as usize]
570            };
571
572            // Compute per-leaf adaptive bound while state is in scope
573            let bound = self
574                .config
575                .adaptive_leaf_bound
576                .map(|k| adaptive_bound(state, k, self.config.leaf_decay_alpha));
577
578            (val, bound)
579        } else {
580            (0.0, None)
581        };
582
583        // Priority: per-leaf adaptive bound > global max_leaf_output > unclamped
584        if let Some(bound) = leaf_bound {
585            if bound < f64::MAX {
586                return raw.clamp(-bound, bound);
587            }
588        }
589        if let Some(max) = self.config.max_leaf_output {
590            raw.clamp(-max, max)
591        } else {
592            raw
593        }
594    }
595
596    /// Predict using sigmoid-blended soft routing for smooth interpolation.
597    ///
598    /// Instead of hard left/right routing at each split node, uses sigmoid
599    /// blending: `alpha = sigmoid((threshold - feature) / bandwidth)`. The
600    /// prediction is `alpha * left_pred + (1 - alpha) * right_pred`, computed
601    /// recursively from root to leaves.
602    ///
603    /// The result is a continuous function that varies smoothly with every
604    /// feature change — no bins, no boundaries, no jumps.
605    ///
606    /// # Arguments
607    ///
608    /// * `features` - Input feature vector.
609    /// * `bandwidth` - Controls transition sharpness. Smaller = sharper
610    ///   (closer to hard splits), larger = smoother.
611    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
612        self.predict_smooth_recursive(self.root, features, bandwidth)
613    }
614
615    /// Predict using per-feature auto-calibrated bandwidths.
616    ///
617    /// Each feature uses its own bandwidth derived from median split threshold
618    /// gaps. Features with `f64::INFINITY` bandwidth fall back to hard routing.
619    pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
620        self.predict_smooth_auto_recursive(self.root, features, bandwidths)
621    }
622
623    /// Predict with parent-leaf linear interpolation.
624    ///
625    /// Routes to the leaf but blends the leaf prediction with the parent node's
626    /// preserved prediction based on the leaf's hessian sum. Fresh leaves
627    /// (low hess_sum) smoothly transition from parent prediction to their own:
628    ///
629    /// `alpha = leaf_hess / (leaf_hess + lambda)`
630    /// `pred = alpha * leaf_pred + (1 - alpha) * parent_pred`
631    ///
632    /// This fixes static predictions from leaves that split but haven't
633    /// accumulated enough samples to outperform their parent.
634    pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
635        let mut current = self.root;
636        let mut parent = None;
637        while !self.arena.is_leaf(current) {
638            parent = Some(current);
639            let feat_idx = self.arena.get_feature_idx(current) as usize;
640            current = if let Some(mask) = self.arena.get_categorical_mask(current) {
641                let cat_val = features[feat_idx] as u64;
642                if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
643                    self.arena.get_left(current)
644                } else {
645                    self.arena.get_right(current)
646                }
647            } else {
648                let threshold = self.arena.get_threshold(current);
649                if features[feat_idx] <= threshold {
650                    self.arena.get_left(current)
651                } else {
652                    self.arena.get_right(current)
653                }
654            };
655        }
656
657        let leaf_pred = self.leaf_prediction(current, features);
658
659        // No parent (root is leaf) → return leaf prediction directly
660        let parent_id = match parent {
661            Some(p) => p,
662            None => return leaf_pred,
663        };
664
665        // Get parent's preserved prediction from its old leaf state
666        let parent_pred = self.leaf_prediction(parent_id, features);
667
668        // Blend: alpha = leaf_hess / (leaf_hess + lambda)
669        let leaf_hess = self
670            .leaf_states
671            .get(current.0 as usize)
672            .and_then(|o| o.as_ref())
673            .map(|s| s.hess_sum)
674            .unwrap_or(0.0);
675
676        let alpha = leaf_hess / (leaf_hess + self.config.lambda);
677        alpha * leaf_pred + (1.0 - alpha) * parent_pred
678    }
679
680    /// Predict with sibling-based interpolation for feature-continuous predictions.
681    ///
682    /// At the leaf's parent split, blends the leaf prediction with its sibling's
683    /// prediction based on the feature's distance from the split threshold:
684    ///
685    /// Within the margin `m` around the threshold:
686    /// `t = (feature - threshold + m) / (2 * m)`  (0 at left edge, 1 at right edge)
687    /// `pred = (1 - t) * left_pred + t * right_pred`
688    ///
689    /// Outside the margin, returns the routed child's prediction directly.
690    /// The margin `m` is derived from auto-bandwidths if available, otherwise
691    /// defaults to `feature_range / n_bins` heuristic per feature.
692    ///
693    /// This makes predictions vary continuously as features move near split
694    /// boundaries, eliminating step-function artifacts.
695    pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
696        self.predict_sibling_recursive(self.root, features, bandwidths)
697    }
698
699    fn predict_sibling_recursive(&self, node: NodeId, features: &[f64], bandwidths: &[f64]) -> f64 {
700        if self.arena.is_leaf(node) {
701            return self.leaf_prediction(node, features);
702        }
703
704        let feat_idx = self.arena.get_feature_idx(node) as usize;
705        let left = self.arena.get_left(node);
706        let right = self.arena.get_right(node);
707
708        // Categorical splits: always hard routing (no interpolation)
709        if let Some(mask) = self.arena.get_categorical_mask(node) {
710            let cat_val = features[feat_idx] as u64;
711            return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
712                self.predict_sibling_recursive(left, features, bandwidths)
713            } else {
714                self.predict_sibling_recursive(right, features, bandwidths)
715            };
716        }
717
718        let threshold = self.arena.get_threshold(node);
719        let margin = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
720
721        // No valid margin or infinite → hard routing
722        if !margin.is_finite() || margin <= 0.0 {
723            return if features[feat_idx] <= threshold {
724                self.predict_sibling_recursive(left, features, bandwidths)
725            } else {
726                self.predict_sibling_recursive(right, features, bandwidths)
727            };
728        }
729
730        let dist = features[feat_idx] - threshold;
731
732        if dist < -margin {
733            // Firmly in left child territory
734            self.predict_sibling_recursive(left, features, bandwidths)
735        } else if dist > margin {
736            // Firmly in right child territory
737            self.predict_sibling_recursive(right, features, bandwidths)
738        } else {
739            // Within the interpolation margin: linear blend
740            let t = (dist + margin) / (2.0 * margin); // 0.0 at left edge, 1.0 at right edge
741            let left_pred = self.predict_sibling_recursive(left, features, bandwidths);
742            let right_pred = self.predict_sibling_recursive(right, features, bandwidths);
743            (1.0 - t) * left_pred + t * right_pred
744        }
745    }
746
747    /// Collect all split thresholds per feature from the tree arena.
748    ///
749    /// Returns a `Vec<Vec<f64>>` indexed by feature, containing all thresholds
750    /// used in continuous splits. Categorical splits are excluded.
751    pub fn collect_split_thresholds_per_feature(&self) -> Vec<Vec<f64>> {
752        let n = self.n_features.unwrap_or(0);
753        let mut thresholds: Vec<Vec<f64>> = vec![Vec::new(); n];
754
755        for i in 0..self.arena.n_nodes() {
756            if !self.arena.is_leaf[i] && self.arena.categorical_mask[i].is_none() {
757                let feat_idx = self.arena.feature_idx[i] as usize;
758                if feat_idx < n {
759                    thresholds[feat_idx].push(self.arena.threshold[i]);
760                }
761            }
762        }
763
764        thresholds
765    }
766
767    /// Recursive sigmoid-blended prediction traversal.
768    fn predict_smooth_recursive(&self, node: NodeId, features: &[f64], bandwidth: f64) -> f64 {
769        if self.arena.is_leaf(node) {
770            // At a leaf, return the leaf prediction (same as regular predict)
771            return self.leaf_prediction(node, features);
772        }
773
774        let feat_idx = self.arena.get_feature_idx(node) as usize;
775        let left = self.arena.get_left(node);
776        let right = self.arena.get_right(node);
777
778        // Categorical splits: hard routing (sigmoid blending is meaningless for categories)
779        if let Some(mask) = self.arena.get_categorical_mask(node) {
780            let cat_val = features[feat_idx] as u64;
781            return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
782                self.predict_smooth_recursive(left, features, bandwidth)
783            } else {
784                self.predict_smooth_recursive(right, features, bandwidth)
785            };
786        }
787
788        // Continuous split: sigmoid blending for smooth transition around the threshold
789        let threshold = self.arena.get_threshold(node);
790        let z = (threshold - features[feat_idx]) / bandwidth;
791        let alpha = 1.0 / (1.0 + math::exp(-z));
792
793        let left_pred = self.predict_smooth_recursive(left, features, bandwidth);
794        let right_pred = self.predict_smooth_recursive(right, features, bandwidth);
795
796        alpha * left_pred + (1.0 - alpha) * right_pred
797    }
798
799    /// Recursive per-feature-bandwidth smooth prediction traversal.
800    fn predict_smooth_auto_recursive(
801        &self,
802        node: NodeId,
803        features: &[f64],
804        bandwidths: &[f64],
805    ) -> f64 {
806        if self.arena.is_leaf(node) {
807            return self.leaf_prediction(node, features);
808        }
809
810        let feat_idx = self.arena.get_feature_idx(node) as usize;
811        let left = self.arena.get_left(node);
812        let right = self.arena.get_right(node);
813
814        // Categorical splits: always hard routing
815        if let Some(mask) = self.arena.get_categorical_mask(node) {
816            let cat_val = features[feat_idx] as u64;
817            return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
818                self.predict_smooth_auto_recursive(left, features, bandwidths)
819            } else {
820                self.predict_smooth_auto_recursive(right, features, bandwidths)
821            };
822        }
823
824        let threshold = self.arena.get_threshold(node);
825        let bw = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
826
827        // Infinite bandwidth = feature never split on across ensemble → hard routing
828        if !bw.is_finite() {
829            return if features[feat_idx] <= threshold {
830                self.predict_smooth_auto_recursive(left, features, bandwidths)
831            } else {
832                self.predict_smooth_auto_recursive(right, features, bandwidths)
833            };
834        }
835
836        // Sigmoid-blended soft routing with per-feature bandwidth
837        let z = (threshold - features[feat_idx]) / bw;
838        let alpha = 1.0 / (1.0 + math::exp(-z));
839
840        let left_pred = self.predict_smooth_auto_recursive(left, features, bandwidths);
841        let right_pred = self.predict_smooth_auto_recursive(right, features, bandwidths);
842
843        alpha * left_pred + (1.0 - alpha) * right_pred
844    }
845
846    /// Generate the feature mask for split evaluation.
847    ///
848    /// If `feature_subsample_rate` is 1.0, all features are included.
849    /// Otherwise, a random subset is selected via xorshift64.
850    ///
851    /// Uses a `Vec<u64>` bitset for O(1) membership testing, replacing the
852    /// previous O(n) `Vec::contains()` which made the fallback loop O(n²).
853    fn generate_feature_mask(&mut self, n_features: usize) {
854        self.feature_mask.clear();
855
856        if self.config.feature_subsample_rate >= 1.0 {
857            self.feature_mask.extend(0..n_features);
858        } else {
859            let target_count =
860                crate::math::ceil((n_features as f64) * self.config.feature_subsample_rate)
861                    as usize;
862            let target_count = target_count.max(1).min(n_features);
863
864            // Prepare the bitset: one bit per feature, O(1) membership test.
865            let n_words = n_features.div_ceil(64);
866            self.feature_mask_bits.clear();
867            self.feature_mask_bits.resize(n_words, 0u64);
868
869            // Include each feature with probability = subsample_rate.
870            for i in 0..n_features {
871                let r = xorshift64(&mut self.rng_state);
872                let p = (r as f64) / (u64::MAX as f64);
873                if p < self.config.feature_subsample_rate {
874                    self.feature_mask.push(i);
875                    self.feature_mask_bits[i / 64] |= 1u64 << (i % 64);
876                }
877            }
878
879            // If we didn't get enough features, fill up deterministically.
880            // Now O(n) instead of O(n²) thanks to the bitset.
881            if self.feature_mask.len() < target_count {
882                for i in 0..n_features {
883                    if self.feature_mask.len() >= target_count {
884                        break;
885                    }
886                    if self.feature_mask_bits[i / 64] & (1u64 << (i % 64)) == 0 {
887                        self.feature_mask.push(i);
888                        self.feature_mask_bits[i / 64] |= 1u64 << (i % 64);
889                    }
890                }
891            }
892        }
893    }
894
895    /// Attempt a split at the given leaf node.
896    ///
897    /// Returns `true` if a split was performed.
898    fn attempt_split(&mut self, leaf_id: NodeId) -> bool {
899        let depth = self.arena.get_depth(leaf_id);
900
901        // When adaptive_depth is enabled, max_depth * 2 is the hard safety ceiling;
902        // the per-split CIR test handles generalization. Otherwise, use static max_depth.
903        let hard_ceiling = if self.config.adaptive_depth.is_some() {
904            self.config.max_depth.saturating_mul(2)
905        } else {
906            self.config.max_depth
907        };
908        let at_max_depth = depth as usize >= hard_ceiling;
909
910        if at_max_depth {
911            // Only proceed if split re-evaluation is enabled and the interval
912            // has elapsed since the last evaluation at this leaf.
913            match self.config.split_reeval_interval {
914                None => return false,
915                Some(interval) => {
916                    let state = match self
917                        .leaf_states
918                        .get(leaf_id.0 as usize)
919                        .and_then(|o| o.as_ref())
920                    {
921                        Some(s) => s,
922                        None => return false,
923                    };
924                    let sample_count = self.arena.get_sample_count(leaf_id);
925                    if sample_count - state.last_reeval_count < interval as u64 {
926                        return false;
927                    }
928                    // Fall through to evaluate potential split.
929                }
930            }
931        }
932
933        let n_features = match self.n_features {
934            Some(n) => n,
935            None => return false,
936        };
937
938        let sample_count = self.arena.get_sample_count(leaf_id);
939        if sample_count < self.config.grace_period as u64 {
940            return false;
941        }
942
943        // Generate the feature mask for this split evaluation.
944        self.generate_feature_mask(n_features);
945
946        // Materialize pending lazy decay before reading histogram data.
947        // This converts un-decayed coordinates to true decayed values so
948        // split evaluation sees correct gradient/hessian sums. O(n_features * n_bins)
949        // but amortized over grace_period samples -- not per-sample cost.
950        if self.config.leaf_decay_alpha.is_some() {
951            if let Some(state) = self
952                .leaf_states
953                .get_mut(leaf_id.0 as usize)
954                .and_then(|o| o.as_mut())
955            {
956                if let Some(ref mut histograms) = state.histograms {
957                    histograms.materialize_decay();
958                }
959            }
960        }
961
962        // Evaluate splits for each feature in the mask.
963        // We need to borrow leaf_states immutably while feature_mask is borrowed.
964        // Collect candidates first.
965        let state = match self
966            .leaf_states
967            .get(leaf_id.0 as usize)
968            .and_then(|o| o.as_ref())
969        {
970            Some(s) => s,
971            None => return false,
972        };
973
974        let histograms = match &state.histograms {
975            Some(h) => h,
976            None => return false,
977        };
978
979        // Collect (feature_idx, best_split_candidate, optional_fisher_order) for
980        // each feature in the mask. For categorical features, we reorder bins by
981        // Fisher optimal binary partitioning before evaluation.
982        let feature_types = &self.config.feature_types;
983        let mut candidates: Vec<(usize, SplitCandidate, Option<Vec<usize>>)> = Vec::new();
984
985        for &feat_idx in &self.feature_mask {
986            if feat_idx >= histograms.n_features() {
987                continue;
988            }
989            let hist = &histograms.histograms[feat_idx];
990            let total_grad = hist.total_gradient();
991            let total_hess = hist.total_hessian();
992
993            let is_categorical = feature_types
994                .as_ref()
995                .is_some_and(|ft| feat_idx < ft.len() && ft[feat_idx] == FeatureType::Categorical);
996
997            if is_categorical {
998                // Fisher optimal binary partitioning:
999                // 1. Compute gradient_sum/hessian_sum ratio per bin
1000                // 2. Sort bins by this ratio
1001                // 3. Evaluate splits on the sorted order
1002                let n_bins = hist.grad_sums.len();
1003                if n_bins < 2 {
1004                    continue;
1005                }
1006
1007                // Build (bin_index, ratio) pairs, filtering out empty bins
1008                let mut bin_order: Vec<usize> = (0..n_bins)
1009                    .filter(|&i| math::abs(hist.hess_sums[i]) > 1e-15)
1010                    .collect();
1011
1012                if bin_order.len() < 2 {
1013                    continue;
1014                }
1015
1016                // Sort by grad_sum / hess_sum ratio (ascending)
1017                bin_order.sort_by(|&a, &b| {
1018                    let ratio_a = hist.grad_sums[a] / hist.hess_sums[a];
1019                    let ratio_b = hist.grad_sums[b] / hist.hess_sums[b];
1020                    ratio_a
1021                        .partial_cmp(&ratio_b)
1022                        .unwrap_or(core::cmp::Ordering::Equal)
1023                });
1024
1025                // Reorder grad/hess sums according to Fisher order
1026                let sorted_grads: Vec<f64> = bin_order.iter().map(|&i| hist.grad_sums[i]).collect();
1027                let sorted_hess: Vec<f64> = bin_order.iter().map(|&i| hist.hess_sums[i]).collect();
1028
1029                if let Some(candidate) = self.split_criterion.evaluate(
1030                    &sorted_grads,
1031                    &sorted_hess,
1032                    total_grad,
1033                    total_hess,
1034                    self.config.gamma,
1035                    self.config.lambda,
1036                ) {
1037                    candidates.push((feat_idx, candidate, Some(bin_order)));
1038                }
1039            } else {
1040                // Standard continuous feature -- evaluate as-is
1041                if let Some(candidate) = self.split_criterion.evaluate(
1042                    &hist.grad_sums,
1043                    &hist.hess_sums,
1044                    total_grad,
1045                    total_hess,
1046                    self.config.gamma,
1047                    self.config.lambda,
1048                ) {
1049                    candidates.push((feat_idx, candidate, None));
1050                }
1051            }
1052        }
1053
1054        // Filter out candidates that violate monotonic constraints.
1055        if let Some(ref mc) = self.config.monotone_constraints {
1056            candidates.retain(|(feat_idx, candidate, _)| {
1057                if *feat_idx >= mc.len() {
1058                    return true; // No constraint for this feature
1059                }
1060                let constraint = mc[*feat_idx];
1061                if constraint == 0 {
1062                    return true; // Unconstrained
1063                }
1064
1065                let left_val =
1066                    leaf_weight(candidate.left_grad, candidate.left_hess, self.config.lambda);
1067                let right_val = leaf_weight(
1068                    candidate.right_grad,
1069                    candidate.right_hess,
1070                    self.config.lambda,
1071                );
1072
1073                if constraint > 0 {
1074                    // Non-decreasing: left_value <= right_value
1075                    left_val <= right_val
1076                } else {
1077                    // Non-increasing: left_value >= right_value
1078                    left_val >= right_val
1079                }
1080            });
1081        }
1082
1083        if candidates.is_empty() {
1084            return false;
1085        }
1086
1087        // Sort candidates by gain descending.
1088        candidates.sort_by(|a, b| {
1089            b.1.gain
1090                .partial_cmp(&a.1.gain)
1091                .unwrap_or(core::cmp::Ordering::Equal)
1092        });
1093
1094        let best_gain = candidates[0].1.gain;
1095        let second_best_gain = if candidates.len() > 1 {
1096            candidates[1].1.gain
1097        } else {
1098            0.0
1099        };
1100
1101        // Per-split information criterion (Lunde-Kleppe-Skaug 2020).
1102        // Acts as a FIRST gate before the Hoeffding bound check.
1103        if let Some(cir_factor) = self.config.adaptive_depth {
1104            let n = sample_count as f64;
1105            if n > 1.0 {
1106                // Use effective_n with EWMA decay if configured
1107                let effective_n = match self.config.leaf_decay_alpha {
1108                    Some(alpha) => n.min(1.0 / (1.0 - alpha)),
1109                    None => n,
1110                };
1111
1112                // Get gradient variance from Welford stats in the leaf state
1113                let grad_var = self
1114                    .leaf_states
1115                    .get(leaf_id.0 as usize)
1116                    .and_then(|o| o.as_ref())
1117                    .map(|leaf_state| {
1118                        if leaf_state.clip_grad_count > 1 {
1119                            leaf_state.clip_grad_m2 / (leaf_state.clip_grad_count as f64 - 1.0)
1120                        } else {
1121                            // Fallback: estimate from grad/hess sums
1122                            let mean_grad = leaf_state.grad_sum / leaf_state.hess_sum.max(1.0);
1123                            mean_grad * mean_grad + 1.0
1124                        }
1125                    })
1126                    .unwrap_or(1.0);
1127
1128                let n_feat = self.n_features.unwrap_or(1) as f64;
1129                let penalty = cir_factor * grad_var / effective_n * n_feat;
1130
1131                if best_gain <= penalty {
1132                    return false; // Don't split — insufficient generalization evidence
1133                }
1134            }
1135        }
1136
1137        // Hoeffding bound: epsilon = sqrt(R^2 * ln(1/delta) / (2 * n))
1138        // R = 1.0 (conservative bound on the range of the gain function).
1139        //
1140        // With EWMA decay, the effective sample size is bounded by 1/(1-alpha).
1141        // We cap n at this value to prevent spurious splits from artificially
1142        // tight bounds when decay is active.
1143        let r_squared = 1.0;
1144        let n = sample_count as f64;
1145        let effective_n = match self.config.leaf_decay_alpha {
1146            Some(alpha) => n.min(1.0 / (1.0 - alpha)),
1147            None => n,
1148        };
1149        let ln_inv_delta = math::ln(1.0 / self.config.delta);
1150        let epsilon = math::sqrt(r_squared * ln_inv_delta / (2.0 * effective_n));
1151
1152        // Split condition: the best is significantly better than second-best,
1153        // OR the bound is already so tight that more samples won't help.
1154        let gap = best_gain - second_best_gain;
1155        if gap <= epsilon && epsilon >= TAU {
1156            // If this was a re-evaluation at max depth, update the count
1157            // so we don't re-evaluate again until the next interval elapses.
1158            if at_max_depth {
1159                if let Some(state) = self
1160                    .leaf_states
1161                    .get_mut(leaf_id.0 as usize)
1162                    .and_then(|o| o.as_mut())
1163                {
1164                    state.last_reeval_count = sample_count;
1165                }
1166            }
1167            return false;
1168        }
1169
1170        // --- Execute the split ---
1171        let (best_feat_idx, ref best_candidate, ref fisher_order) = candidates[0];
1172
1173        // Track split gain for feature importance.
1174        if best_feat_idx < self.split_gains.len() {
1175            self.split_gains[best_feat_idx] += best_candidate.gain;
1176        }
1177
1178        let best_hist = &histograms.histograms[best_feat_idx];
1179
1180        let left_value = leaf_weight(
1181            best_candidate.left_grad,
1182            best_candidate.left_hess,
1183            self.config.lambda,
1184        );
1185        let right_value = leaf_weight(
1186            best_candidate.right_grad,
1187            best_candidate.right_hess,
1188            self.config.lambda,
1189        );
1190
1191        // Perform the split -- categorical or continuous.
1192        let (left_id, right_id) = if let Some(ref order) = fisher_order {
1193            // Categorical split: build a bitmask from the Fisher-sorted partition.
1194            // bin_idx in the sorted order means bins order[0..=bin_idx] go left.
1195            // We need to map those back to original bin indices and set their bits.
1196            //
1197            // For categorical features, bin index = category value (since we use
1198            // one bin per category with midpoint edges).
1199            let mut mask: u64 = 0;
1200            for &sorted_pos in order.iter().take(best_candidate.bin_idx + 1) {
1201                // sorted_pos is the original bin index; for categorical features,
1202                // bin index corresponds to the category's position in sorted categories.
1203                // The actual category value is stored as an integer that maps to this bin.
1204                if sorted_pos < 64 {
1205                    mask |= 1u64 << sorted_pos;
1206                }
1207            }
1208
1209            // Threshold stores 0.0 for categorical splits (routing uses mask).
1210            self.arena.split_leaf_categorical(
1211                leaf_id,
1212                best_feat_idx as u32,
1213                0.0,
1214                left_value,
1215                right_value,
1216                mask,
1217            )
1218        } else {
1219            // Continuous split: standard threshold from bin edge.
1220            let threshold = if best_candidate.bin_idx < best_hist.edges.edges.len() {
1221                best_hist.edges.edges[best_candidate.bin_idx]
1222            } else {
1223                f64::MAX
1224            };
1225
1226            self.arena.split_leaf(
1227                leaf_id,
1228                best_feat_idx as u32,
1229                threshold,
1230                left_value,
1231                right_value,
1232            )
1233        };
1234
1235        // Build child histograms using the subtraction trick.
1236        // The "left" child gets a fresh histogram set built from the parent's
1237        // bins, populated by scanning parent bins [0..=bin_idx].
1238        // Instead of re-scanning, we use the subtraction trick: one child
1239        // gets the parent's histograms minus the other child's.
1240        //
1241        // Strategy: build the left child's histogram directly from the
1242        // parent histogram for each feature (summing bins 0..=best_bin for
1243        // the split feature). Then the right child = parent - left.
1244        //
1245        // Actually, for a streaming tree, the cleaner approach is:
1246        // - Remove the parent's state
1247        // - Create fresh states for both children (they'll accumulate from
1248        //   new samples going forward)
1249        // - BUT we can seed one child with the parent's histograms by
1250        //   constructing "virtual" histograms from the parent's data.
1251        //
1252        // The simplest correct approach: both children start fresh with
1253        // pre-computed bin edges from the parent, so they're immediately
1254        // ready to accumulate. We don't carry forward the parent's histogram
1255        // data because new samples will naturally populate the children.
1256        //
1257        // However, to be more efficient, we CAN carry forward histogram data
1258        // using the subtraction trick. Let's do this properly:
1259
1260        let parent_state = self
1261            .leaf_states
1262            .get_mut(leaf_id.0 as usize)
1263            .and_then(|o| o.take());
1264        let nf = n_features;
1265
1266        // Ensure Vec is large enough for child NodeIds.
1267        let max_child = left_id.0.max(right_id.0) as usize;
1268        if self.leaf_states.len() <= max_child {
1269            self.leaf_states.resize_with(max_child + 1, || None);
1270        }
1271
1272        if let Some(parent) = parent_state {
1273            if let Some(parent_hists) = parent.histograms {
1274                // Build left child histograms from the parent.
1275                let edges_per_feature: Vec<BinEdges> = parent_hists
1276                    .histograms
1277                    .iter()
1278                    .map(|h| h.edges.clone())
1279                    .collect();
1280
1281                // The left child inherits a copy of the parent histograms,
1282                // but we really want to compute how much of the parent's data
1283                // would have gone left vs right. We don't have that per-feature
1284                // breakdown for features other than the split feature.
1285                //
1286                // Correct approach: create fresh histogram states for both
1287                // children with the same bin edges. They start empty but
1288                // with bins_ready = true, so new samples immediately accumulate.
1289                let left_hists = LeafHistograms::new(&edges_per_feature);
1290                let right_hists = LeafHistograms::new(&edges_per_feature);
1291
1292                let ft = self.config.feature_types.as_deref();
1293                let child_binners_l = make_binners(nf, ft);
1294                let child_binners_r = make_binners(nf, ft);
1295
1296                // Warm-start children from parent's learned leaf model.
1297                // If parent has a model, children inherit its weights (resetting
1298                // optimizer state). If parent has no model (ClosedForm), children
1299                // also get None -- the fast path stays fast.
1300                let left_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1301                let right_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1302
1303                let left_state = LeafState {
1304                    histograms: Some(left_hists),
1305                    binners: child_binners_l,
1306                    bins_ready: true,
1307                    grad_sum: 0.0,
1308                    hess_sum: 0.0,
1309                    last_reeval_count: 0,
1310                    clip_grad_mean: 0.0,
1311                    clip_grad_m2: 0.0,
1312                    clip_grad_count: 0,
1313                    output_mean: 0.0,
1314                    output_m2: 0.0,
1315                    output_count: 0,
1316                    leaf_model: left_model,
1317                };
1318
1319                let right_state = LeafState {
1320                    histograms: Some(right_hists),
1321                    binners: child_binners_r,
1322                    bins_ready: true,
1323                    grad_sum: 0.0,
1324                    hess_sum: 0.0,
1325                    last_reeval_count: 0,
1326                    clip_grad_mean: 0.0,
1327                    clip_grad_m2: 0.0,
1328                    clip_grad_count: 0,
1329                    output_mean: 0.0,
1330                    output_m2: 0.0,
1331                    output_count: 0,
1332                    leaf_model: right_model,
1333                };
1334
1335                self.leaf_states[left_id.0 as usize] = Some(left_state);
1336                self.leaf_states[right_id.0 as usize] = Some(right_state);
1337            } else {
1338                // Parent didn't have histograms (shouldn't happen if bins_ready).
1339                let ft = self.config.feature_types.as_deref();
1340                let mut ls = LeafState::new_with_types(nf, ft);
1341                ls.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1342                self.leaf_states[left_id.0 as usize] = Some(ls);
1343                let mut rs = LeafState::new_with_types(nf, ft);
1344                rs.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1345                self.leaf_states[right_id.0 as usize] = Some(rs);
1346            }
1347        } else {
1348            // No parent state found (shouldn't happen).
1349            let ft = self.config.feature_types.as_deref();
1350            let mut ls = LeafState::new_with_types(nf, ft);
1351            ls.leaf_model = self.make_leaf_model(left_id);
1352            self.leaf_states[left_id.0 as usize] = Some(ls);
1353            let mut rs = LeafState::new_with_types(nf, ft);
1354            rs.leaf_model = self.make_leaf_model(right_id);
1355            self.leaf_states[right_id.0 as usize] = Some(rs);
1356        }
1357
1358        true
1359    }
1360}
1361
1362impl StreamingTree for HoeffdingTree {
1363    /// Train the tree on a single sample.
1364    ///
1365    /// Routes the sample to its leaf, updates histogram accumulators, and
1366    /// attempts a split if the Hoeffding bound is satisfied.
1367    fn train_one(&mut self, features: &[f64], gradient: f64, hessian: f64) {
1368        self.samples_seen += 1;
1369
1370        // Initialize n_features on first sample.
1371        let n_features = if let Some(n) = self.n_features {
1372            n
1373        } else {
1374            let n = features.len();
1375            self.n_features = Some(n);
1376            self.split_gains.resize(n, 0.0);
1377
1378            // Re-initialize the root's leaf state now that we know n_features.
1379            if let Some(state) = self
1380                .leaf_states
1381                .get_mut(self.root.0 as usize)
1382                .and_then(|o| o.as_mut())
1383            {
1384                state.binners = make_binners(n, self.config.feature_types.as_deref());
1385            }
1386            n
1387        };
1388
1389        debug_assert_eq!(
1390            features.len(),
1391            n_features,
1392            "feature count mismatch: got {} but expected {}",
1393            features.len(),
1394            n_features,
1395        );
1396
1397        // Route to leaf.
1398        let leaf_id = self.route_to_leaf(features);
1399
1400        // Increment the sample count in the arena.
1401        self.arena.increment_sample_count(leaf_id);
1402        let sample_count = self.arena.get_sample_count(leaf_id);
1403
1404        // Get or create the leaf state.
1405        let idx = leaf_id.0 as usize;
1406        if self.leaf_states.len() <= idx {
1407            self.leaf_states.resize_with(idx + 1, || None);
1408        }
1409        if self.leaf_states[idx].is_none() {
1410            self.leaf_states[idx] = Some(LeafState::new_with_types(
1411                n_features,
1412                self.config.feature_types.as_deref(),
1413            ));
1414        }
1415        let state = self.leaf_states[idx].as_mut().unwrap();
1416
1417        // Apply per-leaf gradient clipping if enabled.
1418        let gradient = if let Some(sigma) = self.config.gradient_clip_sigma {
1419            clip_gradient(state, gradient, sigma)
1420        } else {
1421            gradient
1422        };
1423
1424        // If bins are not yet ready, check if we've reached the grace period.
1425        if !state.bins_ready {
1426            // Observe feature values in the binners.
1427            for (binner, &val) in state.binners.iter_mut().zip(features.iter()) {
1428                binner.observe(val);
1429            }
1430
1431            // Accumulate running gradient/hessian sums (with optional EWMA decay).
1432            if let Some(alpha) = self.config.leaf_decay_alpha {
1433                state.grad_sum = alpha * state.grad_sum + gradient;
1434                state.hess_sum = alpha * state.hess_sum + hessian;
1435            } else {
1436                state.grad_sum += gradient;
1437                state.hess_sum += hessian;
1438            }
1439
1440            // Update the leaf value from running sums.
1441            let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
1442            self.arena.set_leaf_value(leaf_id, lw);
1443
1444            // Track per-leaf output weight for adaptive bounds.
1445            if self.config.adaptive_leaf_bound.is_some() {
1446                update_output_stats(state, lw, self.config.leaf_decay_alpha);
1447            }
1448
1449            // Update the leaf model if one exists (linear / MLP).
1450            if let Some(ref mut model) = state.leaf_model {
1451                model.update(features, gradient, hessian, self.config.lambda);
1452            }
1453
1454            // Check if we've reached the grace period to compute bin edges.
1455            if sample_count >= self.config.grace_period as u64 {
1456                let edges_per_feature: Vec<BinEdges> = state
1457                    .binners
1458                    .iter()
1459                    .map(|b| b.compute_edges(self.config.n_bins))
1460                    .collect();
1461
1462                let mut histograms = LeafHistograms::new(&edges_per_feature);
1463
1464                // We don't have the raw samples to replay into the histogram,
1465                // but we DO have the running grad/hess sums. We can't distribute
1466                // them across bins retroactively. The histograms start empty and
1467                // will accumulate from the next sample onward.
1468                // However, we should NOT lose the current sample. Let's accumulate
1469                // this sample into the newly created histograms.
1470                if let Some(alpha) = self.config.leaf_decay_alpha {
1471                    histograms.accumulate_with_decay(features, gradient, hessian, alpha);
1472                } else {
1473                    histograms.accumulate(features, gradient, hessian);
1474                }
1475
1476                state.histograms = Some(histograms);
1477                state.bins_ready = true;
1478            }
1479
1480            return;
1481        }
1482
1483        // Bins are ready -- accumulate into histograms (with optional decay).
1484        if let Some(ref mut histograms) = state.histograms {
1485            if let Some(alpha) = self.config.leaf_decay_alpha {
1486                histograms.accumulate_with_decay(features, gradient, hessian, alpha);
1487            } else {
1488                histograms.accumulate(features, gradient, hessian);
1489            }
1490        }
1491
1492        // Update running gradient/hessian sums and leaf value (with optional EWMA decay).
1493        if let Some(alpha) = self.config.leaf_decay_alpha {
1494            state.grad_sum = alpha * state.grad_sum + gradient;
1495            state.hess_sum = alpha * state.hess_sum + hessian;
1496        } else {
1497            state.grad_sum += gradient;
1498            state.hess_sum += hessian;
1499        }
1500        let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
1501        self.arena.set_leaf_value(leaf_id, lw);
1502
1503        // Track per-leaf output weight for adaptive bounds.
1504        if self.config.adaptive_leaf_bound.is_some() {
1505            update_output_stats(state, lw, self.config.leaf_decay_alpha);
1506        }
1507
1508        // Update the leaf model if one exists (linear / MLP).
1509        if let Some(ref mut model) = state.leaf_model {
1510            model.update(features, gradient, hessian, self.config.lambda);
1511        }
1512
1513        // Attempt split.
1514        // We only try every grace_period samples to avoid excessive computation.
1515        if sample_count % (self.config.grace_period as u64) == 0 {
1516            self.attempt_split(leaf_id);
1517        }
1518    }
1519
1520    /// Predict the leaf value for a feature vector.
1521    ///
1522    /// Routes from the root to a leaf via threshold comparisons and returns
1523    /// the leaf's current weight.
1524    fn predict(&self, features: &[f64]) -> f64 {
1525        let leaf_id = self.route_to_leaf(features);
1526        self.leaf_prediction(leaf_id, features)
1527    }
1528
1529    /// Current number of leaf nodes.
1530    #[inline]
1531    fn n_leaves(&self) -> usize {
1532        self.arena.n_leaves()
1533    }
1534
1535    /// Total number of samples seen since creation.
1536    #[inline]
1537    fn n_samples_seen(&self) -> u64 {
1538        self.samples_seen
1539    }
1540
1541    /// Reset to initial state with a single root leaf.
1542    fn reset(&mut self) {
1543        self.arena.reset();
1544        let root = self.arena.add_leaf(0);
1545        self.root = root;
1546        self.leaf_states.clear();
1547
1548        // Insert a placeholder leaf state for the new root.
1549        let n_features = self.n_features.unwrap_or(0);
1550        self.leaf_states.resize_with(root.0 as usize + 1, || None);
1551        let mut root_state =
1552            LeafState::new_with_types(n_features, self.config.feature_types.as_deref());
1553        root_state.leaf_model = self.make_leaf_model(root);
1554        self.leaf_states[root.0 as usize] = Some(root_state);
1555
1556        self.samples_seen = 0;
1557        self.feature_mask.clear();
1558        self.feature_mask_bits.clear();
1559        self.rng_state = self.config.seed;
1560        self.split_gains.iter_mut().for_each(|g| *g = 0.0);
1561    }
1562
1563    fn split_gains(&self) -> &[f64] {
1564        &self.split_gains
1565    }
1566
1567    fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
1568        let leaf_id = self.route_to_leaf(features);
1569        let value = self.leaf_prediction(leaf_id, features);
1570        if let Some(state) = self
1571            .leaf_states
1572            .get(leaf_id.0 as usize)
1573            .and_then(|o| o.as_ref())
1574        {
1575            // Variance of the leaf weight estimate = 1 / (H_sum + lambda)
1576            let variance = 1.0 / (state.hess_sum + self.config.lambda);
1577            (value, variance)
1578        } else {
1579            (value, f64::INFINITY)
1580        }
1581    }
1582}
1583
1584impl Clone for HoeffdingTree {
1585    fn clone(&self) -> Self {
1586        Self {
1587            arena: self.arena.clone(),
1588            root: self.root,
1589            config: self.config.clone(),
1590            leaf_states: self.leaf_states.clone(),
1591            n_features: self.n_features,
1592            samples_seen: self.samples_seen,
1593            split_criterion: self.split_criterion,
1594            feature_mask: self.feature_mask.clone(),
1595            feature_mask_bits: self.feature_mask_bits.clone(),
1596            rng_state: self.rng_state,
1597            split_gains: self.split_gains.clone(),
1598        }
1599    }
1600}
1601
1602// SAFETY: All fields are Send + Sync. BinnerKind is a concrete enum with
1603// Send + Sync variants. XGBoostGain is stateless. Vec<Option<LeafState>>
1604// and Vec fields are trivially Send + Sync.
1605unsafe impl Send for HoeffdingTree {}
1606unsafe impl Sync for HoeffdingTree {}
1607
1608#[cfg(test)]
1609mod tests {
1610    use super::*;
1611    use crate::tree::builder::TreeConfig;
1612    use crate::tree::StreamingTree;
1613
1614    /// Simple xorshift64 for test reproducibility (same as the tree uses).
1615    fn test_xorshift(state: &mut u64) -> u64 {
1616        xorshift64(state)
1617    }
1618
1619    /// Generate a pseudo-random f64 in [0, 1) from the RNG state.
1620    fn test_rand_f64(state: &mut u64) -> f64 {
1621        let r = test_xorshift(state);
1622        (r as f64) / (u64::MAX as f64)
1623    }
1624
1625    // -----------------------------------------------------------------------
1626    // Test 1: Single sample train + predict returns non-NaN.
1627    // -----------------------------------------------------------------------
1628    #[test]
1629    fn single_sample_predict_not_nan() {
1630        let config = TreeConfig::new().grace_period(10);
1631        let mut tree = HoeffdingTree::new(config);
1632
1633        let features = vec![1.0, 2.0, 3.0];
1634        tree.train_one(&features, -0.5, 1.0);
1635
1636        let pred = tree.predict(&features);
1637        assert!(!pred.is_nan(), "prediction should not be NaN, got {}", pred);
1638        assert!(
1639            pred.is_finite(),
1640            "prediction should be finite, got {}",
1641            pred
1642        );
1643
1644        // With gradient=-0.5, hessian=1.0, lambda=1.0:
1645        // leaf_weight = -(-0.5) / (1.0 + 1.0) = 0.25
1646        assert!((pred - 0.25).abs() < 1e-10, "expected ~0.25, got {}", pred);
1647    }
1648
1649    // -----------------------------------------------------------------------
1650    // Test 2: Train 1000 samples from y=2*x + noise, verify RMSE decreases.
1651    // -----------------------------------------------------------------------
1652    #[test]
1653    fn linear_signal_rmse_improves() {
1654        let config = TreeConfig::new()
1655            .max_depth(4)
1656            .n_bins(32)
1657            .grace_period(50)
1658            .lambda(0.1)
1659            .gamma(0.0)
1660            .delta(1e-3);
1661
1662        let mut tree = HoeffdingTree::new(config);
1663        let mut rng_state: u64 = 12345;
1664
1665        // Generate training data: y = 2*x, with x in [0, 10].
1666        // For gradient boosting, gradient = prediction - target (for squared loss),
1667        // hessian = 1.0.
1668        //
1669        // We'll simulate a simple boosting loop:
1670        // - Start with prediction = 0 for all points.
1671        // - gradient = pred - target = 0 - y = -y
1672        // - hessian = 1.0
1673
1674        let n_train = 1000;
1675        let mut features_all: Vec<f64> = Vec::with_capacity(n_train);
1676        let mut targets: Vec<f64> = Vec::with_capacity(n_train);
1677
1678        for _ in 0..n_train {
1679            let x = test_rand_f64(&mut rng_state) * 10.0;
1680            let noise = (test_rand_f64(&mut rng_state) - 0.5) * 0.5;
1681            let y = 2.0 * x + noise;
1682            features_all.push(x);
1683            targets.push(y);
1684        }
1685
1686        // Compute initial RMSE (prediction = 0).
1687        let initial_mse: f64 = targets.iter().map(|y| y * y).sum::<f64>() / n_train as f64;
1688        let initial_rmse = initial_mse.sqrt();
1689
1690        // Train the tree.
1691        for i in 0..n_train {
1692            let feat = [features_all[i]];
1693            let pred = tree.predict(&feat);
1694            // For squared loss: gradient = pred - target, hessian = 1.0.
1695            let gradient = pred - targets[i];
1696            let hessian = 1.0;
1697            tree.train_one(&feat, gradient, hessian);
1698        }
1699
1700        // Compute post-training RMSE.
1701        let mut post_mse = 0.0;
1702        for i in 0..n_train {
1703            let feat = [features_all[i]];
1704            let pred = tree.predict(&feat);
1705            let err = pred - targets[i];
1706            post_mse += err * err;
1707        }
1708        post_mse /= n_train as f64;
1709        let post_rmse = post_mse.sqrt();
1710
1711        assert!(
1712            post_rmse < initial_rmse,
1713            "RMSE should decrease after training: initial={:.4}, post={:.4}",
1714            initial_rmse,
1715            post_rmse,
1716        );
1717    }
1718
1719    // -----------------------------------------------------------------------
1720    // Test 3: No splits before grace_period samples.
1721    // -----------------------------------------------------------------------
1722    #[test]
1723    fn no_splits_before_grace_period() {
1724        let grace = 100;
1725        let config = TreeConfig::new()
1726            .grace_period(grace)
1727            .max_depth(4)
1728            .n_bins(16)
1729            .delta(1e-1); // Very lenient delta to make splits easy.
1730
1731        let mut tree = HoeffdingTree::new(config);
1732        let mut rng_state: u64 = 99999;
1733
1734        // Train grace_period - 1 samples.
1735        for _ in 0..(grace - 1) {
1736            let x = test_rand_f64(&mut rng_state) * 10.0;
1737            let y = 2.0 * x;
1738            let feat = [x];
1739            let pred = tree.predict(&feat);
1740            tree.train_one(&feat, pred - y, 1.0);
1741        }
1742
1743        assert_eq!(
1744            tree.n_leaves(),
1745            1,
1746            "should be exactly 1 leaf before grace_period, got {}",
1747            tree.n_leaves()
1748        );
1749    }
1750
1751    // -----------------------------------------------------------------------
1752    // Test 4: Tree does not exceed max_depth.
1753    // -----------------------------------------------------------------------
1754    #[test]
1755    fn respects_max_depth() {
1756        let max_depth = 3;
1757        let config = TreeConfig::new()
1758            .max_depth(max_depth)
1759            .grace_period(20)
1760            .n_bins(16)
1761            .lambda(0.01)
1762            .gamma(0.0)
1763            .delta(1e-1); // Very lenient.
1764
1765        let mut tree = HoeffdingTree::new(config);
1766        let mut rng_state: u64 = 7777;
1767
1768        // Train many samples with a clear signal to force splitting.
1769        for _ in 0..5000 {
1770            let x = test_rand_f64(&mut rng_state) * 10.0;
1771            let y = if x < 2.5 {
1772                -5.0
1773            } else if x < 5.0 {
1774                -1.0
1775            } else if x < 7.5 {
1776                1.0
1777            } else {
1778                5.0
1779            };
1780            let feat = [x];
1781            let pred = tree.predict(&feat);
1782            tree.train_one(&feat, pred - y, 1.0);
1783        }
1784
1785        // Maximum number of leaves at depth d is 2^d.
1786        let max_leaves = 1usize << max_depth;
1787        assert!(
1788            tree.n_leaves() <= max_leaves,
1789            "tree has {} leaves, but max_depth={} allows at most {}",
1790            tree.n_leaves(),
1791            max_depth,
1792            max_leaves,
1793        );
1794    }
1795
1796    // -----------------------------------------------------------------------
1797    // Test 5: Reset works -- tree returns to single leaf.
1798    // -----------------------------------------------------------------------
1799    #[test]
1800    fn reset_returns_to_single_leaf() {
1801        let config = TreeConfig::new()
1802            .grace_period(20)
1803            .max_depth(4)
1804            .n_bins(16)
1805            .delta(1e-1);
1806
1807        let mut tree = HoeffdingTree::new(config);
1808        let mut rng_state: u64 = 54321;
1809
1810        // Train enough to potentially cause splits.
1811        for _ in 0..2000 {
1812            let x = test_rand_f64(&mut rng_state) * 10.0;
1813            let y = 3.0 * x - 5.0;
1814            let feat = [x];
1815            let pred = tree.predict(&feat);
1816            tree.train_one(&feat, pred - y, 1.0);
1817        }
1818
1819        let pre_reset_samples = tree.n_samples_seen();
1820        assert!(pre_reset_samples > 0);
1821
1822        tree.reset();
1823
1824        assert_eq!(
1825            tree.n_leaves(),
1826            1,
1827            "after reset, should have exactly 1 leaf"
1828        );
1829        assert_eq!(
1830            tree.n_samples_seen(),
1831            0,
1832            "after reset, samples_seen should be 0"
1833        );
1834
1835        // Predict should still work (returns 0.0 from empty leaf).
1836        let pred = tree.predict(&[5.0]);
1837        assert!(
1838            pred.abs() < 1e-10,
1839            "prediction after reset should be ~0.0, got {}",
1840            pred
1841        );
1842    }
1843
1844    // -----------------------------------------------------------------------
1845    // Test 6: Multiple features -- verify tree uses different features.
1846    // -----------------------------------------------------------------------
1847    #[test]
1848    fn multi_feature_training() {
1849        let config = TreeConfig::new()
1850            .grace_period(30)
1851            .max_depth(4)
1852            .n_bins(16)
1853            .lambda(0.1)
1854            .delta(1e-2);
1855
1856        let mut tree = HoeffdingTree::new(config);
1857        let mut rng_state: u64 = 11111;
1858
1859        // y = x0 + 2*x1, two features.
1860        for _ in 0..1000 {
1861            let x0 = test_rand_f64(&mut rng_state) * 5.0;
1862            let x1 = test_rand_f64(&mut rng_state) * 5.0;
1863            let y = x0 + 2.0 * x1;
1864            let feat = [x0, x1];
1865            let pred = tree.predict(&feat);
1866            tree.train_one(&feat, pred - y, 1.0);
1867        }
1868
1869        // Just verify it trained without panicking and produces finite predictions.
1870        let pred = tree.predict(&[2.5, 2.5]);
1871        assert!(
1872            pred.is_finite(),
1873            "multi-feature prediction should be finite"
1874        );
1875        assert_eq!(tree.n_samples_seen(), 1000);
1876    }
1877
1878    // -----------------------------------------------------------------------
1879    // Test 7: Feature subsampling does not panic.
1880    // -----------------------------------------------------------------------
1881    #[test]
1882    fn feature_subsampling_works() {
1883        let config = TreeConfig::new()
1884            .grace_period(30)
1885            .max_depth(3)
1886            .n_bins(16)
1887            .lambda(0.1)
1888            .delta(1e-2)
1889            .feature_subsample_rate(0.5);
1890
1891        let mut tree = HoeffdingTree::new(config);
1892        let mut rng_state: u64 = 33333;
1893
1894        // 5 features, only ~50% considered per split.
1895        for _ in 0..1000 {
1896            let feats: Vec<f64> = (0..5)
1897                .map(|_| test_rand_f64(&mut rng_state) * 10.0)
1898                .collect();
1899            let y: f64 = feats.iter().sum();
1900            let pred = tree.predict(&feats);
1901            tree.train_one(&feats, pred - y, 1.0);
1902        }
1903
1904        let pred = tree.predict(&[1.0, 2.0, 3.0, 4.0, 5.0]);
1905        assert!(pred.is_finite(), "subsampled prediction should be finite");
1906    }
1907
1908    // -----------------------------------------------------------------------
1909    // Test 8: xorshift64 produces deterministic sequence.
1910    // -----------------------------------------------------------------------
1911    #[test]
1912    fn xorshift64_deterministic() {
1913        let mut s1: u64 = 42;
1914        let mut s2: u64 = 42;
1915
1916        let seq1: Vec<u64> = (0..100).map(|_| xorshift64(&mut s1)).collect();
1917        let seq2: Vec<u64> = (0..100).map(|_| xorshift64(&mut s2)).collect();
1918
1919        assert_eq!(seq1, seq2, "xorshift64 should be deterministic");
1920
1921        // Verify no zeros in the sequence (xorshift64 with non-zero seed never produces 0).
1922        for &v in &seq1 {
1923            assert_ne!(v, 0, "xorshift64 should never produce 0 with non-zero seed");
1924        }
1925    }
1926
1927    // -----------------------------------------------------------------------
1928    // Test 9: EWMA leaf decay -- recent data dominates predictions.
1929    // -----------------------------------------------------------------------
1930    #[test]
1931    fn ewma_leaf_decay_recent_data_dominates() {
1932        // half_life=50 => alpha = exp(-ln(2)/50) ≈ 0.9862
1933        let alpha = (-(2.0_f64.ln()) / 50.0).exp();
1934        let config = TreeConfig::new()
1935            .grace_period(20)
1936            .max_depth(4)
1937            .n_bins(16)
1938            .lambda(1.0)
1939            .leaf_decay_alpha(alpha);
1940        let mut tree = HoeffdingTree::new(config);
1941
1942        // Phase 1: 1000 samples targeting 1.0
1943        for _ in 0..1000 {
1944            let pred = tree.predict(&[1.0, 2.0]);
1945            let grad = pred - 1.0; // gradient for squared loss
1946            tree.train_one(&[1.0, 2.0], grad, 1.0);
1947        }
1948
1949        // Phase 2: 100 samples targeting 5.0
1950        for _ in 0..100 {
1951            let pred = tree.predict(&[1.0, 2.0]);
1952            let grad = pred - 5.0;
1953            tree.train_one(&[1.0, 2.0], grad, 1.0);
1954        }
1955
1956        let pred = tree.predict(&[1.0, 2.0]);
1957        // With EWMA, the prediction should be pulled toward 5.0 (recent target).
1958        // Without EWMA, 1000 samples at 1.0 would dominate 100 at 5.0.
1959        assert!(
1960            pred > 2.0,
1961            "EWMA should let recent data (target=5.0) pull prediction above 2.0, got {}",
1962            pred,
1963        );
1964    }
1965
1966    // -----------------------------------------------------------------------
1967    // Test 10: EWMA disabled (None) matches traditional behavior.
1968    // -----------------------------------------------------------------------
1969    #[test]
1970    fn ewma_disabled_matches_traditional() {
1971        let config_no_ewma = TreeConfig::new()
1972            .grace_period(20)
1973            .max_depth(4)
1974            .n_bins(16)
1975            .lambda(1.0);
1976        let mut tree = HoeffdingTree::new(config_no_ewma);
1977
1978        let mut rng_state: u64 = 99999;
1979        for _ in 0..200 {
1980            let x = test_rand_f64(&mut rng_state) * 10.0;
1981            let y = 3.0 * x + 1.0;
1982            let pred = tree.predict(&[x]);
1983            tree.train_one(&[x], pred - y, 1.0);
1984        }
1985
1986        let pred = tree.predict(&[5.0]);
1987        assert!(
1988            pred.is_finite(),
1989            "prediction without EWMA should be finite, got {}",
1990            pred
1991        );
1992    }
1993
1994    // -----------------------------------------------------------------------
1995    // Test 11: Split re-evaluation at max depth grows beyond frozen point.
1996    // -----------------------------------------------------------------------
1997    #[test]
1998    fn split_reeval_at_max_depth() {
1999        let config = TreeConfig::new()
2000            .grace_period(20)
2001            .max_depth(2) // Very shallow to hit max depth quickly
2002            .n_bins(16)
2003            .lambda(1.0)
2004            .split_reeval_interval(50);
2005        let mut tree = HoeffdingTree::new(config);
2006
2007        let mut rng_state: u64 = 54321;
2008        // Train enough to saturate max_depth=2 and then trigger re-evaluation.
2009        for _ in 0..2000 {
2010            let x1 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2011            let x2 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2012            let y = 2.0 * x1 + 3.0 * x2;
2013            let pred = tree.predict(&[x1, x2]);
2014            tree.train_one(&[x1, x2], pred - y, 1.0);
2015        }
2016
2017        // With split_reeval_interval=50, max-depth leaves can re-evaluate
2018        // and potentially split beyond max_depth. The tree should have MORE
2019        // leaves than a max_depth=2 tree without re-eval (which caps at 4).
2020        let leaves = tree.n_leaves();
2021        assert!(
2022            leaves >= 4,
2023            "split re-eval should allow growth beyond max_depth=2 cap (4 leaves), got {}",
2024            leaves,
2025        );
2026    }
2027
2028    // -----------------------------------------------------------------------
2029    // Test 12: Split re-evaluation disabled matches existing behavior.
2030    // -----------------------------------------------------------------------
2031    #[test]
2032    fn split_reeval_disabled_matches_traditional() {
2033        let config = TreeConfig::new()
2034            .grace_period(20)
2035            .max_depth(2)
2036            .n_bins(16)
2037            .lambda(1.0);
2038        // No split_reeval_interval => None => traditional hard cap
2039        let mut tree = HoeffdingTree::new(config);
2040
2041        let mut rng_state: u64 = 77777;
2042        for _ in 0..2000 {
2043            let x1 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2044            let x2 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2045            let y = 2.0 * x1 + 3.0 * x2;
2046            let pred = tree.predict(&[x1, x2]);
2047            tree.train_one(&[x1, x2], pred - y, 1.0);
2048        }
2049
2050        // Without re-eval, max_depth=2 caps at 4 leaves (2^2).
2051        let leaves = tree.n_leaves();
2052        assert!(
2053            leaves <= 4,
2054            "without re-eval, max_depth=2 should cap at 4 leaves, got {}",
2055            leaves,
2056        );
2057    }
2058
2059    // -----------------------------------------------------------------------
2060    // Test: Gradient clipping clamps outliers
2061    // -----------------------------------------------------------------------
2062    #[test]
2063    fn gradient_clipping_clamps_outliers() {
2064        let config = TreeConfig::new()
2065            .grace_period(20)
2066            .max_depth(2)
2067            .n_bins(16)
2068            .gradient_clip_sigma(2.0);
2069
2070        let mut tree = HoeffdingTree::new(config);
2071
2072        // Train 50 normal samples
2073        let mut rng_state = 42u64;
2074        for _ in 0..50 {
2075            let x = test_rand_f64(&mut rng_state) * 2.0;
2076            let grad = x * 0.1; // small gradients ~[0, 0.2]
2077            tree.train_one(&[x], grad, 1.0);
2078        }
2079
2080        let pred_before = tree.predict(&[1.0]);
2081
2082        // Now inject an extreme outlier gradient
2083        tree.train_one(&[1.0], 1000.0, 1.0);
2084
2085        let pred_after = tree.predict(&[1.0]);
2086
2087        // With clipping at 2-sigma, the outlier should be clamped.
2088        // Without clipping, the prediction would jump massively.
2089        // The change should be bounded.
2090        let delta = (pred_after - pred_before).abs();
2091        assert!(
2092            delta < 100.0,
2093            "gradient clipping should limit impact of outlier, but prediction changed by {}",
2094            delta,
2095        );
2096    }
2097
2098    // -----------------------------------------------------------------------
2099    // Test: clip_gradient function directly
2100    // -----------------------------------------------------------------------
2101    #[test]
2102    fn clip_gradient_welford_tracks_stats() {
2103        let mut state = LeafState::new(1);
2104
2105        // Feed 20 varied gradients to build up statistics
2106        for i in 0..20 {
2107            let grad = 1.0 + (i as f64) * 0.1; // range [1.0, 2.9]
2108            let clipped = clip_gradient(&mut state, grad, 3.0);
2109            // 3-sigma is very wide, so these should not be clipped
2110            assert!(
2111                (clipped - grad).abs() < 1e-10,
2112                "normal gradients should not be clipped at 3-sigma"
2113            );
2114        }
2115
2116        // Now an extreme outlier -- mean is ~1.95, std ~0.59, 3-sigma range is ~[0.18, 3.72]
2117        let clipped = clip_gradient(&mut state, 100.0, 3.0);
2118        assert!(
2119            clipped < 100.0,
2120            "extreme outlier should be clipped, got {}",
2121            clipped,
2122        );
2123        assert!(
2124            clipped > 0.0,
2125            "clipped value should be positive, got {}",
2126            clipped,
2127        );
2128    }
2129
2130    // -----------------------------------------------------------------------
2131    // Test: clip_gradient warmup period
2132    // -----------------------------------------------------------------------
2133    #[test]
2134    fn clip_gradient_warmup_no_clipping() {
2135        let mut state = LeafState::new(1);
2136
2137        // During warmup (< 10 samples), no clipping
2138        for i in 0..9 {
2139            let val = if i == 8 { 1000.0 } else { 1.0 };
2140            let clipped = clip_gradient(&mut state, val, 2.0);
2141            assert_eq!(clipped, val, "warmup should not clip");
2142        }
2143    }
2144
2145    // -----------------------------------------------------------------------
2146    // Test: adaptive_bound warmup returns f64::MAX
2147    // -----------------------------------------------------------------------
2148    #[test]
2149    fn adaptive_bound_warmup_returns_max() {
2150        let mut state = LeafState::new(1);
2151        // Feed < 10 output weights
2152        for i in 0..9 {
2153            update_output_stats(&mut state, 0.5 + i as f64 * 0.01, None);
2154        }
2155        let bound = adaptive_bound(&state, 3.0, None);
2156        assert_eq!(bound, f64::MAX, "warmup should return f64::MAX");
2157    }
2158
2159    // -----------------------------------------------------------------------
2160    // Test: adaptive_bound tightens after warmup (Welford path)
2161    // -----------------------------------------------------------------------
2162    #[test]
2163    fn adaptive_bound_tightens_after_warmup() {
2164        let mut state = LeafState::new(1);
2165        // Feed 20 outputs centered around 0.3 with small variance
2166        for i in 0..20 {
2167            let w = 0.3 + (i as f64 - 10.0) * 0.01; // range [0.2, 0.39]
2168            update_output_stats(&mut state, w, None);
2169        }
2170        let bound = adaptive_bound(&state, 3.0, None);
2171        // Bound should be much less than a global max of 3.0
2172        assert!(
2173            bound < 1.0,
2174            "3-sigma bound on outputs ~0.3 should be < 1.0, got {}",
2175            bound,
2176        );
2177        assert!(bound > 0.2, "bound should be > |mean|, got {}", bound,);
2178    }
2179
2180    // -----------------------------------------------------------------------
2181    // Test: adaptive_bound clamps outlier leaf
2182    // -----------------------------------------------------------------------
2183    #[test]
2184    fn adaptive_bound_clamps_outlier_leaf() {
2185        let mut state = LeafState::new(1);
2186        // Build stats: 20 outputs ~0.3
2187        for _ in 0..20 {
2188            update_output_stats(&mut state, 0.3, None);
2189        }
2190        let bound = adaptive_bound(&state, 3.0, None);
2191        // A leaf output of 2.9 should be clamped
2192        let clamped = (2.9_f64).clamp(-bound, bound);
2193        assert!(
2194            clamped < 2.9,
2195            "2.9 should be clamped by adaptive bound {}, got {}",
2196            bound,
2197            clamped,
2198        );
2199    }
2200
2201    // -----------------------------------------------------------------------
2202    // Test: adaptive_bound with EWMA decay adapts
2203    // -----------------------------------------------------------------------
2204    #[test]
2205    fn adaptive_bound_with_decay_adapts() {
2206        let alpha = 0.95; // fast decay for testing
2207        let mut state = LeafState::new(1);
2208
2209        // Phase 1: outputs around 0.3
2210        for _ in 0..30 {
2211            update_output_stats(&mut state, 0.3, Some(alpha));
2212        }
2213        let bound_phase1 = adaptive_bound(&state, 3.0, Some(alpha));
2214
2215        // Phase 2: outputs shift to 2.0
2216        for _ in 0..100 {
2217            update_output_stats(&mut state, 2.0, Some(alpha));
2218        }
2219        let bound_phase2 = adaptive_bound(&state, 3.0, Some(alpha));
2220
2221        // After regime change, bound should adapt upward
2222        assert!(
2223            bound_phase2 > bound_phase1,
2224            "EWMA bound should adapt: phase1={}, phase2={}",
2225            bound_phase1,
2226            bound_phase2,
2227        );
2228    }
2229
2230    // -----------------------------------------------------------------------
2231    // Test: adaptive_bound disabled by default
2232    // -----------------------------------------------------------------------
2233    #[test]
2234    fn adaptive_bound_disabled_by_default() {
2235        let config = TreeConfig::default();
2236        assert!(
2237            config.adaptive_leaf_bound.is_none(),
2238            "adaptive_leaf_bound should default to None",
2239        );
2240    }
2241
2242    // -----------------------------------------------------------------------
2243    // Test: adaptive_bound warmup falls back to global max_leaf_output
2244    // -----------------------------------------------------------------------
2245    #[test]
2246    fn adaptive_bound_warmup_falls_back_to_global() {
2247        let mut state = LeafState::new(1);
2248        // Only 5 samples — still in warmup
2249        for _ in 0..5 {
2250            update_output_stats(&mut state, 0.3, None);
2251        }
2252        let bound = adaptive_bound(&state, 3.0, None);
2253        assert_eq!(bound, f64::MAX, "warmup should yield f64::MAX");
2254        // In leaf_prediction, f64::MAX falls through to global max_leaf_output
2255    }
2256
2257    // -----------------------------------------------------------------------
2258    // Test: Monotonic constraints filter invalid splits
2259    // -----------------------------------------------------------------------
2260    #[test]
2261    fn monotonic_constraint_splits_respected() {
2262        // Train with +1 constraint on feature 0 (increasing).
2263        // Use a dataset where feature 0 has a negative relationship.
2264        let config = TreeConfig::new()
2265            .grace_period(30)
2266            .max_depth(4)
2267            .n_bins(16)
2268            .monotone_constraints(vec![1]); // feature 0 must be increasing
2269
2270        let mut tree = HoeffdingTree::new(config);
2271
2272        let mut rng_state = 42u64;
2273        for _ in 0..500 {
2274            let x = test_rand_f64(&mut rng_state) * 10.0;
2275            // Negative relationship: high x → low y → positive gradient
2276            let grad = x * 0.5 - 2.5;
2277            tree.train_one(&[x], grad, 1.0);
2278        }
2279
2280        // Any split that occurred should satisfy: left_value <= right_value
2281        // for the monotone +1 constraint. Verify prediction is non-decreasing.
2282        let pred_low = tree.predict(&[0.0]);
2283        let pred_mid = tree.predict(&[5.0]);
2284        let pred_high = tree.predict(&[10.0]);
2285
2286        // Due to constraint, prediction must be non-decreasing
2287        assert!(
2288            pred_low <= pred_mid + 1e-10 && pred_mid <= pred_high + 1e-10,
2289            "monotonic +1 violated: pred(0)={}, pred(5)={}, pred(10)={}",
2290            pred_low,
2291            pred_mid,
2292            pred_high,
2293        );
2294    }
2295
2296    // -----------------------------------------------------------------------
2297    // Test: predict_with_variance returns finite values
2298    // -----------------------------------------------------------------------
2299    #[test]
2300    fn predict_with_variance_finite() {
2301        let config = TreeConfig::new().grace_period(10);
2302        let mut tree = HoeffdingTree::new(config);
2303
2304        // Train a few samples
2305        for i in 0..30 {
2306            let x = i as f64 * 0.1;
2307            tree.train_one(&[x], x - 1.0, 1.0);
2308        }
2309
2310        let (value, variance) = tree.predict_with_variance(&[1.0]);
2311        assert!(value.is_finite(), "value should be finite");
2312        assert!(variance.is_finite(), "variance should be finite");
2313        assert!(variance > 0.0, "variance should be positive");
2314    }
2315
2316    // -----------------------------------------------------------------------
2317    // Test: predict_with_variance decreases with more data
2318    // -----------------------------------------------------------------------
2319    #[test]
2320    fn predict_with_variance_decreases_with_data() {
2321        let config = TreeConfig::new().grace_period(10);
2322        let mut tree = HoeffdingTree::new(config);
2323
2324        // Train 20 samples, check variance
2325        for i in 0..20 {
2326            tree.train_one(&[1.0], 0.5, 1.0);
2327            if i == 0 {
2328                continue;
2329            }
2330        }
2331        let (_, var_20) = tree.predict_with_variance(&[1.0]);
2332
2333        // Train 200 more samples
2334        for _ in 0..200 {
2335            tree.train_one(&[1.0], 0.5, 1.0);
2336        }
2337        let (_, var_220) = tree.predict_with_variance(&[1.0]);
2338
2339        assert!(
2340            var_220 < var_20,
2341            "variance should decrease with more data: var@20={} vs var@220={}",
2342            var_20,
2343            var_220,
2344        );
2345    }
2346
2347    // -----------------------------------------------------------------------
2348    // Test: predict_smooth matches hard prediction at small bandwidth
2349    // -----------------------------------------------------------------------
2350    #[test]
2351    fn predict_smooth_matches_hard_at_small_bandwidth() {
2352        let config = TreeConfig::new()
2353            .max_depth(3)
2354            .n_bins(16)
2355            .grace_period(20)
2356            .lambda(1.0);
2357        let mut tree = HoeffdingTree::new(config);
2358
2359        // Train enough to get splits
2360        let mut rng = 42u64;
2361        for _ in 0..500 {
2362            let x = test_rand_f64(&mut rng) * 10.0;
2363            let y = 2.0 * x + 1.0;
2364            let features = vec![x, x * 0.5];
2365            let pred = tree.predict(&features);
2366            let grad = pred - y;
2367            let hess = 1.0;
2368            tree.train_one(&features, grad, hess);
2369        }
2370
2371        // With very small bandwidth, smooth prediction should approximate hard prediction
2372        let features = vec![5.0, 2.5];
2373        let hard = tree.predict(&features);
2374        let smooth = tree.predict_smooth(&features, 0.001);
2375        assert!(
2376            (hard - smooth).abs() < 0.1,
2377            "smooth with tiny bandwidth should approximate hard: hard={}, smooth={}",
2378            hard,
2379            smooth,
2380        );
2381    }
2382
2383    // -----------------------------------------------------------------------
2384    // Test: predict_smooth is continuous
2385    // -----------------------------------------------------------------------
2386    #[test]
2387    fn predict_smooth_is_continuous() {
2388        let config = TreeConfig::new()
2389            .max_depth(3)
2390            .n_bins(16)
2391            .grace_period(20)
2392            .lambda(1.0);
2393        let mut tree = HoeffdingTree::new(config);
2394
2395        // Train to get splits
2396        let mut rng = 42u64;
2397        for _ in 0..500 {
2398            let x = test_rand_f64(&mut rng) * 10.0;
2399            let y = 2.0 * x + 1.0;
2400            let features = vec![x, x * 0.5];
2401            let pred = tree.predict(&features);
2402            let grad = pred - y;
2403            tree.train_one(&features, grad, 1.0);
2404        }
2405
2406        // Check that small input changes produce small output changes (continuity)
2407        let bandwidth = 1.0;
2408        let base = tree.predict_smooth(&[5.0, 2.5], bandwidth);
2409        let nudged = tree.predict_smooth(&[5.001, 2.5], bandwidth);
2410        let diff = (base - nudged).abs();
2411        assert!(
2412            diff < 0.1,
2413            "smooth prediction should be continuous: base={}, nudged={}, diff={}",
2414            base,
2415            nudged,
2416            diff,
2417        );
2418    }
2419
2420    // -----------------------------------------------------------------------
2421    // Test: leaf_grad_hess returns valid sums after training.
2422    // -----------------------------------------------------------------------
2423    #[test]
2424    fn leaf_grad_hess_returns_sums() {
2425        let config = TreeConfig::new().grace_period(100).lambda(1.0);
2426        let mut tree = HoeffdingTree::new(config);
2427
2428        let features = vec![1.0, 2.0, 3.0];
2429
2430        // Train several samples with known gradients
2431        for _ in 0..10 {
2432            tree.train_one(&features, -0.5, 1.0);
2433        }
2434
2435        // The root should be a leaf (grace_period=100, only 10 samples)
2436        let root = tree.root();
2437        let (grad, hess) = tree
2438            .leaf_grad_hess(root)
2439            .expect("root should have leaf state");
2440
2441        // grad_sum should be sum of all gradients: 10 * (-0.5) = -5.0
2442        assert!(
2443            (grad - (-5.0)).abs() < 1e-10,
2444            "grad_sum should be -5.0, got {}",
2445            grad
2446        );
2447        // hess_sum should be sum of all hessians: 10 * 1.0 = 10.0
2448        assert!(
2449            (hess - 10.0).abs() < 1e-10,
2450            "hess_sum should be 10.0, got {}",
2451            hess
2452        );
2453    }
2454
2455    #[test]
2456    fn leaf_grad_hess_returns_none_for_invalid_node() {
2457        let config = TreeConfig::new();
2458        let tree = HoeffdingTree::new(config);
2459
2460        // NodeId::NONE should return None
2461        assert!(tree.leaf_grad_hess(NodeId::NONE).is_none());
2462        // A non-existent node should return None
2463        assert!(tree.leaf_grad_hess(NodeId(999)).is_none());
2464    }
2465
2466    // -----------------------------------------------------------------------
2467    // Adaptive depth (per-split information criterion) tests
2468    // -----------------------------------------------------------------------
2469
2470    #[test]
2471    fn adaptive_depth_none_identical_to_static_max_depth() {
2472        // With adaptive_depth = None, behavior should be identical to the
2473        // classic static max_depth check.
2474        let config_static = TreeConfig::new()
2475            .max_depth(3)
2476            .n_bins(32)
2477            .grace_period(20)
2478            .lambda(0.1)
2479            .delta(1e-3);
2480
2481        let config_none = TreeConfig::new()
2482            .max_depth(3)
2483            .n_bins(32)
2484            .grace_period(20)
2485            .lambda(0.1)
2486            .delta(1e-3);
2487
2488        // Verify adaptive_depth is None by default
2489        assert!(config_none.adaptive_depth.is_none());
2490
2491        let mut tree_static = HoeffdingTree::new(config_static);
2492        let mut tree_none = HoeffdingTree::new(config_none);
2493
2494        let mut rng_state: u64 = 42;
2495        for _ in 0..2000 {
2496            let x = test_rand_f64(&mut rng_state) * 10.0;
2497            let y = 2.0 * x;
2498            let feat = [x, x * 0.5, x * x];
2499            let pred_s = tree_static.predict(&feat);
2500            let pred_n = tree_none.predict(&feat);
2501            tree_static.train_one(&feat, pred_s - y, 1.0);
2502            tree_none.train_one(&feat, pred_n - y, 1.0);
2503        }
2504
2505        // Both trees should have the same number of nodes
2506        assert_eq!(
2507            tree_static.arena().n_nodes(),
2508            tree_none.arena().n_nodes(),
2509            "adaptive_depth=None should produce identical tree structure to static max_depth"
2510        );
2511    }
2512
2513    #[test]
2514    fn adaptive_depth_few_samples_stays_shallow() {
2515        // With adaptive_depth enabled and few noisy samples,
2516        // the CIR penalty should prevent deep splits.
2517        let config = TreeConfig::new()
2518            .max_depth(6)
2519            .n_bins(32)
2520            .grace_period(20)
2521            .lambda(0.1)
2522            .delta(1e-3)
2523            .adaptive_depth(7.5);
2524
2525        let mut tree = HoeffdingTree::new(config);
2526        let mut rng_state: u64 = 99;
2527
2528        // Feed a small number of noisy samples — mostly noise, weak signal
2529        for _ in 0..100 {
2530            let x = test_rand_f64(&mut rng_state) * 10.0;
2531            let noise = (test_rand_f64(&mut rng_state) - 0.5) * 20.0; // large noise
2532            let y = 0.1 * x + noise;
2533            let feat = [x, test_rand_f64(&mut rng_state) * 5.0];
2534            let pred = tree.predict(&feat);
2535            tree.train_one(&feat, pred - y, 1.0);
2536        }
2537
2538        // With few noisy samples and CIR penalty, the tree should stay shallower
2539        // than the hard ceiling of max_depth*2 = 12. In fact, with this much noise
2540        // and few samples, we expect very few splits.
2541        let n_nodes = tree.arena().n_nodes();
2542        assert!(
2543            n_nodes <= 15,
2544            "adaptive_depth with few noisy samples should keep tree shallow, got {} nodes",
2545            n_nodes
2546        );
2547    }
2548
2549    #[test]
2550    fn adaptive_depth_many_samples_grows_deeper() {
2551        // With many clean samples, the CIR penalty decays (1/n) and the tree
2552        // should grow deeper than with few samples.
2553        let config_few = TreeConfig::new()
2554            .max_depth(6)
2555            .n_bins(32)
2556            .grace_period(20)
2557            .lambda(0.1)
2558            .delta(1e-3)
2559            .adaptive_depth(7.5);
2560
2561        let config_many = TreeConfig::new()
2562            .max_depth(6)
2563            .n_bins(32)
2564            .grace_period(20)
2565            .lambda(0.1)
2566            .delta(1e-3)
2567            .adaptive_depth(7.5);
2568
2569        let mut tree_few = HoeffdingTree::new(config_few);
2570        let mut tree_many = HoeffdingTree::new(config_many);
2571
2572        let mut rng_state: u64 = 42;
2573
2574        // Strong signal: y = 3*x1 + 2*x2 (clean)
2575        // Train "few" with 200 samples
2576        for _ in 0..200 {
2577            let x1 = test_rand_f64(&mut rng_state) * 10.0;
2578            let x2 = test_rand_f64(&mut rng_state) * 5.0;
2579            let y = 3.0 * x1 + 2.0 * x2;
2580            let feat = [x1, x2];
2581            let pred = tree_few.predict(&feat);
2582            tree_few.train_one(&feat, pred - y, 1.0);
2583        }
2584
2585        // Train "many" with 5000 samples
2586        let mut rng_state2: u64 = 42;
2587        for _ in 0..5000 {
2588            let x1 = test_rand_f64(&mut rng_state2) * 10.0;
2589            let x2 = test_rand_f64(&mut rng_state2) * 5.0;
2590            let y = 3.0 * x1 + 2.0 * x2;
2591            let feat = [x1, x2];
2592            let pred = tree_many.predict(&feat);
2593            tree_many.train_one(&feat, pred - y, 1.0);
2594        }
2595
2596        // The tree with more samples should have at least as many nodes
2597        // (penalty = cir_factor * var / n * n_feat decays with n)
2598        assert!(
2599            tree_many.arena().n_nodes() >= tree_few.arena().n_nodes(),
2600            "more samples should allow deeper growth: many={} vs few={}",
2601            tree_many.arena().n_nodes(),
2602            tree_few.arena().n_nodes()
2603        );
2604    }
2605
2606    #[test]
2607    fn adaptive_depth_penalty_scales_inversely_with_n() {
2608        // Directly verify the penalty math: penalty = cir_factor * grad_var / n * n_feat
2609        // With fixed cir_factor=7.5, grad_var=1.0, n_feat=2:
2610        //   n=100  -> penalty = 7.5 * 1.0 / 100 * 2 = 0.15
2611        //   n=1000 -> penalty = 7.5 * 1.0 / 1000 * 2 = 0.015
2612        // So a gain of 0.05 would fail at n=100 but pass at n=1000.
2613        let cir_factor: f64 = 7.5;
2614        let grad_var: f64 = 1.0;
2615        let n_feat: f64 = 2.0;
2616
2617        let penalty_100 = cir_factor * grad_var / 100.0 * n_feat;
2618        let penalty_1000 = cir_factor * grad_var / 1000.0 * n_feat;
2619
2620        assert!(
2621            (penalty_100 - 0.15).abs() < 1e-10,
2622            "penalty at n=100 should be 0.15, got {}",
2623            penalty_100
2624        );
2625        assert!(
2626            (penalty_1000 - 0.015).abs() < 1e-10,
2627            "penalty at n=1000 should be 0.015, got {}",
2628            penalty_1000
2629        );
2630        assert!(
2631            penalty_100 > penalty_1000,
2632            "penalty should decrease with more samples"
2633        );
2634
2635        // A gain of 0.05 should fail at n=100 (0.05 <= 0.15) but pass at n=1000 (0.05 > 0.015)
2636        let gain = 0.05;
2637        assert!(gain <= penalty_100, "gain should fail CIR at n=100");
2638        assert!(gain > penalty_1000, "gain should pass CIR at n=1000");
2639    }
2640
2641    #[test]
2642    fn adaptive_depth_hard_ceiling_respected() {
2643        // Even with adaptive_depth enabled, trees should never exceed max_depth * 2.
2644        let config = TreeConfig::new()
2645            .max_depth(3)
2646            .n_bins(32)
2647            .grace_period(10)
2648            .lambda(0.01)
2649            .gamma(0.0)
2650            .delta(1e-2) // Loose bound for easy splitting
2651            .adaptive_depth(0.001); // Very small factor = almost no CIR penalty
2652
2653        let mut tree = HoeffdingTree::new(config);
2654        let mut rng_state: u64 = 777;
2655
2656        // Train with very strong, clean signal to maximize splitting
2657        for _ in 0..10000 {
2658            let x = test_rand_f64(&mut rng_state) * 100.0;
2659            let y = x * x; // Strong nonlinear signal
2660            let feat = [x];
2661            let pred = tree.predict(&feat);
2662            tree.train_one(&feat, pred - y, 1.0);
2663        }
2664
2665        // Hard ceiling = max_depth * 2 = 6, so max leaves = 2^6 = 64
2666        let max_leaves = 1usize << 6;
2667        let n_leaves = tree.arena().n_leaves();
2668        assert!(
2669            n_leaves <= max_leaves,
2670            "tree should respect hard ceiling of max_depth*2=6 ({} max leaves), got {} leaves",
2671            max_leaves,
2672            n_leaves
2673        );
2674    }
2675}