Skip to main content

irithyll_core/tree/
hoeffding.rs

1//! Hoeffding-bound split decisions for streaming tree construction.
2//!
3//! [`HoeffdingTree`] is the core streaming decision tree. It grows incrementally:
4//! each sample updates per-leaf histogram accumulators, and splits are committed
5//! only when the Hoeffding bound guarantees the best candidate split is
6//! statistically superior to the runner-up (or a tie-breaking threshold is met).
7//!
8//! # Algorithm
9//!
10//! For each incoming `(features, gradient, hessian)` triple:
11//!
12//! 1. Route the sample from root to a leaf via threshold comparisons.
13//! 2. At the leaf, accumulate gradient/hessian into per-feature histograms.
14//! 3. Once enough samples arrive (grace period), evaluate candidate splits
15//!    using the XGBoost gain formula.
16//! 4. Apply the Hoeffding bound: if the gap between the best and second-best
17//!    gain exceeds `epsilon = sqrt(R^2 * ln(1/delta) / (2n))`, commit the split.
18//! 5. When splitting, use the histogram subtraction trick to initialize one
19//!    child's histograms for free.
20
21use alloc::boxed::Box;
22use alloc::vec;
23use alloc::vec::Vec;
24
25use crate::feature::FeatureType;
26use crate::histogram::bins::LeafHistograms;
27use crate::histogram::{BinEdges, BinnerKind};
28use crate::math;
29use crate::tree::builder::TreeConfig;
30use crate::tree::leaf_model::{LeafModel, LeafModelType};
31use crate::tree::node::{NodeId, TreeArena};
32use crate::tree::split::{leaf_weight, SplitCandidate, SplitCriterion, XGBoostGain};
33use crate::tree::StreamingTree;
34
35/// Tie-breaking threshold (tau). When `epsilon < tau`, we accept the best split
36/// even if the gap between best and second-best gain is small, because the
37/// Hoeffding bound is already tight enough that further samples won't help.
38const TAU: f64 = 0.05;
39
40// ---------------------------------------------------------------------------
41// xorshift64 -- minimal deterministic PRNG
42// ---------------------------------------------------------------------------
43
44/// Advance an xorshift64 state and return the new value.
45#[inline]
46fn xorshift64(state: &mut u64) -> u64 {
47    let mut s = *state;
48    s ^= s << 13;
49    s ^= s >> 7;
50    s ^= s << 17;
51    *state = s;
52    s
53}
54
55// ---------------------------------------------------------------------------
56// LeafState -- per-leaf bookkeeping
57// ---------------------------------------------------------------------------
58
59/// State tracked per leaf node for split decisions.
60///
61/// Each active leaf owns its own set of histogram accumulators (one per feature)
62/// and running gradient/hessian sums for leaf weight updates.
63struct LeafState {
64    /// Histogram accumulators for this leaf. `None` until bin edges are computed
65    /// (after the grace period).
66    histograms: Option<LeafHistograms>,
67
68    /// Per-feature binning strategies that collect observed values to compute
69    /// bin edges. Uses `BinnerKind` enum dispatch instead of `Box<dyn>` to
70    /// eliminate N_features heap allocations per new leaf.
71    binners: Vec<BinnerKind>,
72
73    /// Whether bin edges have been computed (after grace period samples).
74    bins_ready: bool,
75
76    /// Running gradient sum for leaf weight updates.
77    grad_sum: f64,
78
79    /// Running hessian sum for leaf weight updates.
80    hess_sum: f64,
81
82    /// Sample count at last split re-evaluation (for EFDT-inspired re-eval).
83    last_reeval_count: u64,
84
85    /// EWMA gradient mean for gradient clipping (Welford online algorithm).
86    clip_grad_mean: f64,
87
88    /// EWMA gradient M2 accumulator (Welford) for variance estimation.
89    clip_grad_m2: f64,
90
91    /// Number of gradients observed for clipping statistics.
92    clip_grad_count: u64,
93
94    /// EWMA/Welford mean of this leaf's output weight (for adaptive bounds).
95    output_mean: f64,
96
97    /// EWMA/Welford M2 (variance accumulator) of this leaf's output weight.
98    output_m2: f64,
99
100    /// Number of output weight observations.
101    output_count: u64,
102
103    /// Optional trainable leaf model (linear / MLP). `None` for closed-form leaves.
104    leaf_model: Option<Box<dyn LeafModel>>,
105}
106
107impl Clone for LeafState {
108    fn clone(&self) -> Self {
109        Self {
110            histograms: self.histograms.clone(),
111            binners: self.binners.clone(),
112            bins_ready: self.bins_ready,
113            grad_sum: self.grad_sum,
114            hess_sum: self.hess_sum,
115            last_reeval_count: self.last_reeval_count,
116            clip_grad_mean: self.clip_grad_mean,
117            clip_grad_m2: self.clip_grad_m2,
118            clip_grad_count: self.clip_grad_count,
119            output_mean: self.output_mean,
120            output_m2: self.output_m2,
121            output_count: self.output_count,
122            leaf_model: self.leaf_model.as_ref().map(|m| m.clone_warm()),
123        }
124    }
125}
126
127/// Clip a gradient using Welford online stats tracked per leaf.
128///
129/// Updates running mean/variance, then clamps the gradient to `mean ± sigma * std_dev`.
130/// Returns the (possibly clamped) gradient. During warmup (< 10 samples), no clipping
131/// is applied to let the statistics stabilize.
132#[inline]
133fn clip_gradient(state: &mut LeafState, gradient: f64, sigma: f64) -> f64 {
134    state.clip_grad_count += 1;
135    let n = state.clip_grad_count as f64;
136
137    // Welford online update
138    let delta = gradient - state.clip_grad_mean;
139    state.clip_grad_mean += delta / n;
140    let delta2 = gradient - state.clip_grad_mean;
141    state.clip_grad_m2 += delta * delta2;
142
143    // No clipping during warmup
144    if state.clip_grad_count < 10 {
145        return gradient;
146    }
147
148    let variance = state.clip_grad_m2 / (n - 1.0);
149    let std_dev = math::sqrt(variance);
150
151    if std_dev < 1e-15 {
152        return gradient; // All gradients identical -- no clipping needed
153    }
154
155    let lo = state.clip_grad_mean - sigma * std_dev;
156    let hi = state.clip_grad_mean + sigma * std_dev;
157    gradient.clamp(lo, hi)
158}
159
160/// Update per-leaf output weight tracking for adaptive bounds.
161///
162/// If `decay_alpha` is Some, uses EWMA synchronized with leaf_decay_alpha.
163/// Otherwise uses Welford online algorithm (batch scenarios).
164#[inline]
165fn update_output_stats(state: &mut LeafState, weight: f64, decay_alpha: Option<f64>) {
166    state.output_count += 1;
167
168    if let Some(alpha) = decay_alpha {
169        // EWMA — synchronized with leaf gradient decay
170        if state.output_count == 1 {
171            state.output_mean = weight;
172            state.output_m2 = 0.0;
173        } else {
174            let diff = weight - state.output_mean;
175            state.output_mean = alpha * state.output_mean + (1.0 - alpha) * weight;
176            let diff2 = weight - state.output_mean;
177            state.output_m2 = alpha * state.output_m2 + (1.0 - alpha) * diff * diff2;
178        }
179    } else {
180        // Welford online (no decay — batch scenarios)
181        let delta = weight - state.output_mean;
182        state.output_mean += delta / (state.output_count as f64);
183        let delta2 = weight - state.output_mean;
184        state.output_m2 += delta * delta2;
185    }
186}
187
188/// Get the adaptive output bound for this leaf.
189///
190/// Returns `|mean| + k * std`, with a floor of 0.01 to never fully suppress a leaf.
191/// During warmup (< 10 samples), returns `f64::MAX` (no bound).
192#[inline]
193fn adaptive_bound(state: &LeafState, k: f64, decay_alpha: Option<f64>) -> f64 {
194    if state.output_count < 10 {
195        return f64::MAX; // warmup — no bound yet
196    }
197
198    let variance = if decay_alpha.is_some() {
199        // EWMA variance is stored directly in output_m2
200        state.output_m2.max(0.0)
201    } else {
202        // Welford: variance = M2 / (n - 1)
203        state.output_m2 / (state.output_count as f64 - 1.0)
204    };
205    let std = math::sqrt(variance);
206
207    // Bound = |mean| + k * std, floor at 0.01 to never fully suppress a leaf
208    (math::abs(state.output_mean) + k * std).max(0.01)
209}
210
211/// Create binners according to feature types.
212fn make_binners(n_features: usize, feature_types: Option<&[FeatureType]>) -> Vec<BinnerKind> {
213    (0..n_features)
214        .map(|i| {
215            if let Some(ft) = feature_types {
216                if i < ft.len() && ft[i] == FeatureType::Categorical {
217                    return BinnerKind::categorical();
218                }
219            }
220            BinnerKind::uniform()
221        })
222        .collect()
223}
224
225impl LeafState {
226    /// Create a fresh leaf state for a leaf with `n_features` features.
227    fn new(n_features: usize) -> Self {
228        Self::new_with_types(n_features, None)
229    }
230
231    /// Create a fresh leaf state respecting per-feature type declarations.
232    fn new_with_types(n_features: usize, feature_types: Option<&[FeatureType]>) -> Self {
233        let binners = make_binners(n_features, feature_types);
234
235        Self {
236            histograms: None,
237            binners,
238            bins_ready: false,
239            grad_sum: 0.0,
240            hess_sum: 0.0,
241            last_reeval_count: 0,
242            clip_grad_mean: 0.0,
243            clip_grad_m2: 0.0,
244            clip_grad_count: 0,
245            output_mean: 0.0,
246            output_m2: 0.0,
247            output_count: 0,
248            leaf_model: None,
249        }
250    }
251
252    /// Create a leaf state with pre-computed histograms (used after a split
253    /// when we can initialize the child via histogram subtraction).
254    #[allow(dead_code)]
255    fn with_histograms(histograms: LeafHistograms) -> Self {
256        let n_features = histograms.n_features();
257        let binners: Vec<BinnerKind> = (0..n_features).map(|_| BinnerKind::uniform()).collect();
258
259        // Recover grad/hess sums from the histograms.
260        let grad_sum: f64 = histograms
261            .histograms
262            .first()
263            .map_or(0.0, |h| h.total_gradient());
264        let hess_sum: f64 = histograms
265            .histograms
266            .first()
267            .map_or(0.0, |h| h.total_hessian());
268
269        Self {
270            histograms: Some(histograms),
271            binners,
272            bins_ready: true,
273            grad_sum,
274            hess_sum,
275            last_reeval_count: 0,
276            clip_grad_mean: 0.0,
277            clip_grad_m2: 0.0,
278            clip_grad_count: 0,
279            output_mean: 0.0,
280            output_m2: 0.0,
281            output_count: 0,
282            leaf_model: None,
283        }
284    }
285}
286
287// ---------------------------------------------------------------------------
288// HoeffdingTree
289// ---------------------------------------------------------------------------
290
291/// A streaming decision tree that uses Hoeffding-bound split decisions.
292///
293/// The tree grows incrementally: each call to [`train_one`](StreamingTree::train_one)
294/// routes one sample to its leaf, updates histograms, and potentially triggers
295/// a split when statistical evidence is sufficient.
296///
297/// # Feature subsampling
298///
299/// When `config.feature_subsample_rate < 1.0`, each split evaluation considers
300/// only a random subset of features (selected via a deterministic xorshift64 RNG).
301/// This adds diversity when the tree is used inside an ensemble.
302pub struct HoeffdingTree {
303    /// Arena-allocated node storage.
304    arena: TreeArena,
305
306    /// Root node identifier.
307    root: NodeId,
308
309    /// Tree configuration / hyperparameters.
310    config: TreeConfig,
311
312    /// Per-leaf state indexed by `NodeId.0`. Dense Vec -- NodeIds are
313    /// contiguous u32 indices from TreeArena, so direct indexing is optimal.
314    leaf_states: Vec<Option<LeafState>>,
315
316    /// Number of features, learned from the first sample.
317    n_features: Option<usize>,
318
319    /// Total samples seen across all calls to `train_one`.
320    samples_seen: u64,
321
322    /// Split gain evaluator.
323    split_criterion: XGBoostGain,
324
325    /// Scratch buffer for the feature mask (avoids repeated allocation).
326    feature_mask: Vec<usize>,
327
328    /// Bitset scratch buffer for O(1) membership test during feature mask generation.
329    /// Each bit `i` indicates whether feature `i` is already in `feature_mask`.
330    feature_mask_bits: Vec<u64>,
331
332    /// xorshift64 RNG state for feature subsampling.
333    rng_state: u64,
334
335    /// Accumulated split gains per feature for importance tracking.
336    /// Indexed by feature index; grows lazily when n_features is learned.
337    split_gains: Vec<f64>,
338
339    /// Per-node auto-bandwidth for soft routing, indexed by `NodeId.0`.
340    /// Recomputed after every structural change (split).
341    node_bandwidths: Vec<f64>,
342}
343
344impl HoeffdingTree {
345    /// Create a new `HoeffdingTree` with the given configuration.
346    ///
347    /// The tree starts with a single root leaf and no feature information;
348    /// the number of features is inferred from the first training sample.
349    pub fn new(config: TreeConfig) -> Self {
350        let mut arena = TreeArena::new();
351        let root = arena.add_leaf(0);
352
353        // Insert a placeholder leaf state for the root. We don't know n_features
354        // yet, so give it 0 binners -- it will be properly initialized on the
355        // first sample.
356        let mut leaf_states = vec![None; root.0 as usize + 1];
357        let root_model = match config.leaf_model_type {
358            LeafModelType::ClosedForm => None,
359            _ => Some(config.leaf_model_type.create(config.seed, config.delta)),
360        };
361        leaf_states[root.0 as usize] = Some(LeafState {
362            histograms: None,
363            binners: Vec::new(),
364            bins_ready: false,
365            grad_sum: 0.0,
366            hess_sum: 0.0,
367            last_reeval_count: 0,
368            clip_grad_mean: 0.0,
369            clip_grad_m2: 0.0,
370            clip_grad_count: 0,
371            output_mean: 0.0,
372            output_m2: 0.0,
373            output_count: 0,
374            leaf_model: root_model,
375        });
376
377        let seed = config.seed;
378        Self {
379            arena,
380            root,
381            config,
382            leaf_states,
383            n_features: None,
384            samples_seen: 0,
385            split_criterion: XGBoostGain::default(),
386            feature_mask: Vec::new(),
387            feature_mask_bits: Vec::new(),
388            rng_state: seed,
389            split_gains: Vec::new(),
390            node_bandwidths: Vec::new(),
391        }
392    }
393
394    /// Create a leaf model for a new leaf if the config requires one.
395    ///
396    /// Returns `None` for `ClosedForm` (the default), which uses the existing
397    /// `leaf_weight()` path with zero overhead. For `Linear` and `MLP`, returns
398    /// a fresh model seeded deterministically from the config seed and node id.
399    fn make_leaf_model(&self, node: NodeId) -> Option<Box<dyn LeafModel>> {
400        match self.config.leaf_model_type {
401            LeafModelType::ClosedForm => None,
402            _ => Some(
403                self.config
404                    .leaf_model_type
405                    .create(self.config.seed ^ (node.0 as u64), self.config.delta),
406            ),
407        }
408    }
409
410    /// Reconstruct a `HoeffdingTree` from a pre-built arena.
411    ///
412    /// Used during model deserialization. The tree is restored with node
413    /// topology and leaf values intact, but histogram accumulators are empty
414    /// (they will rebuild naturally from continued training).
415    ///
416    /// The root is assumed to be `NodeId(0)`. Leaf states are created empty
417    /// for all current leaf nodes in the arena.
418    pub fn from_arena(
419        config: TreeConfig,
420        arena: TreeArena,
421        n_features: Option<usize>,
422        samples_seen: u64,
423        rng_state: u64,
424    ) -> Self {
425        let root = if arena.n_nodes() > 0 {
426            NodeId(0)
427        } else {
428            // Empty arena -- add a root leaf (shouldn't normally happen in restore).
429            let mut arena_mut = arena;
430            let root = arena_mut.add_leaf(0);
431            return Self {
432                arena: arena_mut,
433                root,
434                config: config.clone(),
435                leaf_states: {
436                    let mut v = vec![None; root.0 as usize + 1];
437                    v[root.0 as usize] = Some(LeafState::new(n_features.unwrap_or(0)));
438                    v
439                },
440                n_features,
441                samples_seen,
442                split_criterion: XGBoostGain::default(),
443                feature_mask: Vec::new(),
444                feature_mask_bits: Vec::new(),
445                rng_state,
446                split_gains: vec![0.0; n_features.unwrap_or(0)],
447                node_bandwidths: Vec::new(),
448            };
449        };
450
451        // Build leaf states for every leaf in the arena.
452        let nf = n_features.unwrap_or(0);
453        let mut leaf_states: Vec<Option<LeafState>> = vec![None; arena.n_nodes()];
454        for (i, slot) in leaf_states.iter_mut().enumerate() {
455            if arena.is_leaf[i] {
456                *slot = Some(LeafState::new(nf));
457            }
458        }
459
460        Self {
461            arena,
462            root,
463            config,
464            leaf_states,
465            n_features,
466            samples_seen,
467            split_criterion: XGBoostGain::default(),
468            feature_mask: Vec::new(),
469            feature_mask_bits: Vec::new(),
470            rng_state,
471            split_gains: vec![0.0; nf],
472            node_bandwidths: Vec::new(),
473        }
474    }
475
476    /// Root node identifier.
477    #[inline]
478    pub fn root(&self) -> NodeId {
479        self.root
480    }
481
482    /// Immutable access to the underlying arena.
483    #[inline]
484    pub fn arena(&self) -> &TreeArena {
485        &self.arena
486    }
487
488    /// Immutable access to the tree configuration.
489    #[inline]
490    pub fn tree_config(&self) -> &TreeConfig {
491        &self.config
492    }
493
494    /// Number of features (learned from the first sample, `None` before any training).
495    #[inline]
496    pub fn n_features(&self) -> Option<usize> {
497        self.n_features
498    }
499
500    /// Current RNG state (for deterministic checkpoint/restore).
501    #[inline]
502    pub fn rng_state(&self) -> u64 {
503        self.rng_state
504    }
505
506    /// Read-only access to the gradient and hessian sums for a leaf node.
507    ///
508    /// Returns `Some((grad_sum, hess_sum))` if `node` is a leaf with an active
509    /// leaf state, or `None` if the node has no state (e.g. internal node
510    /// or freshly allocated).
511    ///
512    /// These sums enable inverse-hessian confidence estimation:
513    /// `confidence = 1.0 / (hess_sum + lambda)`. High hessian means the leaf
514    /// has seen consistent, informative data; low hessian means uncertainty.
515    #[inline]
516    pub fn leaf_grad_hess(&self, node: NodeId) -> Option<(f64, f64)> {
517        self.leaf_states
518            .get(node.0 as usize)
519            .and_then(|o| o.as_ref())
520            .map(|state| (state.grad_sum, state.hess_sum))
521    }
522
523    /// Route a feature vector from the root down to a leaf, returning the leaf's NodeId.
524    fn route_to_leaf(&self, features: &[f64]) -> NodeId {
525        let mut current = self.root;
526        while !self.arena.is_leaf(current) {
527            let feat_idx = self.arena.get_feature_idx(current) as usize;
528            current = if let Some(mask) = self.arena.get_categorical_mask(current) {
529                // Categorical split: use bitmask routing.
530                // The feature value is cast to a bin index. If that bin's bit is set
531                // in the mask, go left; otherwise go right.
532                // For categorical features, the bin index in the histogram corresponds
533                // to the sorted category position, but for bitmask routing we use
534                // the original bin index directly.
535                let cat_val = features[feat_idx] as u64;
536                if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
537                    self.arena.get_left(current)
538                } else {
539                    self.arena.get_right(current)
540                }
541            } else {
542                // Continuous split: standard threshold comparison.
543                let threshold = self.arena.get_threshold(current);
544                if features[feat_idx] <= threshold {
545                    self.arena.get_left(current)
546                } else {
547                    self.arena.get_right(current)
548                }
549            };
550        }
551        current
552    }
553
554    /// Get the prediction value for a leaf node.
555    ///
556    /// Checks (in order): leaf model, live grad/hess statistics, stored leaf value.
557    /// Returns `0.0` if no leaf state exists.
558    #[inline]
559    fn leaf_prediction(&self, leaf_id: NodeId, features: &[f64]) -> f64 {
560        let (raw, leaf_bound) = if let Some(state) = self
561            .leaf_states
562            .get(leaf_id.0 as usize)
563            .and_then(|o| o.as_ref())
564        {
565            // min_hessian_sum: suppress fresh leaves with insufficient samples
566            if let Some(min_h) = self.config.min_hessian_sum {
567                if state.hess_sum < min_h {
568                    return 0.0;
569                }
570            }
571            let val = if let Some(ref model) = state.leaf_model {
572                model.predict(features)
573            } else if state.hess_sum != 0.0 {
574                leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda)
575            } else {
576                self.arena.leaf_value[leaf_id.0 as usize]
577            };
578
579            // Compute per-leaf adaptive bound while state is in scope
580            let bound = self
581                .config
582                .adaptive_leaf_bound
583                .map(|k| adaptive_bound(state, k, self.config.leaf_decay_alpha));
584
585            (val, bound)
586        } else {
587            (0.0, None)
588        };
589
590        // Priority: per-leaf adaptive bound > global max_leaf_output > unclamped
591        if let Some(bound) = leaf_bound {
592            if bound < f64::MAX {
593                return raw.clamp(-bound, bound);
594            }
595        }
596        if let Some(max) = self.config.max_leaf_output {
597            raw.clamp(-max, max)
598        } else {
599            raw
600        }
601    }
602
603    /// Predict using sigmoid-blended soft routing for smooth interpolation.
604    ///
605    /// Instead of hard left/right routing at each split node, uses sigmoid
606    /// blending: `alpha = sigmoid((threshold - feature) / bandwidth)`. The
607    /// prediction is `alpha * left_pred + (1 - alpha) * right_pred`, computed
608    /// recursively from root to leaves.
609    ///
610    /// The result is a continuous function that varies smoothly with every
611    /// feature change — no bins, no boundaries, no jumps.
612    ///
613    /// # Arguments
614    ///
615    /// * `features` - Input feature vector.
616    /// * `bandwidth` - Controls transition sharpness. Smaller = sharper
617    ///   (closer to hard splits), larger = smoother.
618    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
619        self.predict_smooth_recursive(self.root, features, bandwidth)
620    }
621
622    /// Predict using per-feature auto-calibrated bandwidths.
623    ///
624    /// Each feature uses its own bandwidth derived from median split threshold
625    /// gaps. Features with `f64::INFINITY` bandwidth fall back to hard routing.
626    pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
627        self.predict_smooth_auto_recursive(self.root, features, bandwidths)
628    }
629
630    /// Predict with parent-leaf linear interpolation.
631    ///
632    /// Routes to the leaf but blends the leaf prediction with the parent node's
633    /// preserved prediction based on the leaf's hessian sum. Fresh leaves
634    /// (low hess_sum) smoothly transition from parent prediction to their own:
635    ///
636    /// `alpha = leaf_hess / (leaf_hess + lambda)`
637    /// `pred = alpha * leaf_pred + (1 - alpha) * parent_pred`
638    ///
639    /// This fixes static predictions from leaves that split but haven't
640    /// accumulated enough samples to outperform their parent.
641    pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
642        let mut current = self.root;
643        let mut parent = None;
644        while !self.arena.is_leaf(current) {
645            parent = Some(current);
646            let feat_idx = self.arena.get_feature_idx(current) as usize;
647            current = if let Some(mask) = self.arena.get_categorical_mask(current) {
648                let cat_val = features[feat_idx] as u64;
649                if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
650                    self.arena.get_left(current)
651                } else {
652                    self.arena.get_right(current)
653                }
654            } else {
655                let threshold = self.arena.get_threshold(current);
656                if features[feat_idx] <= threshold {
657                    self.arena.get_left(current)
658                } else {
659                    self.arena.get_right(current)
660                }
661            };
662        }
663
664        let leaf_pred = self.leaf_prediction(current, features);
665
666        // No parent (root is leaf) → return leaf prediction directly
667        let parent_id = match parent {
668            Some(p) => p,
669            None => return leaf_pred,
670        };
671
672        // Get parent's preserved prediction from its old leaf state
673        let parent_pred = self.leaf_prediction(parent_id, features);
674
675        // Blend: alpha = leaf_hess / (leaf_hess + lambda)
676        let leaf_hess = self
677            .leaf_states
678            .get(current.0 as usize)
679            .and_then(|o| o.as_ref())
680            .map(|s| s.hess_sum)
681            .unwrap_or(0.0);
682
683        let alpha = leaf_hess / (leaf_hess + self.config.lambda);
684        alpha * leaf_pred + (1.0 - alpha) * parent_pred
685    }
686
687    /// Predict with sibling-based interpolation for feature-continuous predictions.
688    ///
689    /// At the leaf's parent split, blends the leaf prediction with its sibling's
690    /// prediction based on the feature's distance from the split threshold:
691    ///
692    /// Within the margin `m` around the threshold:
693    /// `t = (feature - threshold + m) / (2 * m)`  (0 at left edge, 1 at right edge)
694    /// `pred = (1 - t) * left_pred + t * right_pred`
695    ///
696    /// Outside the margin, returns the routed child's prediction directly.
697    /// The margin `m` is derived from auto-bandwidths if available, otherwise
698    /// defaults to `feature_range / n_bins` heuristic per feature.
699    ///
700    /// This makes predictions vary continuously as features move near split
701    /// boundaries, eliminating step-function artifacts.
702    pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
703        self.predict_sibling_recursive(self.root, features, bandwidths)
704    }
705
706    fn predict_sibling_recursive(&self, node: NodeId, features: &[f64], bandwidths: &[f64]) -> f64 {
707        if self.arena.is_leaf(node) {
708            return self.leaf_prediction(node, features);
709        }
710
711        let feat_idx = self.arena.get_feature_idx(node) as usize;
712        let left = self.arena.get_left(node);
713        let right = self.arena.get_right(node);
714
715        // Categorical splits: always hard routing (no interpolation)
716        if let Some(mask) = self.arena.get_categorical_mask(node) {
717            let cat_val = features[feat_idx] as u64;
718            return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
719                self.predict_sibling_recursive(left, features, bandwidths)
720            } else {
721                self.predict_sibling_recursive(right, features, bandwidths)
722            };
723        }
724
725        let threshold = self.arena.get_threshold(node);
726        let margin = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
727
728        // No valid margin or infinite → hard routing
729        if !margin.is_finite() || margin <= 0.0 {
730            return if features[feat_idx] <= threshold {
731                self.predict_sibling_recursive(left, features, bandwidths)
732            } else {
733                self.predict_sibling_recursive(right, features, bandwidths)
734            };
735        }
736
737        let dist = features[feat_idx] - threshold;
738
739        if dist < -margin {
740            // Firmly in left child territory
741            self.predict_sibling_recursive(left, features, bandwidths)
742        } else if dist > margin {
743            // Firmly in right child territory
744            self.predict_sibling_recursive(right, features, bandwidths)
745        } else {
746            // Within the interpolation margin: linear blend
747            let t = (dist + margin) / (2.0 * margin); // 0.0 at left edge, 1.0 at right edge
748            let left_pred = self.predict_sibling_recursive(left, features, bandwidths);
749            let right_pred = self.predict_sibling_recursive(right, features, bandwidths);
750            (1.0 - t) * left_pred + t * right_pred
751        }
752    }
753
754    /// Collect all split thresholds per feature from the tree arena.
755    ///
756    /// Returns a `Vec<Vec<f64>>` indexed by feature, containing all thresholds
757    /// used in continuous splits. Categorical splits are excluded.
758    pub fn collect_split_thresholds_per_feature(&self) -> Vec<Vec<f64>> {
759        let n = self.n_features.unwrap_or(0);
760        let mut thresholds: Vec<Vec<f64>> = vec![Vec::new(); n];
761
762        for i in 0..self.arena.n_nodes() {
763            if !self.arena.is_leaf[i] && self.arena.categorical_mask[i].is_none() {
764                let feat_idx = self.arena.feature_idx[i] as usize;
765                if feat_idx < n {
766                    thresholds[feat_idx].push(self.arena.threshold[i]);
767                }
768            }
769        }
770
771        thresholds
772    }
773
774    /// Compute per-node bandwidth from nearest neighbor thresholds on the same feature.
775    fn compute_node_bandwidth(&self, node: NodeId, all_thresholds: &[Vec<f64>]) -> f64 {
776        let feat_idx = self.arena.get_feature_idx(node) as usize;
777        let threshold = self.arena.get_threshold(node);
778
779        let thresholds = if feat_idx < all_thresholds.len() {
780            &all_thresholds[feat_idx]
781        } else {
782            return f64::INFINITY;
783        };
784
785        // Find nearest neighbors (thresholds are sorted)
786        let below = thresholds.iter().rev().find(|&&t| t < threshold - 1e-15);
787        let above = thresholds.iter().find(|&&t| t > threshold + 1e-15);
788
789        match (below, above) {
790            (Some(&b), Some(&a)) => (threshold - b).min(a - threshold),
791            (Some(&b), None) => threshold - b,
792            (None, Some(&a)) => a - threshold,
793            (None, None) => f64::INFINITY,
794        }
795    }
796
797    /// Recompute all node bandwidths. Call after structural changes.
798    pub fn recompute_bandwidths(&mut self) {
799        let n = self.arena.n_nodes();
800        self.node_bandwidths.resize(n, f64::INFINITY);
801
802        // Collect and sort thresholds once
803        let mut all_thresholds = self.collect_split_thresholds_per_feature();
804        for v in &mut all_thresholds {
805            v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
806        }
807
808        for i in 0..n {
809            let nid = NodeId(i as u32);
810            if !self.arena.is_leaf(nid) {
811                self.node_bandwidths[i] = self.compute_node_bandwidth(nid, &all_thresholds);
812            } else {
813                self.node_bandwidths[i] = f64::INFINITY;
814            }
815        }
816    }
817
818    /// Predict using per-node auto-bandwidth soft routing.
819    /// Every prediction is a continuous weighted blend — no step discontinuities.
820    pub fn predict_soft_routed(&self, features: &[f64]) -> f64 {
821        self.predict_soft_recursive(self.root, features)
822    }
823
824    fn predict_soft_recursive(&self, node: NodeId, features: &[f64]) -> f64 {
825        if self.arena.is_leaf(node) {
826            return self.leaf_prediction(node, features);
827        }
828
829        let feat_idx = self.arena.get_feature_idx(node) as usize;
830        let left = self.arena.get_left(node);
831        let right = self.arena.get_right(node);
832
833        // Categorical: hard routing
834        if let Some(mask) = self.arena.get_categorical_mask(node) {
835            let cat_val = features[feat_idx] as u64;
836            return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
837                self.predict_soft_recursive(left, features)
838            } else {
839                self.predict_soft_recursive(right, features)
840            };
841        }
842
843        let threshold = self.arena.get_threshold(node);
844        let margin = self
845            .node_bandwidths
846            .get(node.0 as usize)
847            .copied()
848            .unwrap_or(f64::INFINITY);
849
850        let left_pred = self.predict_soft_recursive(left, features);
851        let right_pred = self.predict_soft_recursive(right, features);
852
853        // Non-finite or zero margin: sigmoid fallback
854        if !margin.is_finite() || margin <= 0.0 {
855            let dist = features[feat_idx] - threshold;
856            let scale = math::abs(threshold) * 0.01 + 1e-10;
857            let z = (-dist / scale).clamp(-500.0, 500.0);
858            let t = 1.0 / (1.0 + math::exp(z));
859            return (1.0 - t) * left_pred + t * right_pred;
860        }
861
862        // Linear soft routing: always blend
863        let dist = features[feat_idx] - threshold;
864        let t = ((dist + margin) / (2.0 * margin)).clamp(0.0, 1.0);
865        (1.0 - t) * left_pred + t * right_pred
866    }
867
868    /// Recursive sigmoid-blended prediction traversal.
869    fn predict_smooth_recursive(&self, node: NodeId, features: &[f64], bandwidth: f64) -> f64 {
870        if self.arena.is_leaf(node) {
871            // At a leaf, return the leaf prediction (same as regular predict)
872            return self.leaf_prediction(node, features);
873        }
874
875        let feat_idx = self.arena.get_feature_idx(node) as usize;
876        let left = self.arena.get_left(node);
877        let right = self.arena.get_right(node);
878
879        // Categorical splits: hard routing (sigmoid blending is meaningless for categories)
880        if let Some(mask) = self.arena.get_categorical_mask(node) {
881            let cat_val = features[feat_idx] as u64;
882            return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
883                self.predict_smooth_recursive(left, features, bandwidth)
884            } else {
885                self.predict_smooth_recursive(right, features, bandwidth)
886            };
887        }
888
889        // Continuous split: sigmoid blending for smooth transition around the threshold
890        let threshold = self.arena.get_threshold(node);
891        let z = (threshold - features[feat_idx]) / bandwidth;
892        let alpha = 1.0 / (1.0 + math::exp(-z));
893
894        let left_pred = self.predict_smooth_recursive(left, features, bandwidth);
895        let right_pred = self.predict_smooth_recursive(right, features, bandwidth);
896
897        alpha * left_pred + (1.0 - alpha) * right_pred
898    }
899
900    /// Recursive per-feature-bandwidth smooth prediction traversal.
901    fn predict_smooth_auto_recursive(
902        &self,
903        node: NodeId,
904        features: &[f64],
905        bandwidths: &[f64],
906    ) -> f64 {
907        if self.arena.is_leaf(node) {
908            return self.leaf_prediction(node, features);
909        }
910
911        let feat_idx = self.arena.get_feature_idx(node) as usize;
912        let left = self.arena.get_left(node);
913        let right = self.arena.get_right(node);
914
915        // Categorical splits: always hard routing
916        if let Some(mask) = self.arena.get_categorical_mask(node) {
917            let cat_val = features[feat_idx] as u64;
918            return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
919                self.predict_smooth_auto_recursive(left, features, bandwidths)
920            } else {
921                self.predict_smooth_auto_recursive(right, features, bandwidths)
922            };
923        }
924
925        let threshold = self.arena.get_threshold(node);
926        let bw = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
927
928        // Infinite bandwidth = feature never split on across ensemble → hard routing
929        if !bw.is_finite() {
930            return if features[feat_idx] <= threshold {
931                self.predict_smooth_auto_recursive(left, features, bandwidths)
932            } else {
933                self.predict_smooth_auto_recursive(right, features, bandwidths)
934            };
935        }
936
937        // Sigmoid-blended soft routing with per-feature bandwidth
938        let z = (threshold - features[feat_idx]) / bw;
939        let alpha = 1.0 / (1.0 + math::exp(-z));
940
941        let left_pred = self.predict_smooth_auto_recursive(left, features, bandwidths);
942        let right_pred = self.predict_smooth_auto_recursive(right, features, bandwidths);
943
944        alpha * left_pred + (1.0 - alpha) * right_pred
945    }
946
947    /// Generate the feature mask for split evaluation.
948    ///
949    /// If `feature_subsample_rate` is 1.0, all features are included.
950    /// Otherwise, a random subset is selected via xorshift64.
951    ///
952    /// Uses a `Vec<u64>` bitset for O(1) membership testing, replacing the
953    /// previous O(n) `Vec::contains()` which made the fallback loop O(n²).
954    fn generate_feature_mask(&mut self, n_features: usize) {
955        self.feature_mask.clear();
956
957        if self.config.feature_subsample_rate >= 1.0 {
958            self.feature_mask.extend(0..n_features);
959        } else {
960            let target_count =
961                crate::math::ceil((n_features as f64) * self.config.feature_subsample_rate)
962                    as usize;
963            let target_count = target_count.max(1).min(n_features);
964
965            // Prepare the bitset: one bit per feature, O(1) membership test.
966            let n_words = n_features.div_ceil(64);
967            self.feature_mask_bits.clear();
968            self.feature_mask_bits.resize(n_words, 0u64);
969
970            // Include each feature with probability = subsample_rate.
971            for i in 0..n_features {
972                let r = xorshift64(&mut self.rng_state);
973                let p = (r as f64) / (u64::MAX as f64);
974                if p < self.config.feature_subsample_rate {
975                    self.feature_mask.push(i);
976                    self.feature_mask_bits[i / 64] |= 1u64 << (i % 64);
977                }
978            }
979
980            // If we didn't get enough features, fill up deterministically.
981            // Now O(n) instead of O(n²) thanks to the bitset.
982            if self.feature_mask.len() < target_count {
983                for i in 0..n_features {
984                    if self.feature_mask.len() >= target_count {
985                        break;
986                    }
987                    if self.feature_mask_bits[i / 64] & (1u64 << (i % 64)) == 0 {
988                        self.feature_mask.push(i);
989                        self.feature_mask_bits[i / 64] |= 1u64 << (i % 64);
990                    }
991                }
992            }
993        }
994    }
995
996    /// Attempt a split at the given leaf node.
997    ///
998    /// Returns `true` if a split was performed.
999    fn attempt_split(&mut self, leaf_id: NodeId) -> bool {
1000        let depth = self.arena.get_depth(leaf_id);
1001
1002        // When adaptive_depth is enabled, max_depth * 2 is the hard safety ceiling;
1003        // the per-split CIR test handles generalization. Otherwise, use static max_depth.
1004        let hard_ceiling = if self.config.adaptive_depth.is_some() {
1005            self.config.max_depth.saturating_mul(2)
1006        } else {
1007            self.config.max_depth
1008        };
1009        let at_max_depth = depth as usize >= hard_ceiling;
1010
1011        if at_max_depth {
1012            // Only proceed if split re-evaluation is enabled and the interval
1013            // has elapsed since the last evaluation at this leaf.
1014            match self.config.split_reeval_interval {
1015                None => return false,
1016                Some(interval) => {
1017                    let state = match self
1018                        .leaf_states
1019                        .get(leaf_id.0 as usize)
1020                        .and_then(|o| o.as_ref())
1021                    {
1022                        Some(s) => s,
1023                        None => return false,
1024                    };
1025                    let sample_count = self.arena.get_sample_count(leaf_id);
1026                    if sample_count - state.last_reeval_count < interval as u64 {
1027                        return false;
1028                    }
1029                    // Fall through to evaluate potential split.
1030                }
1031            }
1032        }
1033
1034        let n_features = match self.n_features {
1035            Some(n) => n,
1036            None => return false,
1037        };
1038
1039        let sample_count = self.arena.get_sample_count(leaf_id);
1040        if sample_count < self.config.grace_period as u64 {
1041            return false;
1042        }
1043
1044        // Generate the feature mask for this split evaluation.
1045        self.generate_feature_mask(n_features);
1046
1047        // Materialize pending lazy decay before reading histogram data.
1048        // This converts un-decayed coordinates to true decayed values so
1049        // split evaluation sees correct gradient/hessian sums. O(n_features * n_bins)
1050        // but amortized over grace_period samples -- not per-sample cost.
1051        if self.config.leaf_decay_alpha.is_some() {
1052            if let Some(state) = self
1053                .leaf_states
1054                .get_mut(leaf_id.0 as usize)
1055                .and_then(|o| o.as_mut())
1056            {
1057                if let Some(ref mut histograms) = state.histograms {
1058                    histograms.materialize_decay();
1059                }
1060            }
1061        }
1062
1063        // Evaluate splits for each feature in the mask.
1064        // We need to borrow leaf_states immutably while feature_mask is borrowed.
1065        // Collect candidates first.
1066        let state = match self
1067            .leaf_states
1068            .get(leaf_id.0 as usize)
1069            .and_then(|o| o.as_ref())
1070        {
1071            Some(s) => s,
1072            None => return false,
1073        };
1074
1075        let histograms = match &state.histograms {
1076            Some(h) => h,
1077            None => return false,
1078        };
1079
1080        // Collect (feature_idx, best_split_candidate, optional_fisher_order) for
1081        // each feature in the mask. For categorical features, we reorder bins by
1082        // Fisher optimal binary partitioning before evaluation.
1083        let feature_types = &self.config.feature_types;
1084        let mut candidates: Vec<(usize, SplitCandidate, Option<Vec<usize>>)> = Vec::new();
1085
1086        for &feat_idx in &self.feature_mask {
1087            if feat_idx >= histograms.n_features() {
1088                continue;
1089            }
1090            let hist = &histograms.histograms[feat_idx];
1091            let total_grad = hist.total_gradient();
1092            let total_hess = hist.total_hessian();
1093
1094            let is_categorical = feature_types
1095                .as_ref()
1096                .is_some_and(|ft| feat_idx < ft.len() && ft[feat_idx] == FeatureType::Categorical);
1097
1098            if is_categorical {
1099                // Fisher optimal binary partitioning:
1100                // 1. Compute gradient_sum/hessian_sum ratio per bin
1101                // 2. Sort bins by this ratio
1102                // 3. Evaluate splits on the sorted order
1103                let n_bins = hist.grad_sums.len();
1104                if n_bins < 2 {
1105                    continue;
1106                }
1107
1108                // Build (bin_index, ratio) pairs, filtering out empty bins
1109                let mut bin_order: Vec<usize> = (0..n_bins)
1110                    .filter(|&i| math::abs(hist.hess_sums[i]) > 1e-15)
1111                    .collect();
1112
1113                if bin_order.len() < 2 {
1114                    continue;
1115                }
1116
1117                // Sort by grad_sum / hess_sum ratio (ascending)
1118                bin_order.sort_by(|&a, &b| {
1119                    let ratio_a = hist.grad_sums[a] / hist.hess_sums[a];
1120                    let ratio_b = hist.grad_sums[b] / hist.hess_sums[b];
1121                    ratio_a
1122                        .partial_cmp(&ratio_b)
1123                        .unwrap_or(core::cmp::Ordering::Equal)
1124                });
1125
1126                // Reorder grad/hess sums according to Fisher order
1127                let sorted_grads: Vec<f64> = bin_order.iter().map(|&i| hist.grad_sums[i]).collect();
1128                let sorted_hess: Vec<f64> = bin_order.iter().map(|&i| hist.hess_sums[i]).collect();
1129
1130                if let Some(candidate) = self.split_criterion.evaluate(
1131                    &sorted_grads,
1132                    &sorted_hess,
1133                    total_grad,
1134                    total_hess,
1135                    self.config.gamma,
1136                    self.config.lambda,
1137                ) {
1138                    candidates.push((feat_idx, candidate, Some(bin_order)));
1139                }
1140            } else {
1141                // Standard continuous feature -- evaluate as-is
1142                if let Some(candidate) = self.split_criterion.evaluate(
1143                    &hist.grad_sums,
1144                    &hist.hess_sums,
1145                    total_grad,
1146                    total_hess,
1147                    self.config.gamma,
1148                    self.config.lambda,
1149                ) {
1150                    candidates.push((feat_idx, candidate, None));
1151                }
1152            }
1153        }
1154
1155        // Filter out candidates that violate monotonic constraints.
1156        if let Some(ref mc) = self.config.monotone_constraints {
1157            candidates.retain(|(feat_idx, candidate, _)| {
1158                if *feat_idx >= mc.len() {
1159                    return true; // No constraint for this feature
1160                }
1161                let constraint = mc[*feat_idx];
1162                if constraint == 0 {
1163                    return true; // Unconstrained
1164                }
1165
1166                let left_val =
1167                    leaf_weight(candidate.left_grad, candidate.left_hess, self.config.lambda);
1168                let right_val = leaf_weight(
1169                    candidate.right_grad,
1170                    candidate.right_hess,
1171                    self.config.lambda,
1172                );
1173
1174                if constraint > 0 {
1175                    // Non-decreasing: left_value <= right_value
1176                    left_val <= right_val
1177                } else {
1178                    // Non-increasing: left_value >= right_value
1179                    left_val >= right_val
1180                }
1181            });
1182        }
1183
1184        if candidates.is_empty() {
1185            return false;
1186        }
1187
1188        // Sort candidates by gain descending.
1189        candidates.sort_by(|a, b| {
1190            b.1.gain
1191                .partial_cmp(&a.1.gain)
1192                .unwrap_or(core::cmp::Ordering::Equal)
1193        });
1194
1195        let best_gain = candidates[0].1.gain;
1196        let second_best_gain = if candidates.len() > 1 {
1197            candidates[1].1.gain
1198        } else {
1199            0.0
1200        };
1201
1202        // Per-split information criterion (Lunde-Kleppe-Skaug 2020).
1203        // Acts as a FIRST gate before the Hoeffding bound check.
1204        if let Some(cir_factor) = self.config.adaptive_depth {
1205            let n = sample_count as f64;
1206            if n > 1.0 {
1207                // Use effective_n with EWMA decay if configured
1208                let effective_n = match self.config.leaf_decay_alpha {
1209                    Some(alpha) => n.min(1.0 / (1.0 - alpha)),
1210                    None => n,
1211                };
1212
1213                // Get gradient variance from Welford stats in the leaf state
1214                let grad_var = self
1215                    .leaf_states
1216                    .get(leaf_id.0 as usize)
1217                    .and_then(|o| o.as_ref())
1218                    .map(|leaf_state| {
1219                        if leaf_state.clip_grad_count > 1 {
1220                            leaf_state.clip_grad_m2 / (leaf_state.clip_grad_count as f64 - 1.0)
1221                        } else {
1222                            // Fallback: estimate from grad/hess sums
1223                            let mean_grad = leaf_state.grad_sum / leaf_state.hess_sum.max(1.0);
1224                            mean_grad * mean_grad + 1.0
1225                        }
1226                    })
1227                    .unwrap_or(1.0);
1228
1229                let n_feat = self.n_features.unwrap_or(1) as f64;
1230                let penalty = cir_factor * grad_var / effective_n * n_feat;
1231
1232                if best_gain <= penalty {
1233                    return false; // Don't split — insufficient generalization evidence
1234                }
1235            }
1236        }
1237
1238        // Hoeffding bound: epsilon = sqrt(R^2 * ln(1/delta) / (2 * n))
1239        // R = 1.0 (conservative bound on the range of the gain function).
1240        //
1241        // With EWMA decay, the effective sample size is bounded by 1/(1-alpha).
1242        // We cap n at this value to prevent spurious splits from artificially
1243        // tight bounds when decay is active.
1244        let r_squared = 1.0;
1245        let n = sample_count as f64;
1246        let effective_n = match self.config.leaf_decay_alpha {
1247            Some(alpha) => n.min(1.0 / (1.0 - alpha)),
1248            None => n,
1249        };
1250        let ln_inv_delta = math::ln(1.0 / self.config.delta);
1251        let epsilon = math::sqrt(r_squared * ln_inv_delta / (2.0 * effective_n));
1252
1253        // Split condition: the best is significantly better than second-best,
1254        // OR the bound is already so tight that more samples won't help.
1255        let gap = best_gain - second_best_gain;
1256        if gap <= epsilon && epsilon >= TAU {
1257            // If this was a re-evaluation at max depth, update the count
1258            // so we don't re-evaluate again until the next interval elapses.
1259            if at_max_depth {
1260                if let Some(state) = self
1261                    .leaf_states
1262                    .get_mut(leaf_id.0 as usize)
1263                    .and_then(|o| o.as_mut())
1264                {
1265                    state.last_reeval_count = sample_count;
1266                }
1267            }
1268            return false;
1269        }
1270
1271        // --- Execute the split ---
1272        let (best_feat_idx, ref best_candidate, ref fisher_order) = candidates[0];
1273
1274        // Track split gain for feature importance.
1275        if best_feat_idx < self.split_gains.len() {
1276            self.split_gains[best_feat_idx] += best_candidate.gain;
1277        }
1278
1279        let best_hist = &histograms.histograms[best_feat_idx];
1280
1281        let left_value = leaf_weight(
1282            best_candidate.left_grad,
1283            best_candidate.left_hess,
1284            self.config.lambda,
1285        );
1286        let right_value = leaf_weight(
1287            best_candidate.right_grad,
1288            best_candidate.right_hess,
1289            self.config.lambda,
1290        );
1291
1292        // Perform the split -- categorical or continuous.
1293        let (left_id, right_id) = if let Some(ref order) = fisher_order {
1294            // Categorical split: build a bitmask from the Fisher-sorted partition.
1295            // bin_idx in the sorted order means bins order[0..=bin_idx] go left.
1296            // We need to map those back to original bin indices and set their bits.
1297            //
1298            // For categorical features, bin index = category value (since we use
1299            // one bin per category with midpoint edges).
1300            let mut mask: u64 = 0;
1301            for &sorted_pos in order.iter().take(best_candidate.bin_idx + 1) {
1302                // sorted_pos is the original bin index; for categorical features,
1303                // bin index corresponds to the category's position in sorted categories.
1304                // The actual category value is stored as an integer that maps to this bin.
1305                if sorted_pos < 64 {
1306                    mask |= 1u64 << sorted_pos;
1307                }
1308            }
1309
1310            // Threshold stores 0.0 for categorical splits (routing uses mask).
1311            self.arena.split_leaf_categorical(
1312                leaf_id,
1313                best_feat_idx as u32,
1314                0.0,
1315                left_value,
1316                right_value,
1317                mask,
1318            )
1319        } else {
1320            // Continuous split: standard threshold from bin edge.
1321            let threshold = if best_candidate.bin_idx < best_hist.edges.edges.len() {
1322                best_hist.edges.edges[best_candidate.bin_idx]
1323            } else {
1324                f64::MAX
1325            };
1326
1327            self.arena.split_leaf(
1328                leaf_id,
1329                best_feat_idx as u32,
1330                threshold,
1331                left_value,
1332                right_value,
1333            )
1334        };
1335
1336        // Build child histograms using the subtraction trick.
1337        // The "left" child gets a fresh histogram set built from the parent's
1338        // bins, populated by scanning parent bins [0..=bin_idx].
1339        // Instead of re-scanning, we use the subtraction trick: one child
1340        // gets the parent's histograms minus the other child's.
1341        //
1342        // Strategy: build the left child's histogram directly from the
1343        // parent histogram for each feature (summing bins 0..=best_bin for
1344        // the split feature). Then the right child = parent - left.
1345        //
1346        // Actually, for a streaming tree, the cleaner approach is:
1347        // - Remove the parent's state
1348        // - Create fresh states for both children (they'll accumulate from
1349        //   new samples going forward)
1350        // - BUT we can seed one child with the parent's histograms by
1351        //   constructing "virtual" histograms from the parent's data.
1352        //
1353        // The simplest correct approach: both children start fresh with
1354        // pre-computed bin edges from the parent, so they're immediately
1355        // ready to accumulate. We don't carry forward the parent's histogram
1356        // data because new samples will naturally populate the children.
1357        //
1358        // However, to be more efficient, we CAN carry forward histogram data
1359        // using the subtraction trick. Let's do this properly:
1360
1361        let parent_state = self
1362            .leaf_states
1363            .get_mut(leaf_id.0 as usize)
1364            .and_then(|o| o.take());
1365        let nf = n_features;
1366
1367        // Ensure Vec is large enough for child NodeIds.
1368        let max_child = left_id.0.max(right_id.0) as usize;
1369        if self.leaf_states.len() <= max_child {
1370            self.leaf_states.resize_with(max_child + 1, || None);
1371        }
1372
1373        if let Some(parent) = parent_state {
1374            if let Some(parent_hists) = parent.histograms {
1375                // Build left child histograms from the parent.
1376                let edges_per_feature: Vec<BinEdges> = parent_hists
1377                    .histograms
1378                    .iter()
1379                    .map(|h| h.edges.clone())
1380                    .collect();
1381
1382                // The left child inherits a copy of the parent histograms,
1383                // but we really want to compute how much of the parent's data
1384                // would have gone left vs right. We don't have that per-feature
1385                // breakdown for features other than the split feature.
1386                //
1387                // Correct approach: create fresh histogram states for both
1388                // children with the same bin edges. They start empty but
1389                // with bins_ready = true, so new samples immediately accumulate.
1390                let left_hists = LeafHistograms::new(&edges_per_feature);
1391                let right_hists = LeafHistograms::new(&edges_per_feature);
1392
1393                let ft = self.config.feature_types.as_deref();
1394                let child_binners_l = make_binners(nf, ft);
1395                let child_binners_r = make_binners(nf, ft);
1396
1397                // Warm-start children from parent's learned leaf model.
1398                // If parent has a model, children inherit its weights (resetting
1399                // optimizer state). If parent has no model (ClosedForm), children
1400                // also get None -- the fast path stays fast.
1401                let left_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1402                let right_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1403
1404                let left_state = LeafState {
1405                    histograms: Some(left_hists),
1406                    binners: child_binners_l,
1407                    bins_ready: true,
1408                    grad_sum: 0.0,
1409                    hess_sum: 0.0,
1410                    last_reeval_count: 0,
1411                    clip_grad_mean: 0.0,
1412                    clip_grad_m2: 0.0,
1413                    clip_grad_count: 0,
1414                    output_mean: 0.0,
1415                    output_m2: 0.0,
1416                    output_count: 0,
1417                    leaf_model: left_model,
1418                };
1419
1420                let right_state = LeafState {
1421                    histograms: Some(right_hists),
1422                    binners: child_binners_r,
1423                    bins_ready: true,
1424                    grad_sum: 0.0,
1425                    hess_sum: 0.0,
1426                    last_reeval_count: 0,
1427                    clip_grad_mean: 0.0,
1428                    clip_grad_m2: 0.0,
1429                    clip_grad_count: 0,
1430                    output_mean: 0.0,
1431                    output_m2: 0.0,
1432                    output_count: 0,
1433                    leaf_model: right_model,
1434                };
1435
1436                self.leaf_states[left_id.0 as usize] = Some(left_state);
1437                self.leaf_states[right_id.0 as usize] = Some(right_state);
1438            } else {
1439                // Parent didn't have histograms (shouldn't happen if bins_ready).
1440                let ft = self.config.feature_types.as_deref();
1441                let mut ls = LeafState::new_with_types(nf, ft);
1442                ls.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1443                self.leaf_states[left_id.0 as usize] = Some(ls);
1444                let mut rs = LeafState::new_with_types(nf, ft);
1445                rs.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1446                self.leaf_states[right_id.0 as usize] = Some(rs);
1447            }
1448        } else {
1449            // No parent state found (shouldn't happen).
1450            let ft = self.config.feature_types.as_deref();
1451            let mut ls = LeafState::new_with_types(nf, ft);
1452            ls.leaf_model = self.make_leaf_model(left_id);
1453            self.leaf_states[left_id.0 as usize] = Some(ls);
1454            let mut rs = LeafState::new_with_types(nf, ft);
1455            rs.leaf_model = self.make_leaf_model(right_id);
1456            self.leaf_states[right_id.0 as usize] = Some(rs);
1457        }
1458
1459        // Recompute per-node bandwidths after structural change.
1460        self.recompute_bandwidths();
1461
1462        true
1463    }
1464}
1465
1466impl StreamingTree for HoeffdingTree {
1467    /// Train the tree on a single sample.
1468    ///
1469    /// Routes the sample to its leaf, updates histogram accumulators, and
1470    /// attempts a split if the Hoeffding bound is satisfied.
1471    fn train_one(&mut self, features: &[f64], gradient: f64, hessian: f64) {
1472        self.samples_seen += 1;
1473
1474        // Initialize n_features on first sample.
1475        let n_features = if let Some(n) = self.n_features {
1476            n
1477        } else {
1478            let n = features.len();
1479            self.n_features = Some(n);
1480            self.split_gains.resize(n, 0.0);
1481
1482            // Re-initialize the root's leaf state now that we know n_features.
1483            if let Some(state) = self
1484                .leaf_states
1485                .get_mut(self.root.0 as usize)
1486                .and_then(|o| o.as_mut())
1487            {
1488                state.binners = make_binners(n, self.config.feature_types.as_deref());
1489            }
1490            n
1491        };
1492
1493        debug_assert_eq!(
1494            features.len(),
1495            n_features,
1496            "feature count mismatch: got {} but expected {}",
1497            features.len(),
1498            n_features,
1499        );
1500
1501        // Route to leaf.
1502        let leaf_id = self.route_to_leaf(features);
1503
1504        // Increment the sample count in the arena.
1505        self.arena.increment_sample_count(leaf_id);
1506        let sample_count = self.arena.get_sample_count(leaf_id);
1507
1508        // Get or create the leaf state.
1509        let idx = leaf_id.0 as usize;
1510        if self.leaf_states.len() <= idx {
1511            self.leaf_states.resize_with(idx + 1, || None);
1512        }
1513        if self.leaf_states[idx].is_none() {
1514            self.leaf_states[idx] = Some(LeafState::new_with_types(
1515                n_features,
1516                self.config.feature_types.as_deref(),
1517            ));
1518        }
1519        let state = self.leaf_states[idx].as_mut().unwrap();
1520
1521        // Apply per-leaf gradient clipping if enabled.
1522        let gradient = if let Some(sigma) = self.config.gradient_clip_sigma {
1523            clip_gradient(state, gradient, sigma)
1524        } else {
1525            gradient
1526        };
1527
1528        // If bins are not yet ready, check if we've reached the grace period.
1529        if !state.bins_ready {
1530            // Observe feature values in the binners.
1531            for (binner, &val) in state.binners.iter_mut().zip(features.iter()) {
1532                binner.observe(val);
1533            }
1534
1535            // Accumulate running gradient/hessian sums (with optional EWMA decay).
1536            if let Some(alpha) = self.config.leaf_decay_alpha {
1537                state.grad_sum = alpha * state.grad_sum + gradient;
1538                state.hess_sum = alpha * state.hess_sum + hessian;
1539            } else {
1540                state.grad_sum += gradient;
1541                state.hess_sum += hessian;
1542            }
1543
1544            // Update the leaf value from running sums.
1545            let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
1546            self.arena.set_leaf_value(leaf_id, lw);
1547
1548            // Track per-leaf output weight for adaptive bounds.
1549            if self.config.adaptive_leaf_bound.is_some() {
1550                update_output_stats(state, lw, self.config.leaf_decay_alpha);
1551            }
1552
1553            // Update the leaf model if one exists (linear / MLP).
1554            if let Some(ref mut model) = state.leaf_model {
1555                model.update(features, gradient, hessian, self.config.lambda);
1556            }
1557
1558            // Check if we've reached the grace period to compute bin edges.
1559            if sample_count >= self.config.grace_period as u64 {
1560                let edges_per_feature: Vec<BinEdges> = state
1561                    .binners
1562                    .iter()
1563                    .map(|b| b.compute_edges(self.config.n_bins))
1564                    .collect();
1565
1566                let mut histograms = LeafHistograms::new(&edges_per_feature);
1567
1568                // We don't have the raw samples to replay into the histogram,
1569                // but we DO have the running grad/hess sums. We can't distribute
1570                // them across bins retroactively. The histograms start empty and
1571                // will accumulate from the next sample onward.
1572                // However, we should NOT lose the current sample. Let's accumulate
1573                // this sample into the newly created histograms.
1574                if let Some(alpha) = self.config.leaf_decay_alpha {
1575                    histograms.accumulate_with_decay(features, gradient, hessian, alpha);
1576                } else {
1577                    histograms.accumulate(features, gradient, hessian);
1578                }
1579
1580                state.histograms = Some(histograms);
1581                state.bins_ready = true;
1582            }
1583
1584            return;
1585        }
1586
1587        // Bins are ready -- accumulate into histograms (with optional decay).
1588        if let Some(ref mut histograms) = state.histograms {
1589            if let Some(alpha) = self.config.leaf_decay_alpha {
1590                histograms.accumulate_with_decay(features, gradient, hessian, alpha);
1591            } else {
1592                histograms.accumulate(features, gradient, hessian);
1593            }
1594        }
1595
1596        // Update running gradient/hessian sums and leaf value (with optional EWMA decay).
1597        if let Some(alpha) = self.config.leaf_decay_alpha {
1598            state.grad_sum = alpha * state.grad_sum + gradient;
1599            state.hess_sum = alpha * state.hess_sum + hessian;
1600        } else {
1601            state.grad_sum += gradient;
1602            state.hess_sum += hessian;
1603        }
1604        let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
1605        self.arena.set_leaf_value(leaf_id, lw);
1606
1607        // Track per-leaf output weight for adaptive bounds.
1608        if self.config.adaptive_leaf_bound.is_some() {
1609            update_output_stats(state, lw, self.config.leaf_decay_alpha);
1610        }
1611
1612        // Update the leaf model if one exists (linear / MLP).
1613        if let Some(ref mut model) = state.leaf_model {
1614            model.update(features, gradient, hessian, self.config.lambda);
1615        }
1616
1617        // Attempt split.
1618        // We only try every grace_period samples to avoid excessive computation.
1619        if sample_count % (self.config.grace_period as u64) == 0 {
1620            self.attempt_split(leaf_id);
1621        }
1622    }
1623
1624    /// Predict the leaf value for a feature vector.
1625    ///
1626    /// Routes from the root to a leaf via threshold comparisons and returns
1627    /// the leaf's current weight.
1628    fn predict(&self, features: &[f64]) -> f64 {
1629        let leaf_id = self.route_to_leaf(features);
1630        self.leaf_prediction(leaf_id, features)
1631    }
1632
1633    /// Current number of leaf nodes.
1634    #[inline]
1635    fn n_leaves(&self) -> usize {
1636        self.arena.n_leaves()
1637    }
1638
1639    /// Total number of samples seen since creation.
1640    #[inline]
1641    fn n_samples_seen(&self) -> u64 {
1642        self.samples_seen
1643    }
1644
1645    /// Reset to initial state with a single root leaf.
1646    fn reset(&mut self) {
1647        self.arena.reset();
1648        let root = self.arena.add_leaf(0);
1649        self.root = root;
1650        self.leaf_states.clear();
1651
1652        // Insert a placeholder leaf state for the new root.
1653        let n_features = self.n_features.unwrap_or(0);
1654        self.leaf_states.resize_with(root.0 as usize + 1, || None);
1655        let mut root_state =
1656            LeafState::new_with_types(n_features, self.config.feature_types.as_deref());
1657        root_state.leaf_model = self.make_leaf_model(root);
1658        self.leaf_states[root.0 as usize] = Some(root_state);
1659
1660        self.samples_seen = 0;
1661        self.feature_mask.clear();
1662        self.feature_mask_bits.clear();
1663        self.rng_state = self.config.seed;
1664        self.split_gains.iter_mut().for_each(|g| *g = 0.0);
1665        self.node_bandwidths.clear();
1666    }
1667
1668    fn split_gains(&self) -> &[f64] {
1669        &self.split_gains
1670    }
1671
1672    fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
1673        let leaf_id = self.route_to_leaf(features);
1674        let value = self.leaf_prediction(leaf_id, features);
1675        if let Some(state) = self
1676            .leaf_states
1677            .get(leaf_id.0 as usize)
1678            .and_then(|o| o.as_ref())
1679        {
1680            // Variance of the leaf weight estimate = 1 / (H_sum + lambda)
1681            let variance = 1.0 / (state.hess_sum + self.config.lambda);
1682            (value, variance)
1683        } else {
1684            (value, f64::INFINITY)
1685        }
1686    }
1687}
1688
1689impl Clone for HoeffdingTree {
1690    fn clone(&self) -> Self {
1691        Self {
1692            arena: self.arena.clone(),
1693            root: self.root,
1694            config: self.config.clone(),
1695            leaf_states: self.leaf_states.clone(),
1696            n_features: self.n_features,
1697            samples_seen: self.samples_seen,
1698            split_criterion: self.split_criterion,
1699            feature_mask: self.feature_mask.clone(),
1700            feature_mask_bits: self.feature_mask_bits.clone(),
1701            rng_state: self.rng_state,
1702            split_gains: self.split_gains.clone(),
1703            node_bandwidths: self.node_bandwidths.clone(),
1704        }
1705    }
1706}
1707
1708// SAFETY: All fields are Send + Sync. BinnerKind is a concrete enum with
1709// Send + Sync variants. XGBoostGain is stateless. Vec<Option<LeafState>>
1710// and Vec fields are trivially Send + Sync.
1711unsafe impl Send for HoeffdingTree {}
1712unsafe impl Sync for HoeffdingTree {}
1713
1714#[cfg(test)]
1715mod tests {
1716    use super::*;
1717    use crate::tree::builder::TreeConfig;
1718    use crate::tree::StreamingTree;
1719
1720    /// Simple xorshift64 for test reproducibility (same as the tree uses).
1721    fn test_xorshift(state: &mut u64) -> u64 {
1722        xorshift64(state)
1723    }
1724
1725    /// Generate a pseudo-random f64 in [0, 1) from the RNG state.
1726    fn test_rand_f64(state: &mut u64) -> f64 {
1727        let r = test_xorshift(state);
1728        (r as f64) / (u64::MAX as f64)
1729    }
1730
1731    // -----------------------------------------------------------------------
1732    // Test 1: Single sample train + predict returns non-NaN.
1733    // -----------------------------------------------------------------------
1734    #[test]
1735    fn single_sample_predict_not_nan() {
1736        let config = TreeConfig::new().grace_period(10);
1737        let mut tree = HoeffdingTree::new(config);
1738
1739        let features = vec![1.0, 2.0, 3.0];
1740        tree.train_one(&features, -0.5, 1.0);
1741
1742        let pred = tree.predict(&features);
1743        assert!(!pred.is_nan(), "prediction should not be NaN, got {}", pred);
1744        assert!(
1745            pred.is_finite(),
1746            "prediction should be finite, got {}",
1747            pred
1748        );
1749
1750        // With gradient=-0.5, hessian=1.0, lambda=1.0:
1751        // leaf_weight = -(-0.5) / (1.0 + 1.0) = 0.25
1752        assert!((pred - 0.25).abs() < 1e-10, "expected ~0.25, got {}", pred);
1753    }
1754
1755    // -----------------------------------------------------------------------
1756    // Test 2: Train 1000 samples from y=2*x + noise, verify RMSE decreases.
1757    // -----------------------------------------------------------------------
1758    #[test]
1759    fn linear_signal_rmse_improves() {
1760        let config = TreeConfig::new()
1761            .max_depth(4)
1762            .n_bins(32)
1763            .grace_period(50)
1764            .lambda(0.1)
1765            .gamma(0.0)
1766            .delta(1e-3);
1767
1768        let mut tree = HoeffdingTree::new(config);
1769        let mut rng_state: u64 = 12345;
1770
1771        // Generate training data: y = 2*x, with x in [0, 10].
1772        // For gradient boosting, gradient = prediction - target (for squared loss),
1773        // hessian = 1.0.
1774        //
1775        // We'll simulate a simple boosting loop:
1776        // - Start with prediction = 0 for all points.
1777        // - gradient = pred - target = 0 - y = -y
1778        // - hessian = 1.0
1779
1780        let n_train = 1000;
1781        let mut features_all: Vec<f64> = Vec::with_capacity(n_train);
1782        let mut targets: Vec<f64> = Vec::with_capacity(n_train);
1783
1784        for _ in 0..n_train {
1785            let x = test_rand_f64(&mut rng_state) * 10.0;
1786            let noise = (test_rand_f64(&mut rng_state) - 0.5) * 0.5;
1787            let y = 2.0 * x + noise;
1788            features_all.push(x);
1789            targets.push(y);
1790        }
1791
1792        // Compute initial RMSE (prediction = 0).
1793        let initial_mse: f64 = targets.iter().map(|y| y * y).sum::<f64>() / n_train as f64;
1794        let initial_rmse = initial_mse.sqrt();
1795
1796        // Train the tree.
1797        for i in 0..n_train {
1798            let feat = [features_all[i]];
1799            let pred = tree.predict(&feat);
1800            // For squared loss: gradient = pred - target, hessian = 1.0.
1801            let gradient = pred - targets[i];
1802            let hessian = 1.0;
1803            tree.train_one(&feat, gradient, hessian);
1804        }
1805
1806        // Compute post-training RMSE.
1807        let mut post_mse = 0.0;
1808        for i in 0..n_train {
1809            let feat = [features_all[i]];
1810            let pred = tree.predict(&feat);
1811            let err = pred - targets[i];
1812            post_mse += err * err;
1813        }
1814        post_mse /= n_train as f64;
1815        let post_rmse = post_mse.sqrt();
1816
1817        assert!(
1818            post_rmse < initial_rmse,
1819            "RMSE should decrease after training: initial={:.4}, post={:.4}",
1820            initial_rmse,
1821            post_rmse,
1822        );
1823    }
1824
1825    // -----------------------------------------------------------------------
1826    // Test 3: No splits before grace_period samples.
1827    // -----------------------------------------------------------------------
1828    #[test]
1829    fn no_splits_before_grace_period() {
1830        let grace = 100;
1831        let config = TreeConfig::new()
1832            .grace_period(grace)
1833            .max_depth(4)
1834            .n_bins(16)
1835            .delta(1e-1); // Very lenient delta to make splits easy.
1836
1837        let mut tree = HoeffdingTree::new(config);
1838        let mut rng_state: u64 = 99999;
1839
1840        // Train grace_period - 1 samples.
1841        for _ in 0..(grace - 1) {
1842            let x = test_rand_f64(&mut rng_state) * 10.0;
1843            let y = 2.0 * x;
1844            let feat = [x];
1845            let pred = tree.predict(&feat);
1846            tree.train_one(&feat, pred - y, 1.0);
1847        }
1848
1849        assert_eq!(
1850            tree.n_leaves(),
1851            1,
1852            "should be exactly 1 leaf before grace_period, got {}",
1853            tree.n_leaves()
1854        );
1855    }
1856
1857    // -----------------------------------------------------------------------
1858    // Test 4: Tree does not exceed max_depth.
1859    // -----------------------------------------------------------------------
1860    #[test]
1861    fn respects_max_depth() {
1862        let max_depth = 3;
1863        let config = TreeConfig::new()
1864            .max_depth(max_depth)
1865            .grace_period(20)
1866            .n_bins(16)
1867            .lambda(0.01)
1868            .gamma(0.0)
1869            .delta(1e-1); // Very lenient.
1870
1871        let mut tree = HoeffdingTree::new(config);
1872        let mut rng_state: u64 = 7777;
1873
1874        // Train many samples with a clear signal to force splitting.
1875        for _ in 0..5000 {
1876            let x = test_rand_f64(&mut rng_state) * 10.0;
1877            let y = if x < 2.5 {
1878                -5.0
1879            } else if x < 5.0 {
1880                -1.0
1881            } else if x < 7.5 {
1882                1.0
1883            } else {
1884                5.0
1885            };
1886            let feat = [x];
1887            let pred = tree.predict(&feat);
1888            tree.train_one(&feat, pred - y, 1.0);
1889        }
1890
1891        // Maximum number of leaves at depth d is 2^d.
1892        let max_leaves = 1usize << max_depth;
1893        assert!(
1894            tree.n_leaves() <= max_leaves,
1895            "tree has {} leaves, but max_depth={} allows at most {}",
1896            tree.n_leaves(),
1897            max_depth,
1898            max_leaves,
1899        );
1900    }
1901
1902    // -----------------------------------------------------------------------
1903    // Test 5: Reset works -- tree returns to single leaf.
1904    // -----------------------------------------------------------------------
1905    #[test]
1906    fn reset_returns_to_single_leaf() {
1907        let config = TreeConfig::new()
1908            .grace_period(20)
1909            .max_depth(4)
1910            .n_bins(16)
1911            .delta(1e-1);
1912
1913        let mut tree = HoeffdingTree::new(config);
1914        let mut rng_state: u64 = 54321;
1915
1916        // Train enough to potentially cause splits.
1917        for _ in 0..2000 {
1918            let x = test_rand_f64(&mut rng_state) * 10.0;
1919            let y = 3.0 * x - 5.0;
1920            let feat = [x];
1921            let pred = tree.predict(&feat);
1922            tree.train_one(&feat, pred - y, 1.0);
1923        }
1924
1925        let pre_reset_samples = tree.n_samples_seen();
1926        assert!(pre_reset_samples > 0);
1927
1928        tree.reset();
1929
1930        assert_eq!(
1931            tree.n_leaves(),
1932            1,
1933            "after reset, should have exactly 1 leaf"
1934        );
1935        assert_eq!(
1936            tree.n_samples_seen(),
1937            0,
1938            "after reset, samples_seen should be 0"
1939        );
1940
1941        // Predict should still work (returns 0.0 from empty leaf).
1942        let pred = tree.predict(&[5.0]);
1943        assert!(
1944            pred.abs() < 1e-10,
1945            "prediction after reset should be ~0.0, got {}",
1946            pred
1947        );
1948    }
1949
1950    // -----------------------------------------------------------------------
1951    // Test 6: Multiple features -- verify tree uses different features.
1952    // -----------------------------------------------------------------------
1953    #[test]
1954    fn multi_feature_training() {
1955        let config = TreeConfig::new()
1956            .grace_period(30)
1957            .max_depth(4)
1958            .n_bins(16)
1959            .lambda(0.1)
1960            .delta(1e-2);
1961
1962        let mut tree = HoeffdingTree::new(config);
1963        let mut rng_state: u64 = 11111;
1964
1965        // y = x0 + 2*x1, two features.
1966        for _ in 0..1000 {
1967            let x0 = test_rand_f64(&mut rng_state) * 5.0;
1968            let x1 = test_rand_f64(&mut rng_state) * 5.0;
1969            let y = x0 + 2.0 * x1;
1970            let feat = [x0, x1];
1971            let pred = tree.predict(&feat);
1972            tree.train_one(&feat, pred - y, 1.0);
1973        }
1974
1975        // Just verify it trained without panicking and produces finite predictions.
1976        let pred = tree.predict(&[2.5, 2.5]);
1977        assert!(
1978            pred.is_finite(),
1979            "multi-feature prediction should be finite"
1980        );
1981        assert_eq!(tree.n_samples_seen(), 1000);
1982    }
1983
1984    // -----------------------------------------------------------------------
1985    // Test 7: Feature subsampling does not panic.
1986    // -----------------------------------------------------------------------
1987    #[test]
1988    fn feature_subsampling_works() {
1989        let config = TreeConfig::new()
1990            .grace_period(30)
1991            .max_depth(3)
1992            .n_bins(16)
1993            .lambda(0.1)
1994            .delta(1e-2)
1995            .feature_subsample_rate(0.5);
1996
1997        let mut tree = HoeffdingTree::new(config);
1998        let mut rng_state: u64 = 33333;
1999
2000        // 5 features, only ~50% considered per split.
2001        for _ in 0..1000 {
2002            let feats: Vec<f64> = (0..5)
2003                .map(|_| test_rand_f64(&mut rng_state) * 10.0)
2004                .collect();
2005            let y: f64 = feats.iter().sum();
2006            let pred = tree.predict(&feats);
2007            tree.train_one(&feats, pred - y, 1.0);
2008        }
2009
2010        let pred = tree.predict(&[1.0, 2.0, 3.0, 4.0, 5.0]);
2011        assert!(pred.is_finite(), "subsampled prediction should be finite");
2012    }
2013
2014    // -----------------------------------------------------------------------
2015    // Test 8: xorshift64 produces deterministic sequence.
2016    // -----------------------------------------------------------------------
2017    #[test]
2018    fn xorshift64_deterministic() {
2019        let mut s1: u64 = 42;
2020        let mut s2: u64 = 42;
2021
2022        let seq1: Vec<u64> = (0..100).map(|_| xorshift64(&mut s1)).collect();
2023        let seq2: Vec<u64> = (0..100).map(|_| xorshift64(&mut s2)).collect();
2024
2025        assert_eq!(seq1, seq2, "xorshift64 should be deterministic");
2026
2027        // Verify no zeros in the sequence (xorshift64 with non-zero seed never produces 0).
2028        for &v in &seq1 {
2029            assert_ne!(v, 0, "xorshift64 should never produce 0 with non-zero seed");
2030        }
2031    }
2032
2033    // -----------------------------------------------------------------------
2034    // Test 9: EWMA leaf decay -- recent data dominates predictions.
2035    // -----------------------------------------------------------------------
2036    #[test]
2037    fn ewma_leaf_decay_recent_data_dominates() {
2038        // half_life=50 => alpha = exp(-ln(2)/50) ≈ 0.9862
2039        let alpha = (-(2.0_f64.ln()) / 50.0).exp();
2040        let config = TreeConfig::new()
2041            .grace_period(20)
2042            .max_depth(4)
2043            .n_bins(16)
2044            .lambda(1.0)
2045            .leaf_decay_alpha(alpha);
2046        let mut tree = HoeffdingTree::new(config);
2047
2048        // Phase 1: 1000 samples targeting 1.0
2049        for _ in 0..1000 {
2050            let pred = tree.predict(&[1.0, 2.0]);
2051            let grad = pred - 1.0; // gradient for squared loss
2052            tree.train_one(&[1.0, 2.0], grad, 1.0);
2053        }
2054
2055        // Phase 2: 100 samples targeting 5.0
2056        for _ in 0..100 {
2057            let pred = tree.predict(&[1.0, 2.0]);
2058            let grad = pred - 5.0;
2059            tree.train_one(&[1.0, 2.0], grad, 1.0);
2060        }
2061
2062        let pred = tree.predict(&[1.0, 2.0]);
2063        // With EWMA, the prediction should be pulled toward 5.0 (recent target).
2064        // Without EWMA, 1000 samples at 1.0 would dominate 100 at 5.0.
2065        assert!(
2066            pred > 2.0,
2067            "EWMA should let recent data (target=5.0) pull prediction above 2.0, got {}",
2068            pred,
2069        );
2070    }
2071
2072    // -----------------------------------------------------------------------
2073    // Test 10: EWMA disabled (None) matches traditional behavior.
2074    // -----------------------------------------------------------------------
2075    #[test]
2076    fn ewma_disabled_matches_traditional() {
2077        let config_no_ewma = TreeConfig::new()
2078            .grace_period(20)
2079            .max_depth(4)
2080            .n_bins(16)
2081            .lambda(1.0);
2082        let mut tree = HoeffdingTree::new(config_no_ewma);
2083
2084        let mut rng_state: u64 = 99999;
2085        for _ in 0..200 {
2086            let x = test_rand_f64(&mut rng_state) * 10.0;
2087            let y = 3.0 * x + 1.0;
2088            let pred = tree.predict(&[x]);
2089            tree.train_one(&[x], pred - y, 1.0);
2090        }
2091
2092        let pred = tree.predict(&[5.0]);
2093        assert!(
2094            pred.is_finite(),
2095            "prediction without EWMA should be finite, got {}",
2096            pred
2097        );
2098    }
2099
2100    // -----------------------------------------------------------------------
2101    // Test 11: Split re-evaluation at max depth grows beyond frozen point.
2102    // -----------------------------------------------------------------------
2103    #[test]
2104    fn split_reeval_at_max_depth() {
2105        let config = TreeConfig::new()
2106            .grace_period(20)
2107            .max_depth(2) // Very shallow to hit max depth quickly
2108            .n_bins(16)
2109            .lambda(1.0)
2110            .split_reeval_interval(50);
2111        let mut tree = HoeffdingTree::new(config);
2112
2113        let mut rng_state: u64 = 54321;
2114        // Train enough to saturate max_depth=2 and then trigger re-evaluation.
2115        for _ in 0..2000 {
2116            let x1 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2117            let x2 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2118            let y = 2.0 * x1 + 3.0 * x2;
2119            let pred = tree.predict(&[x1, x2]);
2120            tree.train_one(&[x1, x2], pred - y, 1.0);
2121        }
2122
2123        // With split_reeval_interval=50, max-depth leaves can re-evaluate
2124        // and potentially split beyond max_depth. The tree should have MORE
2125        // leaves than a max_depth=2 tree without re-eval (which caps at 4).
2126        let leaves = tree.n_leaves();
2127        assert!(
2128            leaves >= 4,
2129            "split re-eval should allow growth beyond max_depth=2 cap (4 leaves), got {}",
2130            leaves,
2131        );
2132    }
2133
2134    // -----------------------------------------------------------------------
2135    // Test 12: Split re-evaluation disabled matches existing behavior.
2136    // -----------------------------------------------------------------------
2137    #[test]
2138    fn split_reeval_disabled_matches_traditional() {
2139        let config = TreeConfig::new()
2140            .grace_period(20)
2141            .max_depth(2)
2142            .n_bins(16)
2143            .lambda(1.0);
2144        // No split_reeval_interval => None => traditional hard cap
2145        let mut tree = HoeffdingTree::new(config);
2146
2147        let mut rng_state: u64 = 77777;
2148        for _ in 0..2000 {
2149            let x1 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2150            let x2 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2151            let y = 2.0 * x1 + 3.0 * x2;
2152            let pred = tree.predict(&[x1, x2]);
2153            tree.train_one(&[x1, x2], pred - y, 1.0);
2154        }
2155
2156        // Without re-eval, max_depth=2 caps at 4 leaves (2^2).
2157        let leaves = tree.n_leaves();
2158        assert!(
2159            leaves <= 4,
2160            "without re-eval, max_depth=2 should cap at 4 leaves, got {}",
2161            leaves,
2162        );
2163    }
2164
2165    // -----------------------------------------------------------------------
2166    // Test: Gradient clipping clamps outliers
2167    // -----------------------------------------------------------------------
2168    #[test]
2169    fn gradient_clipping_clamps_outliers() {
2170        let config = TreeConfig::new()
2171            .grace_period(20)
2172            .max_depth(2)
2173            .n_bins(16)
2174            .gradient_clip_sigma(2.0);
2175
2176        let mut tree = HoeffdingTree::new(config);
2177
2178        // Train 50 normal samples
2179        let mut rng_state = 42u64;
2180        for _ in 0..50 {
2181            let x = test_rand_f64(&mut rng_state) * 2.0;
2182            let grad = x * 0.1; // small gradients ~[0, 0.2]
2183            tree.train_one(&[x], grad, 1.0);
2184        }
2185
2186        let pred_before = tree.predict(&[1.0]);
2187
2188        // Now inject an extreme outlier gradient
2189        tree.train_one(&[1.0], 1000.0, 1.0);
2190
2191        let pred_after = tree.predict(&[1.0]);
2192
2193        // With clipping at 2-sigma, the outlier should be clamped.
2194        // Without clipping, the prediction would jump massively.
2195        // The change should be bounded.
2196        let delta = (pred_after - pred_before).abs();
2197        assert!(
2198            delta < 100.0,
2199            "gradient clipping should limit impact of outlier, but prediction changed by {}",
2200            delta,
2201        );
2202    }
2203
2204    // -----------------------------------------------------------------------
2205    // Test: clip_gradient function directly
2206    // -----------------------------------------------------------------------
2207    #[test]
2208    fn clip_gradient_welford_tracks_stats() {
2209        let mut state = LeafState::new(1);
2210
2211        // Feed 20 varied gradients to build up statistics
2212        for i in 0..20 {
2213            let grad = 1.0 + (i as f64) * 0.1; // range [1.0, 2.9]
2214            let clipped = clip_gradient(&mut state, grad, 3.0);
2215            // 3-sigma is very wide, so these should not be clipped
2216            assert!(
2217                (clipped - grad).abs() < 1e-10,
2218                "normal gradients should not be clipped at 3-sigma"
2219            );
2220        }
2221
2222        // Now an extreme outlier -- mean is ~1.95, std ~0.59, 3-sigma range is ~[0.18, 3.72]
2223        let clipped = clip_gradient(&mut state, 100.0, 3.0);
2224        assert!(
2225            clipped < 100.0,
2226            "extreme outlier should be clipped, got {}",
2227            clipped,
2228        );
2229        assert!(
2230            clipped > 0.0,
2231            "clipped value should be positive, got {}",
2232            clipped,
2233        );
2234    }
2235
2236    // -----------------------------------------------------------------------
2237    // Test: clip_gradient warmup period
2238    // -----------------------------------------------------------------------
2239    #[test]
2240    fn clip_gradient_warmup_no_clipping() {
2241        let mut state = LeafState::new(1);
2242
2243        // During warmup (< 10 samples), no clipping
2244        for i in 0..9 {
2245            let val = if i == 8 { 1000.0 } else { 1.0 };
2246            let clipped = clip_gradient(&mut state, val, 2.0);
2247            assert_eq!(clipped, val, "warmup should not clip");
2248        }
2249    }
2250
2251    // -----------------------------------------------------------------------
2252    // Test: adaptive_bound warmup returns f64::MAX
2253    // -----------------------------------------------------------------------
2254    #[test]
2255    fn adaptive_bound_warmup_returns_max() {
2256        let mut state = LeafState::new(1);
2257        // Feed < 10 output weights
2258        for i in 0..9 {
2259            update_output_stats(&mut state, 0.5 + i as f64 * 0.01, None);
2260        }
2261        let bound = adaptive_bound(&state, 3.0, None);
2262        assert_eq!(bound, f64::MAX, "warmup should return f64::MAX");
2263    }
2264
2265    // -----------------------------------------------------------------------
2266    // Test: adaptive_bound tightens after warmup (Welford path)
2267    // -----------------------------------------------------------------------
2268    #[test]
2269    fn adaptive_bound_tightens_after_warmup() {
2270        let mut state = LeafState::new(1);
2271        // Feed 20 outputs centered around 0.3 with small variance
2272        for i in 0..20 {
2273            let w = 0.3 + (i as f64 - 10.0) * 0.01; // range [0.2, 0.39]
2274            update_output_stats(&mut state, w, None);
2275        }
2276        let bound = adaptive_bound(&state, 3.0, None);
2277        // Bound should be much less than a global max of 3.0
2278        assert!(
2279            bound < 1.0,
2280            "3-sigma bound on outputs ~0.3 should be < 1.0, got {}",
2281            bound,
2282        );
2283        assert!(bound > 0.2, "bound should be > |mean|, got {}", bound,);
2284    }
2285
2286    // -----------------------------------------------------------------------
2287    // Test: adaptive_bound clamps outlier leaf
2288    // -----------------------------------------------------------------------
2289    #[test]
2290    fn adaptive_bound_clamps_outlier_leaf() {
2291        let mut state = LeafState::new(1);
2292        // Build stats: 20 outputs ~0.3
2293        for _ in 0..20 {
2294            update_output_stats(&mut state, 0.3, None);
2295        }
2296        let bound = adaptive_bound(&state, 3.0, None);
2297        // A leaf output of 2.9 should be clamped
2298        let clamped = (2.9_f64).clamp(-bound, bound);
2299        assert!(
2300            clamped < 2.9,
2301            "2.9 should be clamped by adaptive bound {}, got {}",
2302            bound,
2303            clamped,
2304        );
2305    }
2306
2307    // -----------------------------------------------------------------------
2308    // Test: adaptive_bound with EWMA decay adapts
2309    // -----------------------------------------------------------------------
2310    #[test]
2311    fn adaptive_bound_with_decay_adapts() {
2312        let alpha = 0.95; // fast decay for testing
2313        let mut state = LeafState::new(1);
2314
2315        // Phase 1: outputs around 0.3
2316        for _ in 0..30 {
2317            update_output_stats(&mut state, 0.3, Some(alpha));
2318        }
2319        let bound_phase1 = adaptive_bound(&state, 3.0, Some(alpha));
2320
2321        // Phase 2: outputs shift to 2.0
2322        for _ in 0..100 {
2323            update_output_stats(&mut state, 2.0, Some(alpha));
2324        }
2325        let bound_phase2 = adaptive_bound(&state, 3.0, Some(alpha));
2326
2327        // After regime change, bound should adapt upward
2328        assert!(
2329            bound_phase2 > bound_phase1,
2330            "EWMA bound should adapt: phase1={}, phase2={}",
2331            bound_phase1,
2332            bound_phase2,
2333        );
2334    }
2335
2336    // -----------------------------------------------------------------------
2337    // Test: adaptive_bound disabled by default
2338    // -----------------------------------------------------------------------
2339    #[test]
2340    fn adaptive_bound_disabled_by_default() {
2341        let config = TreeConfig::default();
2342        assert!(
2343            config.adaptive_leaf_bound.is_none(),
2344            "adaptive_leaf_bound should default to None",
2345        );
2346    }
2347
2348    // -----------------------------------------------------------------------
2349    // Test: adaptive_bound warmup falls back to global max_leaf_output
2350    // -----------------------------------------------------------------------
2351    #[test]
2352    fn adaptive_bound_warmup_falls_back_to_global() {
2353        let mut state = LeafState::new(1);
2354        // Only 5 samples — still in warmup
2355        for _ in 0..5 {
2356            update_output_stats(&mut state, 0.3, None);
2357        }
2358        let bound = adaptive_bound(&state, 3.0, None);
2359        assert_eq!(bound, f64::MAX, "warmup should yield f64::MAX");
2360        // In leaf_prediction, f64::MAX falls through to global max_leaf_output
2361    }
2362
2363    // -----------------------------------------------------------------------
2364    // Test: Monotonic constraints filter invalid splits
2365    // -----------------------------------------------------------------------
2366    #[test]
2367    fn monotonic_constraint_splits_respected() {
2368        // Train with +1 constraint on feature 0 (increasing).
2369        // Use a dataset where feature 0 has a negative relationship.
2370        let config = TreeConfig::new()
2371            .grace_period(30)
2372            .max_depth(4)
2373            .n_bins(16)
2374            .monotone_constraints(vec![1]); // feature 0 must be increasing
2375
2376        let mut tree = HoeffdingTree::new(config);
2377
2378        let mut rng_state = 42u64;
2379        for _ in 0..500 {
2380            let x = test_rand_f64(&mut rng_state) * 10.0;
2381            // Negative relationship: high x → low y → positive gradient
2382            let grad = x * 0.5 - 2.5;
2383            tree.train_one(&[x], grad, 1.0);
2384        }
2385
2386        // Any split that occurred should satisfy: left_value <= right_value
2387        // for the monotone +1 constraint. Verify prediction is non-decreasing.
2388        let pred_low = tree.predict(&[0.0]);
2389        let pred_mid = tree.predict(&[5.0]);
2390        let pred_high = tree.predict(&[10.0]);
2391
2392        // Due to constraint, prediction must be non-decreasing
2393        assert!(
2394            pred_low <= pred_mid + 1e-10 && pred_mid <= pred_high + 1e-10,
2395            "monotonic +1 violated: pred(0)={}, pred(5)={}, pred(10)={}",
2396            pred_low,
2397            pred_mid,
2398            pred_high,
2399        );
2400    }
2401
2402    // -----------------------------------------------------------------------
2403    // Test: predict_with_variance returns finite values
2404    // -----------------------------------------------------------------------
2405    #[test]
2406    fn predict_with_variance_finite() {
2407        let config = TreeConfig::new().grace_period(10);
2408        let mut tree = HoeffdingTree::new(config);
2409
2410        // Train a few samples
2411        for i in 0..30 {
2412            let x = i as f64 * 0.1;
2413            tree.train_one(&[x], x - 1.0, 1.0);
2414        }
2415
2416        let (value, variance) = tree.predict_with_variance(&[1.0]);
2417        assert!(value.is_finite(), "value should be finite");
2418        assert!(variance.is_finite(), "variance should be finite");
2419        assert!(variance > 0.0, "variance should be positive");
2420    }
2421
2422    // -----------------------------------------------------------------------
2423    // Test: predict_with_variance decreases with more data
2424    // -----------------------------------------------------------------------
2425    #[test]
2426    fn predict_with_variance_decreases_with_data() {
2427        let config = TreeConfig::new().grace_period(10);
2428        let mut tree = HoeffdingTree::new(config);
2429
2430        // Train 20 samples, check variance
2431        for i in 0..20 {
2432            tree.train_one(&[1.0], 0.5, 1.0);
2433            if i == 0 {
2434                continue;
2435            }
2436        }
2437        let (_, var_20) = tree.predict_with_variance(&[1.0]);
2438
2439        // Train 200 more samples
2440        for _ in 0..200 {
2441            tree.train_one(&[1.0], 0.5, 1.0);
2442        }
2443        let (_, var_220) = tree.predict_with_variance(&[1.0]);
2444
2445        assert!(
2446            var_220 < var_20,
2447            "variance should decrease with more data: var@20={} vs var@220={}",
2448            var_20,
2449            var_220,
2450        );
2451    }
2452
2453    // -----------------------------------------------------------------------
2454    // Test: predict_smooth matches hard prediction at small bandwidth
2455    // -----------------------------------------------------------------------
2456    #[test]
2457    fn predict_smooth_matches_hard_at_small_bandwidth() {
2458        let config = TreeConfig::new()
2459            .max_depth(3)
2460            .n_bins(16)
2461            .grace_period(20)
2462            .lambda(1.0);
2463        let mut tree = HoeffdingTree::new(config);
2464
2465        // Train enough to get splits
2466        let mut rng = 42u64;
2467        for _ in 0..500 {
2468            let x = test_rand_f64(&mut rng) * 10.0;
2469            let y = 2.0 * x + 1.0;
2470            let features = vec![x, x * 0.5];
2471            let pred = tree.predict(&features);
2472            let grad = pred - y;
2473            let hess = 1.0;
2474            tree.train_one(&features, grad, hess);
2475        }
2476
2477        // With very small bandwidth, smooth prediction should approximate hard prediction
2478        let features = vec![5.0, 2.5];
2479        let hard = tree.predict(&features);
2480        let smooth = tree.predict_smooth(&features, 0.001);
2481        assert!(
2482            (hard - smooth).abs() < 0.1,
2483            "smooth with tiny bandwidth should approximate hard: hard={}, smooth={}",
2484            hard,
2485            smooth,
2486        );
2487    }
2488
2489    // -----------------------------------------------------------------------
2490    // Test: predict_smooth is continuous
2491    // -----------------------------------------------------------------------
2492    #[test]
2493    fn predict_smooth_is_continuous() {
2494        let config = TreeConfig::new()
2495            .max_depth(3)
2496            .n_bins(16)
2497            .grace_period(20)
2498            .lambda(1.0);
2499        let mut tree = HoeffdingTree::new(config);
2500
2501        // Train to get splits
2502        let mut rng = 42u64;
2503        for _ in 0..500 {
2504            let x = test_rand_f64(&mut rng) * 10.0;
2505            let y = 2.0 * x + 1.0;
2506            let features = vec![x, x * 0.5];
2507            let pred = tree.predict(&features);
2508            let grad = pred - y;
2509            tree.train_one(&features, grad, 1.0);
2510        }
2511
2512        // Check that small input changes produce small output changes (continuity)
2513        let bandwidth = 1.0;
2514        let base = tree.predict_smooth(&[5.0, 2.5], bandwidth);
2515        let nudged = tree.predict_smooth(&[5.001, 2.5], bandwidth);
2516        let diff = (base - nudged).abs();
2517        assert!(
2518            diff < 0.1,
2519            "smooth prediction should be continuous: base={}, nudged={}, diff={}",
2520            base,
2521            nudged,
2522            diff,
2523        );
2524    }
2525
2526    // -----------------------------------------------------------------------
2527    // Test: leaf_grad_hess returns valid sums after training.
2528    // -----------------------------------------------------------------------
2529    #[test]
2530    fn leaf_grad_hess_returns_sums() {
2531        let config = TreeConfig::new().grace_period(100).lambda(1.0);
2532        let mut tree = HoeffdingTree::new(config);
2533
2534        let features = vec![1.0, 2.0, 3.0];
2535
2536        // Train several samples with known gradients
2537        for _ in 0..10 {
2538            tree.train_one(&features, -0.5, 1.0);
2539        }
2540
2541        // The root should be a leaf (grace_period=100, only 10 samples)
2542        let root = tree.root();
2543        let (grad, hess) = tree
2544            .leaf_grad_hess(root)
2545            .expect("root should have leaf state");
2546
2547        // grad_sum should be sum of all gradients: 10 * (-0.5) = -5.0
2548        assert!(
2549            (grad - (-5.0)).abs() < 1e-10,
2550            "grad_sum should be -5.0, got {}",
2551            grad
2552        );
2553        // hess_sum should be sum of all hessians: 10 * 1.0 = 10.0
2554        assert!(
2555            (hess - 10.0).abs() < 1e-10,
2556            "hess_sum should be 10.0, got {}",
2557            hess
2558        );
2559    }
2560
2561    #[test]
2562    fn leaf_grad_hess_returns_none_for_invalid_node() {
2563        let config = TreeConfig::new();
2564        let tree = HoeffdingTree::new(config);
2565
2566        // NodeId::NONE should return None
2567        assert!(tree.leaf_grad_hess(NodeId::NONE).is_none());
2568        // A non-existent node should return None
2569        assert!(tree.leaf_grad_hess(NodeId(999)).is_none());
2570    }
2571
2572    // -----------------------------------------------------------------------
2573    // Adaptive depth (per-split information criterion) tests
2574    // -----------------------------------------------------------------------
2575
2576    #[test]
2577    fn adaptive_depth_none_identical_to_static_max_depth() {
2578        // With adaptive_depth = None, behavior should be identical to the
2579        // classic static max_depth check.
2580        let config_static = TreeConfig::new()
2581            .max_depth(3)
2582            .n_bins(32)
2583            .grace_period(20)
2584            .lambda(0.1)
2585            .delta(1e-3);
2586
2587        let config_none = TreeConfig::new()
2588            .max_depth(3)
2589            .n_bins(32)
2590            .grace_period(20)
2591            .lambda(0.1)
2592            .delta(1e-3);
2593
2594        // Verify adaptive_depth is None by default
2595        assert!(config_none.adaptive_depth.is_none());
2596
2597        let mut tree_static = HoeffdingTree::new(config_static);
2598        let mut tree_none = HoeffdingTree::new(config_none);
2599
2600        let mut rng_state: u64 = 42;
2601        for _ in 0..2000 {
2602            let x = test_rand_f64(&mut rng_state) * 10.0;
2603            let y = 2.0 * x;
2604            let feat = [x, x * 0.5, x * x];
2605            let pred_s = tree_static.predict(&feat);
2606            let pred_n = tree_none.predict(&feat);
2607            tree_static.train_one(&feat, pred_s - y, 1.0);
2608            tree_none.train_one(&feat, pred_n - y, 1.0);
2609        }
2610
2611        // Both trees should have the same number of nodes
2612        assert_eq!(
2613            tree_static.arena().n_nodes(),
2614            tree_none.arena().n_nodes(),
2615            "adaptive_depth=None should produce identical tree structure to static max_depth"
2616        );
2617    }
2618
2619    #[test]
2620    fn adaptive_depth_few_samples_stays_shallow() {
2621        // With adaptive_depth enabled and few noisy samples,
2622        // the CIR penalty should prevent deep splits.
2623        let config = TreeConfig::new()
2624            .max_depth(6)
2625            .n_bins(32)
2626            .grace_period(20)
2627            .lambda(0.1)
2628            .delta(1e-3)
2629            .adaptive_depth(7.5);
2630
2631        let mut tree = HoeffdingTree::new(config);
2632        let mut rng_state: u64 = 99;
2633
2634        // Feed a small number of noisy samples — mostly noise, weak signal
2635        for _ in 0..100 {
2636            let x = test_rand_f64(&mut rng_state) * 10.0;
2637            let noise = (test_rand_f64(&mut rng_state) - 0.5) * 20.0; // large noise
2638            let y = 0.1 * x + noise;
2639            let feat = [x, test_rand_f64(&mut rng_state) * 5.0];
2640            let pred = tree.predict(&feat);
2641            tree.train_one(&feat, pred - y, 1.0);
2642        }
2643
2644        // With few noisy samples and CIR penalty, the tree should stay shallower
2645        // than the hard ceiling of max_depth*2 = 12. In fact, with this much noise
2646        // and few samples, we expect very few splits.
2647        let n_nodes = tree.arena().n_nodes();
2648        assert!(
2649            n_nodes <= 15,
2650            "adaptive_depth with few noisy samples should keep tree shallow, got {} nodes",
2651            n_nodes
2652        );
2653    }
2654
2655    #[test]
2656    fn adaptive_depth_many_samples_grows_deeper() {
2657        // With many clean samples, the CIR penalty decays (1/n) and the tree
2658        // should grow deeper than with few samples.
2659        let config_few = TreeConfig::new()
2660            .max_depth(6)
2661            .n_bins(32)
2662            .grace_period(20)
2663            .lambda(0.1)
2664            .delta(1e-3)
2665            .adaptive_depth(7.5);
2666
2667        let config_many = TreeConfig::new()
2668            .max_depth(6)
2669            .n_bins(32)
2670            .grace_period(20)
2671            .lambda(0.1)
2672            .delta(1e-3)
2673            .adaptive_depth(7.5);
2674
2675        let mut tree_few = HoeffdingTree::new(config_few);
2676        let mut tree_many = HoeffdingTree::new(config_many);
2677
2678        let mut rng_state: u64 = 42;
2679
2680        // Strong signal: y = 3*x1 + 2*x2 (clean)
2681        // Train "few" with 200 samples
2682        for _ in 0..200 {
2683            let x1 = test_rand_f64(&mut rng_state) * 10.0;
2684            let x2 = test_rand_f64(&mut rng_state) * 5.0;
2685            let y = 3.0 * x1 + 2.0 * x2;
2686            let feat = [x1, x2];
2687            let pred = tree_few.predict(&feat);
2688            tree_few.train_one(&feat, pred - y, 1.0);
2689        }
2690
2691        // Train "many" with 5000 samples
2692        let mut rng_state2: u64 = 42;
2693        for _ in 0..5000 {
2694            let x1 = test_rand_f64(&mut rng_state2) * 10.0;
2695            let x2 = test_rand_f64(&mut rng_state2) * 5.0;
2696            let y = 3.0 * x1 + 2.0 * x2;
2697            let feat = [x1, x2];
2698            let pred = tree_many.predict(&feat);
2699            tree_many.train_one(&feat, pred - y, 1.0);
2700        }
2701
2702        // The tree with more samples should have at least as many nodes
2703        // (penalty = cir_factor * var / n * n_feat decays with n)
2704        assert!(
2705            tree_many.arena().n_nodes() >= tree_few.arena().n_nodes(),
2706            "more samples should allow deeper growth: many={} vs few={}",
2707            tree_many.arena().n_nodes(),
2708            tree_few.arena().n_nodes()
2709        );
2710    }
2711
2712    #[test]
2713    fn adaptive_depth_penalty_scales_inversely_with_n() {
2714        // Directly verify the penalty math: penalty = cir_factor * grad_var / n * n_feat
2715        // With fixed cir_factor=7.5, grad_var=1.0, n_feat=2:
2716        //   n=100  -> penalty = 7.5 * 1.0 / 100 * 2 = 0.15
2717        //   n=1000 -> penalty = 7.5 * 1.0 / 1000 * 2 = 0.015
2718        // So a gain of 0.05 would fail at n=100 but pass at n=1000.
2719        let cir_factor: f64 = 7.5;
2720        let grad_var: f64 = 1.0;
2721        let n_feat: f64 = 2.0;
2722
2723        let penalty_100 = cir_factor * grad_var / 100.0 * n_feat;
2724        let penalty_1000 = cir_factor * grad_var / 1000.0 * n_feat;
2725
2726        assert!(
2727            (penalty_100 - 0.15).abs() < 1e-10,
2728            "penalty at n=100 should be 0.15, got {}",
2729            penalty_100
2730        );
2731        assert!(
2732            (penalty_1000 - 0.015).abs() < 1e-10,
2733            "penalty at n=1000 should be 0.015, got {}",
2734            penalty_1000
2735        );
2736        assert!(
2737            penalty_100 > penalty_1000,
2738            "penalty should decrease with more samples"
2739        );
2740
2741        // A gain of 0.05 should fail at n=100 (0.05 <= 0.15) but pass at n=1000 (0.05 > 0.015)
2742        let gain = 0.05;
2743        assert!(gain <= penalty_100, "gain should fail CIR at n=100");
2744        assert!(gain > penalty_1000, "gain should pass CIR at n=1000");
2745    }
2746
2747    #[test]
2748    fn adaptive_depth_hard_ceiling_respected() {
2749        // Even with adaptive_depth enabled, trees should never exceed max_depth * 2.
2750        let config = TreeConfig::new()
2751            .max_depth(3)
2752            .n_bins(32)
2753            .grace_period(10)
2754            .lambda(0.01)
2755            .gamma(0.0)
2756            .delta(1e-2) // Loose bound for easy splitting
2757            .adaptive_depth(0.001); // Very small factor = almost no CIR penalty
2758
2759        let mut tree = HoeffdingTree::new(config);
2760        let mut rng_state: u64 = 777;
2761
2762        // Train with very strong, clean signal to maximize splitting
2763        for _ in 0..10000 {
2764            let x = test_rand_f64(&mut rng_state) * 100.0;
2765            let y = x * x; // Strong nonlinear signal
2766            let feat = [x];
2767            let pred = tree.predict(&feat);
2768            tree.train_one(&feat, pred - y, 1.0);
2769        }
2770
2771        // Hard ceiling = max_depth * 2 = 6, so max leaves = 2^6 = 64
2772        let max_leaves = 1usize << 6;
2773        let n_leaves = tree.arena().n_leaves();
2774        assert!(
2775            n_leaves <= max_leaves,
2776            "tree should respect hard ceiling of max_depth*2=6 ({} max leaves), got {} leaves",
2777            max_leaves,
2778            n_leaves
2779        );
2780    }
2781
2782    // -----------------------------------------------------------------------
2783    // Soft routing tests
2784    // -----------------------------------------------------------------------
2785
2786    #[test]
2787    fn soft_routed_prediction_finite() {
2788        let config = TreeConfig::new().grace_period(20).max_depth(4).n_bins(16);
2789        let mut tree = HoeffdingTree::new(config);
2790        let mut rng: u64 = 42;
2791
2792        // Train until splits happen
2793        for _ in 0..2000 {
2794            let x = test_rand_f64(&mut rng) * 10.0;
2795            let y = x * 2.0 + 1.0;
2796            let feat = [x, x * 0.5];
2797            let pred = tree.predict(&feat);
2798            tree.train_one(&feat, pred - y, 1.0);
2799        }
2800
2801        // Verify soft-routed prediction is finite
2802        for i in 0..20 {
2803            let x = i as f64 * 0.5;
2804            let pred = tree.predict_soft_routed(&[x, x * 0.5]);
2805            assert!(
2806                pred.is_finite(),
2807                "soft-routed prediction should be finite at x={}, got {}",
2808                x,
2809                pred
2810            );
2811        }
2812    }
2813
2814    #[test]
2815    fn soft_routed_smoother_than_hard() {
2816        let config = TreeConfig::new()
2817            .grace_period(20)
2818            .max_depth(4)
2819            .n_bins(32)
2820            .delta(1e-2);
2821        let mut tree = HoeffdingTree::new(config);
2822        let mut rng: u64 = 123;
2823
2824        // Train on y = sin(x) with enough data for splits
2825        for _ in 0..5000 {
2826            let x = test_rand_f64(&mut rng) * 6.283; // 0..2pi
2827            let y = math::sin(x);
2828            let feat = [x];
2829            let pred = tree.predict(&feat);
2830            tree.train_one(&feat, pred - y, 1.0);
2831        }
2832
2833        // Ensure we actually have splits
2834        if tree.arena().n_nodes() < 3 {
2835            return; // Tree didn't split -- can't test smoothness
2836        }
2837
2838        // Sweep and compute total variation for both methods
2839        let n_points = 100;
2840        let hard_preds: Vec<f64> = (0..n_points)
2841            .map(|i| {
2842                let x = i as f64 * 6.283 / n_points as f64;
2843                tree.predict(&[x])
2844            })
2845            .collect();
2846        let soft_preds: Vec<f64> = (0..n_points)
2847            .map(|i| {
2848                let x = i as f64 * 6.283 / n_points as f64;
2849                tree.predict_soft_routed(&[x])
2850            })
2851            .collect();
2852
2853        let hard_tv: f64 = hard_preds.windows(2).map(|w| math::abs(w[1] - w[0])).sum();
2854        let soft_tv: f64 = soft_preds.windows(2).map(|w| math::abs(w[1] - w[0])).sum();
2855
2856        assert!(
2857            soft_tv <= hard_tv + 1e-10,
2858            "soft routing should have <= total variation than hard: soft_tv={}, hard_tv={}",
2859            soft_tv,
2860            hard_tv
2861        );
2862    }
2863
2864    #[test]
2865    fn node_bandwidths_computed_after_splits() {
2866        let config = TreeConfig::new().grace_period(20).max_depth(4).n_bins(16);
2867        let mut tree = HoeffdingTree::new(config);
2868        let mut rng: u64 = 999;
2869
2870        // Train until splits happen
2871        for _ in 0..3000 {
2872            let x = test_rand_f64(&mut rng) * 10.0;
2873            let y = x * x;
2874            let feat = [x];
2875            let pred = tree.predict(&feat);
2876            tree.train_one(&feat, pred - y, 1.0);
2877        }
2878
2879        if tree.arena().n_nodes() < 3 {
2880            return; // No splits -- can't test bandwidths
2881        }
2882
2883        // Verify node_bandwidths is populated
2884        assert!(
2885            !tree.node_bandwidths.is_empty(),
2886            "node_bandwidths should be non-empty after splits"
2887        );
2888
2889        // Verify internal nodes have finite positive bandwidths
2890        let mut found_finite = false;
2891        for i in 0..tree.arena().n_nodes() {
2892            let nid = NodeId(i as u32);
2893            if !tree.arena().is_leaf(nid) {
2894                let bw = tree.node_bandwidths[i];
2895                assert!(
2896                    bw > 0.0,
2897                    "bandwidth for internal node {} should be > 0, got {}",
2898                    i,
2899                    bw
2900                );
2901                if bw.is_finite() {
2902                    found_finite = true;
2903                }
2904            }
2905        }
2906        assert!(
2907            found_finite,
2908            "at least one internal node should have a finite bandwidth"
2909        );
2910    }
2911
2912    #[test]
2913    fn soft_routed_agrees_at_training_points() {
2914        let config = TreeConfig::new().grace_period(20).max_depth(3).n_bins(16);
2915        let mut tree = HoeffdingTree::new(config);
2916
2917        // Train on simple data
2918        let training_data: Vec<(f64, f64)> = (0..500)
2919            .map(|i| {
2920                let x = i as f64 * 0.02;
2921                let y = x * 3.0 + 1.0;
2922                (x, y)
2923            })
2924            .collect();
2925
2926        for &(x, y) in &training_data {
2927            let feat = [x];
2928            let pred = tree.predict(&feat);
2929            tree.train_one(&feat, pred - y, 1.0);
2930        }
2931
2932        // At training points far from boundaries, hard and soft should be similar
2933        let mut total_diff = 0.0;
2934        let mut count = 0;
2935        for &(x, _) in &training_data {
2936            let feat = [x];
2937            let hard = tree.predict(&feat);
2938            let soft = tree.predict_soft_routed(&feat);
2939            if hard.is_finite() && soft.is_finite() {
2940                total_diff += math::abs(hard - soft);
2941                count += 1;
2942            }
2943        }
2944
2945        if count > 0 {
2946            let mean_diff = total_diff / count as f64;
2947            // Soft routing is a blend, so predictions will differ somewhat,
2948            // but the mean difference should be bounded.
2949            assert!(
2950                mean_diff < 5.0,
2951                "mean difference between hard and soft should be reasonable, got {}",
2952                mean_diff
2953            );
2954        }
2955    }
2956}