Skip to main content

irithyll_core/ensemble/
mod.rs

1//! SGBT ensemble orchestrator -- the core boosting loop.
2//!
3//! Implements Streaming Gradient Boosted Trees (Gunasekara et al., 2024):
4//! a sequence of boosting steps, each owning a streaming tree and drift detector,
5//! with automatic tree replacement when concept drift is detected.
6//!
7//! # Algorithm
8//!
9//! For each incoming sample `(x, y)`:
10//! 1. Compute the current ensemble prediction: `F(x) = base + lr * Σ tree_s(x)`
11//! 2. For each boosting step `s = 1..N`:
12//!    - Compute gradient `g = loss.gradient(y, current_pred)`
13//!    - Compute hessian `h = loss.hessian(y, current_pred)`
14//!    - Feed `(x, g, h)` to tree `s` (which internally uses weighted squared loss)
15//!    - Update `current_pred += lr * tree_s.predict(x)`
16//! 3. The ensemble adapts incrementally, with each tree targeting the residual
17//!    of all preceding trees.
18
19pub mod adaptive;
20pub mod adaptive_forest;
21pub mod bagged;
22pub mod config;
23pub mod distributional;
24pub mod lr_schedule;
25pub mod moe;
26pub mod moe_distributional;
27pub mod multi_target;
28pub mod multiclass;
29#[cfg(feature = "parallel")]
30pub mod parallel;
31pub mod quantile_regressor;
32pub mod replacement;
33pub mod stacked;
34pub mod step;
35pub mod variants;
36
37use alloc::boxed::Box;
38use alloc::string::String;
39use alloc::vec;
40use alloc::vec::Vec;
41
42use core::fmt;
43
44use crate::ensemble::config::SGBTConfig;
45use crate::ensemble::step::BoostingStep;
46use crate::loss::squared::SquaredLoss;
47use crate::loss::Loss;
48use crate::sample::Observation;
49#[allow(unused_imports)] // Used in doc links + tests
50use crate::sample::Sample;
51use crate::tree::builder::TreeConfig;
52
53/// Type alias for an SGBT model using dynamic (boxed) loss dispatch.
54///
55/// Use this when the loss function is determined at runtime (e.g., when
56/// deserializing a model from JSON where the loss type is stored as a tag).
57///
58/// For compile-time loss dispatch (preferred for performance), use
59/// `SGBT<LogisticLoss>`, `SGBT<HuberLoss>`, etc.
60pub type DynSGBT = SGBT<Box<dyn Loss>>;
61
62/// Streaming Gradient Boosted Trees ensemble.
63///
64/// The primary entry point for training and prediction. Generic over `L: Loss`
65/// so the loss function's gradient/hessian calls are monomorphized (inlined)
66/// into the boosting hot loop -- no virtual dispatch overhead.
67///
68/// The default type parameter `L = SquaredLoss` means `SGBT::new(config)`
69/// creates a regression model without specifying the loss type explicitly.
70///
71/// # Examples
72///
73/// ```text
74/// use irithyll::{SGBTConfig, SGBT};
75///
76/// // Regression with squared loss (default):
77/// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
78/// let model = SGBT::new(config);
79/// ```ignore
80///
81/// ```text
82/// use irithyll::{SGBTConfig, SGBT};
83/// use irithyll::loss::logistic::LogisticLoss;
84///
85/// // Classification with logistic loss -- no Box::new()!
86/// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
87/// let model = SGBT::with_loss(config, LogisticLoss);
88/// ```
89pub struct SGBT<L: Loss = SquaredLoss> {
90    /// Configuration.
91    config: SGBTConfig,
92    /// Boosting steps (one tree + drift detector each).
93    steps: Vec<BoostingStep>,
94    /// Loss function (monomorphized -- no vtable).
95    loss: L,
96    /// Base prediction (initial constant, computed from first batch of targets).
97    base_prediction: f64,
98    /// Whether base_prediction has been initialized.
99    base_initialized: bool,
100    /// Running collection of initial targets for computing base_prediction.
101    initial_targets: Vec<f64>,
102    /// Number of initial targets to collect before setting base_prediction.
103    initial_target_count: usize,
104    /// Total samples trained.
105    samples_seen: u64,
106    /// RNG state for variant skip logic.
107    rng_state: u64,
108    /// Per-step EWMA of |marginal contribution| for quality-based pruning.
109    /// Empty when `quality_prune_alpha` is `None`.
110    contribution_ewma: Vec<f64>,
111    /// Per-step consecutive low-contribution sample counter.
112    /// Empty when `quality_prune_alpha` is `None`.
113    low_contrib_count: Vec<u64>,
114    /// Rolling mean absolute error for error-weighted sample importance.
115    /// Only used when `error_weight_alpha` is `Some`.
116    rolling_mean_error: f64,
117    /// EWMA of contribution standard deviation (σ proxy for adaptive_mts).
118    /// Only updated when `adaptive_mts` is `Some`.
119    rolling_contribution_sigma: f64,
120    /// Per-feature auto-calibrated bandwidths for smooth prediction.
121    /// Computed from median split threshold gaps across all trees.
122    auto_bandwidths: Vec<f64>,
123    /// Sum of replacement counts across all steps at last bandwidth computation.
124    /// Used to detect when trees have been replaced and bandwidths need refresh.
125    last_replacement_sum: u64,
126}
127
128impl<L: Loss + Clone> Clone for SGBT<L> {
129    fn clone(&self) -> Self {
130        Self {
131            config: self.config.clone(),
132            steps: self.steps.clone(),
133            loss: self.loss.clone(),
134            base_prediction: self.base_prediction,
135            base_initialized: self.base_initialized,
136            initial_targets: self.initial_targets.clone(),
137            initial_target_count: self.initial_target_count,
138            samples_seen: self.samples_seen,
139            rng_state: self.rng_state,
140            contribution_ewma: self.contribution_ewma.clone(),
141            low_contrib_count: self.low_contrib_count.clone(),
142            rolling_mean_error: self.rolling_mean_error,
143            rolling_contribution_sigma: self.rolling_contribution_sigma,
144            auto_bandwidths: self.auto_bandwidths.clone(),
145            last_replacement_sum: self.last_replacement_sum,
146        }
147    }
148}
149
150impl<L: Loss> fmt::Debug for SGBT<L> {
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        f.debug_struct("SGBT")
153            .field("n_steps", &self.steps.len())
154            .field("samples_seen", &self.samples_seen)
155            .field("base_prediction", &self.base_prediction)
156            .field("base_initialized", &self.base_initialized)
157            .finish()
158    }
159}
160
161// ---------------------------------------------------------------------------
162// Convenience constructor for the default loss (SquaredLoss)
163// ---------------------------------------------------------------------------
164
165impl SGBT<SquaredLoss> {
166    /// Create a new SGBT ensemble with squared loss (regression).
167    ///
168    /// This is the most common constructor. For classification or custom
169    /// losses, use [`with_loss`](SGBT::with_loss).
170    pub fn new(config: SGBTConfig) -> Self {
171        Self::with_loss(config, SquaredLoss)
172    }
173}
174
175// ---------------------------------------------------------------------------
176// General impl for all Loss types
177// ---------------------------------------------------------------------------
178
179impl<L: Loss> SGBT<L> {
180    /// Create a new SGBT ensemble with a specific loss function.
181    ///
182    /// The loss is stored by value (monomorphized), giving zero-cost
183    /// gradient/hessian dispatch.
184    ///
185    /// ```ignore
186    /// use irithyll::{SGBTConfig, SGBT};
187    /// use irithyll::loss::logistic::LogisticLoss;
188    ///
189    /// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
190    /// let model = SGBT::with_loss(config, LogisticLoss);
191    /// ```
192    pub fn with_loss(config: SGBTConfig, loss: L) -> Self {
193        let leaf_decay_alpha = config
194            .leaf_half_life
195            .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
196
197        let tree_config = TreeConfig::new()
198            .max_depth(config.max_depth)
199            .n_bins(config.n_bins)
200            .lambda(config.lambda)
201            .gamma(config.gamma)
202            .grace_period(config.grace_period)
203            .delta(config.delta)
204            .feature_subsample_rate(config.feature_subsample_rate)
205            .leaf_decay_alpha_opt(leaf_decay_alpha)
206            .split_reeval_interval_opt(config.split_reeval_interval)
207            .feature_types_opt(config.feature_types.clone())
208            .gradient_clip_sigma_opt(config.gradient_clip_sigma)
209            .monotone_constraints_opt(config.monotone_constraints.clone())
210            .max_leaf_output_opt(config.max_leaf_output)
211            .adaptive_leaf_bound_opt(config.adaptive_leaf_bound)
212            .adaptive_depth_opt(config.adaptive_depth)
213            .min_hessian_sum_opt(config.min_hessian_sum)
214            .leaf_model_type(config.leaf_model_type.clone());
215
216        let max_tree_samples = config.max_tree_samples;
217
218        let shadow_warmup = config.shadow_warmup.unwrap_or(0);
219        let steps: Vec<BoostingStep> = (0..config.n_steps)
220            .map(|i| {
221                let mut tc = tree_config.clone();
222                tc.seed = config.seed ^ (i as u64);
223                let detector = config.drift_detector.create();
224                if shadow_warmup > 0 {
225                    BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
226                } else {
227                    BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
228                }
229            })
230            .collect();
231
232        let seed = config.seed;
233        let initial_target_count = config.initial_target_count;
234        let n = config.n_steps;
235        let has_pruning =
236            config.quality_prune_alpha.is_some() || config.proactive_prune_interval.is_some();
237        Self {
238            config,
239            steps,
240            loss,
241            base_prediction: 0.0,
242            base_initialized: false,
243            initial_targets: Vec::new(),
244            initial_target_count,
245            samples_seen: 0,
246            rng_state: seed,
247            contribution_ewma: if has_pruning {
248                vec![0.0; n]
249            } else {
250                Vec::new()
251            },
252            low_contrib_count: if has_pruning { vec![0; n] } else { Vec::new() },
253            rolling_mean_error: 0.0,
254            rolling_contribution_sigma: 0.0,
255            auto_bandwidths: Vec::new(),
256            last_replacement_sum: 0,
257        }
258    }
259
260    /// Train on a single observation.
261    ///
262    /// Accepts any type implementing [`Observation`], including [`Sample`],
263    /// [`SampleRef`](crate::SampleRef), or tuples like `(&[f64], f64)` for
264    /// zero-copy training.
265    pub fn train_one(&mut self, sample: &impl Observation) {
266        self.samples_seen += 1;
267        let target = sample.target();
268        let features = sample.features();
269
270        // Initialize base prediction from first few targets
271        if !self.base_initialized {
272            self.initial_targets.push(target);
273            if self.initial_targets.len() >= self.initial_target_count {
274                self.base_prediction = self.loss.initial_prediction(&self.initial_targets);
275                self.base_initialized = true;
276                self.initial_targets.clear();
277                self.initial_targets.shrink_to_fit();
278            }
279        }
280
281        // Current prediction starts from base
282        let mut current_pred = self.base_prediction;
283
284        // Adaptive MTS: compute contribution variance and set effective max_tree_samples
285        if let Some((base_mts, k)) = self.config.adaptive_mts {
286            let sigma = self.contribution_variance(features);
287            self.rolling_contribution_sigma =
288                0.999 * self.rolling_contribution_sigma + 0.001 * sigma;
289
290            let normalized = if self.rolling_contribution_sigma > 1e-10 {
291                sigma / self.rolling_contribution_sigma
292            } else {
293                1.0
294            };
295            let factor = 1.0 / (1.0 + k * normalized);
296            let effective_mts =
297                ((base_mts as f64) * factor).max(self.config.grace_period as f64 * 2.0) as u64;
298            for step in &mut self.steps {
299                step.slot_mut().set_max_tree_samples(Some(effective_mts));
300            }
301        }
302
303        let prune_alpha = self
304            .config
305            .quality_prune_alpha
306            .or_else(|| self.config.proactive_prune_interval.map(|_| 0.01));
307        let prune_threshold = self.config.quality_prune_threshold;
308        let prune_patience = self.config.quality_prune_patience;
309
310        // Error-weighted sample importance: compute weight from prediction error
311        let error_weight = if let Some(ew_alpha) = self.config.error_weight_alpha {
312            let abs_error = crate::math::abs(target - current_pred);
313            if self.rolling_mean_error > 1e-15 {
314                let w = (1.0 + abs_error / (self.rolling_mean_error + 1e-15)).min(10.0);
315                self.rolling_mean_error =
316                    ew_alpha * abs_error + (1.0 - ew_alpha) * self.rolling_mean_error;
317                w
318            } else {
319                self.rolling_mean_error = abs_error.max(1e-15);
320                1.0 // first sample, no reweighting
321            }
322        } else {
323            1.0
324        };
325
326        // Sequential boosting: each step targets the residual of all prior steps
327        for s in 0..self.steps.len() {
328            let gradient = self.loss.gradient(target, current_pred) * error_weight;
329            let hessian = self.loss.hessian(target, current_pred) * error_weight;
330            let train_count = self
331                .config
332                .variant
333                .train_count(hessian, &mut self.rng_state);
334
335            let step_pred =
336                self.steps[s].train_and_predict(features, gradient, hessian, train_count);
337
338            current_pred += self.config.learning_rate * step_pred;
339
340            // Quality-based tree pruning: track contribution and replace dead wood
341            if let Some(alpha) = prune_alpha {
342                let contribution = crate::math::abs(self.config.learning_rate * step_pred);
343                self.contribution_ewma[s] =
344                    alpha * contribution + (1.0 - alpha) * self.contribution_ewma[s];
345
346                if self.contribution_ewma[s] < prune_threshold {
347                    self.low_contrib_count[s] += 1;
348                    if self.low_contrib_count[s] >= prune_patience {
349                        self.steps[s].reset();
350                        self.contribution_ewma[s] = 0.0;
351                        self.low_contrib_count[s] = 0;
352                    }
353                } else {
354                    self.low_contrib_count[s] = 0;
355                }
356            }
357        }
358
359        // Proactive pruning: replace worst-contributing tree at interval
360        if let Some(interval) = self.config.proactive_prune_interval {
361            if self.samples_seen % interval == 0
362                && self.samples_seen > 0
363                && !self.contribution_ewma.is_empty()
364            {
365                let min_age = interval / 2;
366                let worst_idx = self
367                    .steps
368                    .iter()
369                    .enumerate()
370                    .zip(self.contribution_ewma.iter())
371                    .filter(|((_, step), _)| step.n_samples_seen() >= min_age)
372                    .min_by(|((_, _), a_ewma), ((_, _), b_ewma)| {
373                        a_ewma
374                            .partial_cmp(b_ewma)
375                            .unwrap_or(core::cmp::Ordering::Equal)
376                    })
377                    .map(|((i, _), _)| i);
378
379                if let Some(idx) = worst_idx {
380                    self.steps[idx].reset();
381                    self.contribution_ewma[idx] = 0.0;
382                    self.low_contrib_count[idx] = 0;
383                }
384            }
385        }
386
387        // Refresh auto-bandwidths when trees have been replaced or not yet computed.
388        self.refresh_bandwidths();
389    }
390
391    /// Train on a batch of observations.
392    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
393        for sample in samples {
394            self.train_one(sample);
395        }
396    }
397
398    /// Train on a batch with periodic callback for cooperative yielding.
399    ///
400    /// The callback is invoked every `interval` samples with the number of
401    /// samples processed so far. This allows long-running training to yield
402    /// to other tasks in an async runtime, update progress bars, or perform
403    /// periodic checkpointing.
404    ///
405    /// # Example
406    ///
407    /// ```ignore
408    /// use irithyll::{SGBTConfig, SGBT};
409    ///
410    /// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
411    /// let mut model = SGBT::new(config);
412    /// let data: Vec<(Vec<f64>, f64)> = Vec::new(); // your data
413    ///
414    /// model.train_batch_with_callback(&data, 1000, |processed| {
415    ///     println!("Trained {} samples", processed);
416    /// });
417    /// ```
418    pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
419        &mut self,
420        samples: &[O],
421        interval: usize,
422        mut callback: F,
423    ) {
424        let interval = interval.max(1); // Prevent zero interval
425        for (i, sample) in samples.iter().enumerate() {
426            self.train_one(sample);
427            if (i + 1) % interval == 0 {
428                callback(i + 1);
429            }
430        }
431        // Final callback if the total isn't a multiple of interval
432        let total = samples.len();
433        if total % interval != 0 {
434            callback(total);
435        }
436    }
437
438    /// Train on a random subsample of a batch using reservoir sampling.
439    ///
440    /// When `max_samples < samples.len()`, selects a representative subset
441    /// using Algorithm R (Vitter, 1985) -- a uniform random sample without
442    /// replacement. The selected samples are then trained in their original
443    /// order to preserve sequential dependencies.
444    ///
445    /// This is ideal for large replay buffers where training on the full
446    /// dataset is prohibitively slow but a representative subset gives
447    /// equivalent model quality (e.g., 1M of 4.3M samples with R²=0.997).
448    ///
449    /// When `max_samples >= samples.len()`, all samples are trained.
450    pub fn train_batch_subsampled<O: Observation>(&mut self, samples: &[O], max_samples: usize) {
451        if max_samples >= samples.len() {
452            self.train_batch(samples);
453            return;
454        }
455
456        // Reservoir sampling (Algorithm R) to select indices
457        let mut reservoir: Vec<usize> = (0..max_samples).collect();
458        let mut rng = self.rng_state;
459
460        for i in max_samples..samples.len() {
461            // Generate random index in [0, i]
462            rng ^= rng << 13;
463            rng ^= rng >> 7;
464            rng ^= rng << 17;
465            let j = (rng % (i as u64 + 1)) as usize;
466            if j < max_samples {
467                reservoir[j] = i;
468            }
469        }
470
471        self.rng_state = rng;
472
473        // Sort to preserve original order (important for EWMA/drift state)
474        reservoir.sort_unstable();
475
476        // Train on the selected subset
477        for &idx in &reservoir {
478            self.train_one(&samples[idx]);
479        }
480    }
481
482    /// Train on a batch with both subsampling and periodic callbacks.
483    ///
484    /// Combines reservoir subsampling with cooperative yield points.
485    /// Ideal for long-running daemon training where you need both
486    /// efficiency (subsampling) and cooperation (yielding).
487    pub fn train_batch_subsampled_with_callback<O: Observation, F: FnMut(usize)>(
488        &mut self,
489        samples: &[O],
490        max_samples: usize,
491        interval: usize,
492        mut callback: F,
493    ) {
494        if max_samples >= samples.len() {
495            self.train_batch_with_callback(samples, interval, callback);
496            return;
497        }
498
499        // Reservoir sampling
500        let mut reservoir: Vec<usize> = (0..max_samples).collect();
501        let mut rng = self.rng_state;
502
503        for i in max_samples..samples.len() {
504            rng ^= rng << 13;
505            rng ^= rng >> 7;
506            rng ^= rng << 17;
507            let j = (rng % (i as u64 + 1)) as usize;
508            if j < max_samples {
509                reservoir[j] = i;
510            }
511        }
512
513        self.rng_state = rng;
514        reservoir.sort_unstable();
515
516        let interval = interval.max(1);
517        for (i, &idx) in reservoir.iter().enumerate() {
518            self.train_one(&samples[idx]);
519            if (i + 1) % interval == 0 {
520                callback(i + 1);
521            }
522        }
523        let total = reservoir.len();
524        if total % interval != 0 {
525            callback(total);
526        }
527    }
528
529    /// Predict the raw output for a feature vector.
530    ///
531    /// Always uses sigmoid-blended soft routing with auto-calibrated per-feature
532    /// bandwidths derived from median split threshold gaps. Features that have
533    /// never been split on use hard routing (bandwidth = infinity).
534    pub fn predict(&self, features: &[f64]) -> f64 {
535        let mut pred = self.base_prediction;
536        if self.auto_bandwidths.is_empty() {
537            // No bandwidths computed yet (no training) — hard routing fallback
538            for step in &self.steps {
539                pred += self.config.learning_rate * step.predict(features);
540            }
541        } else {
542            for step in &self.steps {
543                pred += self.config.learning_rate
544                    * step.predict_smooth_auto(features, &self.auto_bandwidths);
545            }
546        }
547        pred
548    }
549
550    /// Predict using sigmoid-blended soft routing with an explicit bandwidth.
551    ///
552    /// Uses a single bandwidth for all features. For auto-calibrated per-feature
553    /// bandwidths, use [`predict()`](SGBT::predict) which always uses smooth routing.
554    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
555        let mut pred = self.base_prediction;
556        for step in &self.steps {
557            pred += self.config.learning_rate * step.predict_smooth(features, bandwidth);
558        }
559        pred
560    }
561
562    /// Per-feature auto-calibrated bandwidths used by `predict()`.
563    ///
564    /// Empty before the first training sample. Each entry corresponds to a
565    /// feature index; `f64::INFINITY` means that feature has no splits and
566    /// uses hard routing.
567    pub fn auto_bandwidths(&self) -> &[f64] {
568        &self.auto_bandwidths
569    }
570
571    /// Predict with parent-leaf linear interpolation.
572    ///
573    /// Blends each leaf prediction with its parent's preserved prediction
574    /// based on sample count, preventing stale predictions from fresh leaves.
575    pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
576        let mut pred = self.base_prediction;
577        for step in &self.steps {
578            pred += self.config.learning_rate * step.predict_interpolated(features);
579        }
580        pred
581    }
582
583    /// Predict with sibling-based interpolation for feature-continuous predictions.
584    ///
585    /// At each split node near the threshold boundary, blends left and right
586    /// subtree predictions linearly based on distance from the threshold.
587    /// Uses auto-calibrated bandwidths as the interpolation margin.
588    /// Predictions vary continuously as features change, eliminating
589    /// step-function artifacts.
590    pub fn predict_sibling_interpolated(&self, features: &[f64]) -> f64 {
591        let mut pred = self.base_prediction;
592        for step in &self.steps {
593            pred += self.config.learning_rate
594                * step.predict_sibling_interpolated(features, &self.auto_bandwidths);
595        }
596        pred
597    }
598
599    /// Predict with graduated active-shadow blending.
600    ///
601    /// Smoothly transitions between active and shadow trees during replacement,
602    /// eliminating prediction dips. Requires `shadow_warmup` to be configured.
603    /// When disabled, equivalent to `predict()`.
604    pub fn predict_graduated(&self, features: &[f64]) -> f64 {
605        let mut pred = self.base_prediction;
606        for step in &self.steps {
607            pred += self.config.learning_rate * step.predict_graduated(features);
608        }
609        pred
610    }
611
612    /// Predict with graduated blending + sibling interpolation (premium path).
613    ///
614    /// Combines graduated active-shadow handoff (no prediction dips during
615    /// tree replacement) with feature-continuous sibling interpolation
616    /// (no step-function artifacts near split boundaries).
617    pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> f64 {
618        let mut pred = self.base_prediction;
619        for step in &self.steps {
620            pred += self.config.learning_rate
621                * step.predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
622        }
623        pred
624    }
625
626    /// Predict with loss transform applied (e.g., sigmoid for logistic loss).
627    pub fn predict_transformed(&self, features: &[f64]) -> f64 {
628        self.loss.predict_transform(self.predict(features))
629    }
630
631    /// Predict probability (alias for `predict_transformed`).
632    pub fn predict_proba(&self, features: &[f64]) -> f64 {
633        self.predict_transformed(features)
634    }
635
636    /// Predict with confidence estimation.
637    ///
638    /// Returns `(prediction, confidence)` where confidence = 1 / sqrt(sum_variance).
639    /// Higher confidence indicates more certain predictions (leaves have seen
640    /// more hessian mass). Confidence of 0.0 means the model has no information.
641    ///
642    /// This enables execution engines to modulate aggressiveness:
643    /// - High confidence + favorable prediction → act immediately
644    /// - Low confidence → fall back to simpler models or wait for more data
645    ///
646    /// The variance per tree is estimated as `1 / (H_sum + lambda)` at the
647    /// leaf where the sample lands. The ensemble variance is the sum of
648    /// per-tree variances (scaled by learning_rate²), and confidence is
649    /// the reciprocal of the standard deviation.
650    pub fn predict_with_confidence(&self, features: &[f64]) -> (f64, f64) {
651        let mut pred = self.base_prediction;
652        let mut total_variance = 0.0;
653        let lr2 = self.config.learning_rate * self.config.learning_rate;
654
655        for step in &self.steps {
656            let (value, variance) = step.predict_with_variance(features);
657            pred += self.config.learning_rate * value;
658            total_variance += lr2 * variance;
659        }
660
661        let confidence = if total_variance > 0.0 && total_variance.is_finite() {
662            1.0 / crate::math::sqrt(total_variance)
663        } else {
664            0.0
665        };
666
667        (pred, confidence)
668    }
669
670    /// Batch prediction.
671    pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
672        feature_matrix.iter().map(|f| self.predict(f)).collect()
673    }
674
675    /// Number of boosting steps.
676    pub fn n_steps(&self) -> usize {
677        self.steps.len()
678    }
679
680    /// Total trees (active + alternates).
681    pub fn n_trees(&self) -> usize {
682        self.steps.len() + self.steps.iter().filter(|s| s.has_alternate()).count()
683    }
684
685    /// Total leaves across all active trees.
686    pub fn total_leaves(&self) -> usize {
687        self.steps.iter().map(|s| s.n_leaves()).sum()
688    }
689
690    /// Total samples trained.
691    pub fn n_samples_seen(&self) -> u64 {
692        self.samples_seen
693    }
694
695    /// The current base prediction.
696    pub fn base_prediction(&self) -> f64 {
697        self.base_prediction
698    }
699
700    /// Whether the base prediction has been initialized.
701    pub fn is_initialized(&self) -> bool {
702        self.base_initialized
703    }
704
705    /// Access the configuration.
706    pub fn config(&self) -> &SGBTConfig {
707        &self.config
708    }
709
710    /// Set the learning rate for future boosting rounds.
711    ///
712    /// This allows external schedulers (e.g., [`lr_schedule::LRScheduler`]) to
713    /// adapt the rate over time without rebuilding the model.
714    ///
715    /// # Arguments
716    ///
717    /// * `lr` -- New learning rate (should be positive and finite)
718    #[inline]
719    pub fn set_learning_rate(&mut self, lr: f64) {
720        self.config.learning_rate = lr;
721    }
722
723    /// Immutable access to the boosting steps.
724    ///
725    /// Useful for model inspection and export (e.g., ONNX serialization).
726    pub fn steps(&self) -> &[BoostingStep] {
727        &self.steps
728    }
729
730    /// Immutable access to the loss function.
731    pub fn loss(&self) -> &L {
732        &self.loss
733    }
734
735    /// Feature importances based on accumulated split gains across all trees.
736    ///
737    /// Returns normalized importances (sum to 1.0) indexed by feature.
738    /// Returns an empty Vec if no splits have occurred yet.
739    pub fn feature_importances(&self) -> Vec<f64> {
740        // Aggregate split gains across all boosting steps.
741        let mut totals: Vec<f64> = Vec::new();
742        for step in &self.steps {
743            let gains = step.slot().split_gains();
744            if totals.is_empty() && !gains.is_empty() {
745                totals.resize(gains.len(), 0.0);
746            }
747            for (i, &g) in gains.iter().enumerate() {
748                if i < totals.len() {
749                    totals[i] += g;
750                }
751            }
752        }
753
754        // Normalize to sum to 1.0.
755        let sum: f64 = totals.iter().sum();
756        if sum > 0.0 {
757            totals.iter_mut().for_each(|v| *v /= sum);
758        }
759        totals
760    }
761
762    /// Feature names, if configured.
763    pub fn feature_names(&self) -> Option<&[String]> {
764        self.config.feature_names.as_deref()
765    }
766
767    /// Feature importances paired with their names.
768    ///
769    /// Returns `None` if feature names are not configured. Otherwise returns
770    /// `(name, importance)` pairs sorted by importance descending.
771    pub fn named_feature_importances(&self) -> Option<Vec<(String, f64)>> {
772        let names = self.config.feature_names.as_ref()?;
773        let importances = self.feature_importances();
774        let mut pairs: Vec<(String, f64)> = names
775            .iter()
776            .zip(importances.iter().chain(core::iter::repeat(&0.0)))
777            .map(|(n, &v)| (n.clone(), v))
778            .collect();
779        pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
780        Some(pairs)
781    }
782
783    /// Train on a single sample with named features.
784    ///
785    /// Converts a `HashMap<String, f64>` of named features into a positional
786    /// vector using the configured feature names. Missing features default to 0.0.
787    ///
788    /// # Panics
789    ///
790    /// Panics if `feature_names` is not configured.
791    #[cfg(feature = "std")]
792    pub fn train_one_named(
793        &mut self,
794        features: &std::collections::HashMap<alloc::string::String, f64>,
795        target: f64,
796    ) {
797        let names = self
798            .config
799            .feature_names
800            .as_ref()
801            .expect("train_one_named requires feature_names to be configured");
802        let vec: Vec<f64> = names
803            .iter()
804            .map(|name| features.get(name).copied().unwrap_or(0.0))
805            .collect();
806        self.train_one(&(&vec[..], target));
807    }
808
809    /// Predict with named features.
810    ///
811    /// Converts named features into a positional vector, same as `train_one_named`.
812    ///
813    /// # Panics
814    ///
815    /// Panics if `feature_names` is not configured.
816    #[cfg(feature = "std")]
817    pub fn predict_named(
818        &self,
819        features: &std::collections::HashMap<alloc::string::String, f64>,
820    ) -> f64 {
821        let names = self
822            .config
823            .feature_names
824            .as_ref()
825            .expect("predict_named requires feature_names to be configured");
826        let vec: Vec<f64> = names
827            .iter()
828            .map(|name| features.get(name).copied().unwrap_or(0.0))
829            .collect();
830        self.predict(&vec)
831    }
832
833    // NOTE: explain() and explain_named() require the `explain` module which
834    // lives in the full `irithyll` crate, not in `irithyll-core`. Those methods
835    // are provided via the re-export layer in `irithyll::ensemble`.
836
837    /// Compute tree contribution standard deviation (σ proxy for adaptive_mts).
838    ///
839    /// Measures how much individual tree predictions vary across the ensemble,
840    /// which serves as a proxy for model uncertainty in the base SGBT
841    /// (the distributional variant uses honest_sigma instead).
842    fn contribution_variance(&self, features: &[f64]) -> f64 {
843        let n = self.steps.len();
844        if n <= 1 {
845            return 0.0;
846        }
847
848        let lr = self.config.learning_rate;
849        let mut sum = 0.0;
850        let mut sq_sum = 0.0;
851        for step in &self.steps {
852            let c = lr * step.predict(features);
853            sum += c;
854            sq_sum += c * c;
855        }
856        let n_f = n as f64;
857        let mean = sum / n_f;
858        let var = (sq_sum / n_f) - (mean * mean);
859        // Bessel's correction + numerical safety
860        crate::math::sqrt((var.abs() * n_f / (n_f - 1.0)).max(0.0))
861    }
862
863    /// Refresh auto-bandwidths if any tree has been replaced since last computation.
864    fn refresh_bandwidths(&mut self) {
865        let current_sum: u64 = self.steps.iter().map(|s| s.slot().replacements()).sum();
866        if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
867            self.auto_bandwidths = self.compute_auto_bandwidths();
868            self.last_replacement_sum = current_sum;
869        }
870    }
871
872    /// Compute per-feature auto-calibrated bandwidths from all trees.
873    ///
874    /// For each feature, collects all split thresholds across all trees,
875    /// computes the median gap between consecutive unique thresholds, and
876    /// returns `median_gap * K` (K = 2.0).
877    ///
878    /// Edge cases:
879    /// - Feature with < 3 unique thresholds: `range / n_bins * K`
880    /// - Feature never split on (< 2 unique thresholds): `f64::INFINITY` (hard routing)
881    fn compute_auto_bandwidths(&self) -> Vec<f64> {
882        const K: f64 = 2.0;
883
884        // Determine n_features from the trees
885        let n_features = self
886            .steps
887            .iter()
888            .filter_map(|s| s.slot().active_tree().n_features())
889            .max()
890            .unwrap_or(0);
891
892        if n_features == 0 {
893            return Vec::new();
894        }
895
896        // Collect all thresholds from all trees per feature
897        let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
898
899        for step in &self.steps {
900            let tree_thresholds = step
901                .slot()
902                .active_tree()
903                .collect_split_thresholds_per_feature();
904            for (i, ts) in tree_thresholds.into_iter().enumerate() {
905                if i < n_features {
906                    all_thresholds[i].extend(ts);
907                }
908            }
909        }
910
911        let n_bins = self.config.n_bins as f64;
912
913        // Compute per-feature bandwidth
914        all_thresholds
915            .iter()
916            .map(|ts| {
917                if ts.is_empty() {
918                    return f64::INFINITY; // Never split on → hard routing
919                }
920
921                let mut sorted = ts.clone();
922                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
923                sorted.dedup_by(|a, b| crate::math::abs(*a - *b) < 1e-15);
924
925                if sorted.len() < 2 {
926                    return f64::INFINITY; // Single threshold → hard routing
927                }
928
929                // Compute gaps between consecutive unique thresholds
930                let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
931
932                if sorted.len() < 3 {
933                    // Fallback: feature_range / n_bins * K
934                    let range = sorted.last().unwrap() - sorted.first().unwrap();
935                    if range < 1e-15 {
936                        return f64::INFINITY;
937                    }
938                    return (range / n_bins) * K;
939                }
940
941                // Median gap
942                gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
943                let median_gap = if gaps.len() % 2 == 0 {
944                    (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
945                } else {
946                    gaps[gaps.len() / 2]
947                };
948
949                if median_gap < 1e-15 {
950                    f64::INFINITY
951                } else {
952                    median_gap * K
953                }
954            })
955            .collect()
956    }
957
958    /// Reset the ensemble to initial state.
959    pub fn reset(&mut self) {
960        for step in &mut self.steps {
961            step.reset();
962        }
963        self.base_prediction = 0.0;
964        self.base_initialized = false;
965        self.initial_targets.clear();
966        self.samples_seen = 0;
967        self.rng_state = self.config.seed;
968        self.rolling_contribution_sigma = 0.0;
969        self.auto_bandwidths.clear();
970        self.last_replacement_sum = 0;
971    }
972
973    /// Serialize the model into a [`ModelState`](crate::serde_support::ModelState).
974    ///
975    /// Auto-detects the [`LossType`](crate::loss::LossType) from the loss
976    /// function's [`Loss::loss_type()`] implementation.
977    ///
978    /// # Errors
979    ///
980    /// Returns [`IrithyllError::Serialization`](crate::IrithyllError::Serialization)
981    /// if the loss does not implement `loss_type()` (returns `None`). For custom
982    /// losses, use [`to_model_state_with`](Self::to_model_state_with) instead.
983    #[cfg(feature = "_serde_support")]
984    pub fn to_model_state(&self) -> crate::error::Result<crate::serde_support::ModelState> {
985        let loss_type = self.loss.loss_type().ok_or_else(|| {
986            crate::error::IrithyllError::Serialization(
987                "cannot auto-detect loss type for serialization: \
988                 implement Loss::loss_type() or use to_model_state_with()"
989                    .into(),
990            )
991        })?;
992        Ok(self.to_model_state_with(loss_type))
993    }
994
995    /// Serialize the model with an explicit [`LossType`](crate::loss::LossType) tag.
996    ///
997    /// Use this for custom loss functions that don't implement `loss_type()`.
998    #[cfg(feature = "_serde_support")]
999    pub fn to_model_state_with(
1000        &self,
1001        loss_type: crate::loss::LossType,
1002    ) -> crate::serde_support::ModelState {
1003        use crate::serde_support::{ModelState, StepSnapshot};
1004
1005        let steps = self
1006            .steps
1007            .iter()
1008            .map(|step| {
1009                let slot = step.slot();
1010                let tree_snap = snapshot_tree(slot.active_tree());
1011                let alt_snap = slot.alternate_tree().map(snapshot_tree);
1012                let drift_state = slot.detector().serialize_state();
1013                let alt_drift_state = slot.alt_detector().and_then(|d| d.serialize_state());
1014                StepSnapshot {
1015                    tree: tree_snap,
1016                    alternate_tree: alt_snap,
1017                    drift_state,
1018                    alt_drift_state,
1019                }
1020            })
1021            .collect();
1022
1023        ModelState {
1024            config: self.config.clone(),
1025            loss_type,
1026            base_prediction: self.base_prediction,
1027            base_initialized: self.base_initialized,
1028            initial_targets: self.initial_targets.clone(),
1029            initial_target_count: self.initial_target_count,
1030            samples_seen: self.samples_seen,
1031            rng_state: self.rng_state,
1032            steps,
1033            rolling_mean_error: self.rolling_mean_error,
1034            contribution_ewma: self.contribution_ewma.clone(),
1035            low_contrib_count: self.low_contrib_count.clone(),
1036        }
1037    }
1038}
1039
1040// ---------------------------------------------------------------------------
1041// DynSGBT: deserialization returns a dynamically-dispatched model
1042// ---------------------------------------------------------------------------
1043
1044#[cfg(feature = "_serde_support")]
1045impl SGBT<Box<dyn Loss>> {
1046    /// Reconstruct an SGBT model from a [`ModelState`](crate::serde_support::ModelState).
1047    ///
1048    /// Returns a [`DynSGBT`] (`SGBT<Box<dyn Loss>>`) because the concrete
1049    /// loss type is determined at runtime from the serialized tag.
1050    ///
1051    /// Rebuilds the full ensemble including tree topology and leaf values.
1052    /// Histogram accumulators are left empty and will rebuild from continued
1053    /// training. If drift detector state was serialized, it is restored;
1054    /// otherwise a fresh detector is created from the config.
1055    pub fn from_model_state(state: crate::serde_support::ModelState) -> Self {
1056        use crate::ensemble::replacement::TreeSlot;
1057
1058        let loss = state.loss_type.into_loss();
1059
1060        let leaf_decay_alpha = state
1061            .config
1062            .leaf_half_life
1063            .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
1064        let max_tree_samples = state.config.max_tree_samples;
1065
1066        let steps: Vec<BoostingStep> = state
1067            .steps
1068            .iter()
1069            .enumerate()
1070            .map(|(i, step_snap)| {
1071                let tree_config = TreeConfig::new()
1072                    .max_depth(state.config.max_depth)
1073                    .n_bins(state.config.n_bins)
1074                    .lambda(state.config.lambda)
1075                    .gamma(state.config.gamma)
1076                    .grace_period(state.config.grace_period)
1077                    .delta(state.config.delta)
1078                    .feature_subsample_rate(state.config.feature_subsample_rate)
1079                    .leaf_decay_alpha_opt(leaf_decay_alpha)
1080                    .split_reeval_interval_opt(state.config.split_reeval_interval)
1081                    .feature_types_opt(state.config.feature_types.clone())
1082                    .gradient_clip_sigma_opt(state.config.gradient_clip_sigma)
1083                    .monotone_constraints_opt(state.config.monotone_constraints.clone())
1084                    .adaptive_depth_opt(state.config.adaptive_depth)
1085                    .leaf_model_type(state.config.leaf_model_type.clone())
1086                    .seed(state.config.seed ^ (i as u64));
1087
1088                let active = rebuild_tree(&step_snap.tree, tree_config.clone());
1089                let alternate = step_snap
1090                    .alternate_tree
1091                    .as_ref()
1092                    .map(|snap| rebuild_tree(snap, tree_config.clone()));
1093
1094                let mut detector = state.config.drift_detector.create();
1095                if let Some(ref ds) = step_snap.drift_state {
1096                    detector.restore_state(ds);
1097                }
1098                let mut slot = TreeSlot::from_trees(
1099                    active,
1100                    alternate,
1101                    tree_config,
1102                    detector,
1103                    max_tree_samples,
1104                );
1105                if let Some(ref ads) = step_snap.alt_drift_state {
1106                    if let Some(alt_det) = slot.alt_detector_mut() {
1107                        alt_det.restore_state(ads);
1108                    }
1109                }
1110                BoostingStep::from_slot(slot)
1111            })
1112            .collect();
1113
1114        let n = steps.len();
1115        let has_pruning = state.config.quality_prune_alpha.is_some()
1116            || state.config.proactive_prune_interval.is_some();
1117
1118        // Restore pruning state if available, otherwise initialize
1119        let contribution_ewma = if !state.contribution_ewma.is_empty() {
1120            state.contribution_ewma
1121        } else if has_pruning {
1122            vec![0.0; n]
1123        } else {
1124            Vec::new()
1125        };
1126        let low_contrib_count = if !state.low_contrib_count.is_empty() {
1127            state.low_contrib_count
1128        } else if has_pruning {
1129            vec![0; n]
1130        } else {
1131            Vec::new()
1132        };
1133
1134        Self {
1135            config: state.config,
1136            steps,
1137            loss,
1138            base_prediction: state.base_prediction,
1139            base_initialized: state.base_initialized,
1140            initial_targets: state.initial_targets,
1141            initial_target_count: state.initial_target_count,
1142            samples_seen: state.samples_seen,
1143            rng_state: state.rng_state,
1144            contribution_ewma,
1145            low_contrib_count,
1146            rolling_mean_error: state.rolling_mean_error,
1147            rolling_contribution_sigma: 0.0,
1148            auto_bandwidths: Vec::new(),
1149            last_replacement_sum: 0,
1150        }
1151    }
1152}
1153
1154// ---------------------------------------------------------------------------
1155// Shared snapshot/rebuild helpers for serde (used by SGBT + DistributionalSGBT)
1156// ---------------------------------------------------------------------------
1157
1158/// Snapshot a [`HoeffdingTree`] into a serializable [`TreeSnapshot`].
1159#[cfg(feature = "_serde_support")]
1160pub(crate) fn snapshot_tree(
1161    tree: &crate::tree::hoeffding::HoeffdingTree,
1162) -> crate::serde_support::TreeSnapshot {
1163    use crate::serde_support::TreeSnapshot;
1164    use crate::tree::StreamingTree;
1165    let arena = tree.arena();
1166    TreeSnapshot {
1167        feature_idx: arena.feature_idx.clone(),
1168        threshold: arena.threshold.clone(),
1169        left: arena.left.iter().map(|id| id.0).collect(),
1170        right: arena.right.iter().map(|id| id.0).collect(),
1171        leaf_value: arena.leaf_value.clone(),
1172        is_leaf: arena.is_leaf.clone(),
1173        depth: arena.depth.clone(),
1174        sample_count: arena.sample_count.clone(),
1175        n_features: tree.n_features(),
1176        samples_seen: tree.n_samples_seen(),
1177        rng_state: tree.rng_state(),
1178        categorical_mask: arena.categorical_mask.clone(),
1179    }
1180}
1181
1182/// Rebuild a [`HoeffdingTree`] from a [`TreeSnapshot`] and a [`TreeConfig`].
1183#[cfg(feature = "_serde_support")]
1184pub(crate) fn rebuild_tree(
1185    snapshot: &crate::serde_support::TreeSnapshot,
1186    tree_config: TreeConfig,
1187) -> crate::tree::hoeffding::HoeffdingTree {
1188    use crate::tree::hoeffding::HoeffdingTree;
1189    use crate::tree::node::{NodeId, TreeArena};
1190
1191    let mut arena = TreeArena::new();
1192    let n = snapshot.feature_idx.len();
1193
1194    for i in 0..n {
1195        arena.feature_idx.push(snapshot.feature_idx[i]);
1196        arena.threshold.push(snapshot.threshold[i]);
1197        arena.left.push(NodeId(snapshot.left[i]));
1198        arena.right.push(NodeId(snapshot.right[i]));
1199        arena.leaf_value.push(snapshot.leaf_value[i]);
1200        arena.is_leaf.push(snapshot.is_leaf[i]);
1201        arena.depth.push(snapshot.depth[i]);
1202        arena.sample_count.push(snapshot.sample_count[i]);
1203        let mask = snapshot.categorical_mask.get(i).copied().flatten();
1204        arena.categorical_mask.push(mask);
1205    }
1206
1207    HoeffdingTree::from_arena(
1208        tree_config,
1209        arena,
1210        snapshot.n_features,
1211        snapshot.samples_seen,
1212        snapshot.rng_state,
1213    )
1214}
1215
1216#[cfg(test)]
1217mod tests {
1218    use super::*;
1219    use alloc::boxed::Box;
1220    use alloc::vec;
1221    use alloc::vec::Vec;
1222
1223    fn default_config() -> SGBTConfig {
1224        SGBTConfig::builder()
1225            .n_steps(10)
1226            .learning_rate(0.1)
1227            .grace_period(20)
1228            .max_depth(4)
1229            .n_bins(16)
1230            .build()
1231            .unwrap()
1232    }
1233
1234    #[test]
1235    fn new_model_predicts_zero() {
1236        let model = SGBT::new(default_config());
1237        let pred = model.predict(&[1.0, 2.0, 3.0]);
1238        assert!(pred.abs() < 1e-12);
1239    }
1240
1241    #[test]
1242    fn train_one_does_not_panic() {
1243        let mut model = SGBT::new(default_config());
1244        model.train_one(&Sample::new(vec![1.0, 2.0, 3.0], 5.0));
1245        assert_eq!(model.n_samples_seen(), 1);
1246    }
1247
1248    #[test]
1249    fn prediction_changes_after_training() {
1250        let mut model = SGBT::new(default_config());
1251        let features = vec![1.0, 2.0, 3.0];
1252        for i in 0..100 {
1253            model.train_one(&Sample::new(features.clone(), (i as f64) * 0.1));
1254        }
1255        let pred = model.predict(&features);
1256        assert!(pred.is_finite());
1257    }
1258
1259    #[test]
1260    fn linear_signal_rmse_improves() {
1261        let config = SGBTConfig::builder()
1262            .n_steps(20)
1263            .learning_rate(0.1)
1264            .grace_period(10)
1265            .max_depth(3)
1266            .n_bins(16)
1267            .build()
1268            .unwrap();
1269        let mut model = SGBT::new(config);
1270
1271        let mut rng: u64 = 12345;
1272        let mut early_errors = Vec::new();
1273        let mut late_errors = Vec::new();
1274
1275        for i in 0..500 {
1276            rng ^= rng << 13;
1277            rng ^= rng >> 7;
1278            rng ^= rng << 17;
1279            let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1280            rng ^= rng << 13;
1281            rng ^= rng >> 7;
1282            rng ^= rng << 17;
1283            let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1284            let target = 2.0 * x1 + 3.0 * x2;
1285
1286            let pred = model.predict(&[x1, x2]);
1287            let error = (pred - target).powi(2);
1288
1289            if (50..150).contains(&i) {
1290                early_errors.push(error);
1291            }
1292            if i >= 400 {
1293                late_errors.push(error);
1294            }
1295
1296            model.train_one(&Sample::new(vec![x1, x2], target));
1297        }
1298
1299        let early_rmse = (early_errors.iter().sum::<f64>() / early_errors.len() as f64).sqrt();
1300        let late_rmse = (late_errors.iter().sum::<f64>() / late_errors.len() as f64).sqrt();
1301
1302        assert!(
1303            late_rmse < early_rmse,
1304            "RMSE should decrease: early={:.4}, late={:.4}",
1305            early_rmse,
1306            late_rmse
1307        );
1308    }
1309
1310    #[test]
1311    fn train_batch_equivalent_to_sequential() {
1312        let config = default_config();
1313        let mut model_seq = SGBT::new(config.clone());
1314        let mut model_batch = SGBT::new(config);
1315
1316        let samples: Vec<Sample> = (0..20)
1317            .map(|i| {
1318                let x = i as f64 * 0.5;
1319                Sample::new(vec![x, x * 2.0], x * 3.0)
1320            })
1321            .collect();
1322
1323        for s in &samples {
1324            model_seq.train_one(s);
1325        }
1326        model_batch.train_batch(&samples);
1327
1328        let pred_seq = model_seq.predict(&[1.0, 2.0]);
1329        let pred_batch = model_batch.predict(&[1.0, 2.0]);
1330
1331        assert!(
1332            (pred_seq - pred_batch).abs() < 1e-10,
1333            "seq={}, batch={}",
1334            pred_seq,
1335            pred_batch
1336        );
1337    }
1338
1339    #[test]
1340    fn reset_returns_to_initial() {
1341        let mut model = SGBT::new(default_config());
1342        for i in 0..100 {
1343            model.train_one(&Sample::new(vec![1.0, 2.0], i as f64));
1344        }
1345        model.reset();
1346        assert_eq!(model.n_samples_seen(), 0);
1347        assert!(!model.is_initialized());
1348        assert!(model.predict(&[1.0, 2.0]).abs() < 1e-12);
1349    }
1350
1351    #[test]
1352    fn base_prediction_initializes() {
1353        let mut model = SGBT::new(default_config());
1354        for i in 0..50 {
1355            model.train_one(&Sample::new(vec![1.0], i as f64 + 100.0));
1356        }
1357        assert!(model.is_initialized());
1358        let expected = (100.0 + 149.0) / 2.0;
1359        assert!((model.base_prediction() - expected).abs() < 1.0);
1360    }
1361
1362    #[test]
1363    fn with_loss_uses_custom_loss() {
1364        use crate::loss::logistic::LogisticLoss;
1365        let model = SGBT::with_loss(default_config(), LogisticLoss);
1366        let pred = model.predict_transformed(&[1.0, 2.0]);
1367        assert!(
1368            (pred - 0.5).abs() < 1e-6,
1369            "sigmoid(0) should be 0.5, got {}",
1370            pred
1371        );
1372    }
1373
1374    #[test]
1375    fn ewma_config_propagates_and_trains() {
1376        let config = SGBTConfig::builder()
1377            .n_steps(5)
1378            .learning_rate(0.1)
1379            .grace_period(10)
1380            .max_depth(3)
1381            .n_bins(16)
1382            .leaf_half_life(50)
1383            .build()
1384            .unwrap();
1385        let mut model = SGBT::new(config);
1386
1387        for i in 0..200 {
1388            let x = (i as f64) * 0.1;
1389            model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
1390        }
1391
1392        let pred = model.predict(&[1.0, 2.0]);
1393        assert!(
1394            pred.is_finite(),
1395            "EWMA-enabled model should produce finite predictions, got {}",
1396            pred
1397        );
1398    }
1399
1400    #[test]
1401    fn max_tree_samples_config_propagates() {
1402        let config = SGBTConfig::builder()
1403            .n_steps(5)
1404            .learning_rate(0.1)
1405            .grace_period(10)
1406            .max_depth(3)
1407            .n_bins(16)
1408            .max_tree_samples(200)
1409            .build()
1410            .unwrap();
1411        let mut model = SGBT::new(config);
1412
1413        for i in 0..500 {
1414            let x = (i as f64) * 0.1;
1415            model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
1416        }
1417
1418        let pred = model.predict(&[1.0, 2.0]);
1419        assert!(
1420            pred.is_finite(),
1421            "max_tree_samples model should produce finite predictions, got {}",
1422            pred
1423        );
1424    }
1425
1426    #[test]
1427    fn split_reeval_config_propagates() {
1428        let config = SGBTConfig::builder()
1429            .n_steps(5)
1430            .learning_rate(0.1)
1431            .grace_period(10)
1432            .max_depth(2)
1433            .n_bins(16)
1434            .split_reeval_interval(50)
1435            .build()
1436            .unwrap();
1437        let mut model = SGBT::new(config);
1438
1439        let mut rng: u64 = 12345;
1440        for _ in 0..1000 {
1441            rng ^= rng << 13;
1442            rng ^= rng >> 7;
1443            rng ^= rng << 17;
1444            let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1445            rng ^= rng << 13;
1446            rng ^= rng >> 7;
1447            rng ^= rng << 17;
1448            let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1449            let target = 2.0 * x1 + 3.0 * x2;
1450            model.train_one(&Sample::new(vec![x1, x2], target));
1451        }
1452
1453        let pred = model.predict(&[1.0, 2.0]);
1454        assert!(
1455            pred.is_finite(),
1456            "split re-eval model should produce finite predictions, got {}",
1457            pred
1458        );
1459    }
1460
1461    #[test]
1462    fn loss_accessor_works() {
1463        use crate::loss::logistic::LogisticLoss;
1464        let model = SGBT::with_loss(default_config(), LogisticLoss);
1465        // Verify we can access the concrete loss type
1466        let _loss: &LogisticLoss = model.loss();
1467        assert_eq!(_loss.n_outputs(), 1);
1468    }
1469
1470    #[test]
1471    fn clone_produces_independent_copy() {
1472        let config = default_config();
1473        let mut model = SGBT::new(config);
1474
1475        // Train the original on some data
1476        let mut rng: u64 = 99999;
1477        for _ in 0..200 {
1478            rng ^= rng << 13;
1479            rng ^= rng >> 7;
1480            rng ^= rng << 17;
1481            let x = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1482            let target = 2.0 * x + 1.0;
1483            model.train_one(&Sample::new(vec![x], target));
1484        }
1485
1486        // Clone the model
1487        let mut cloned = model.clone();
1488
1489        // Both should produce identical predictions
1490        let test_features = [3.0];
1491        let pred_original = model.predict(&test_features);
1492        let pred_cloned = cloned.predict(&test_features);
1493        assert!(
1494            (pred_original - pred_cloned).abs() < 1e-12,
1495            "clone should predict identically: original={pred_original}, cloned={pred_cloned}"
1496        );
1497
1498        // Train only the clone further -- models should diverge
1499        for _ in 0..200 {
1500            rng ^= rng << 13;
1501            rng ^= rng >> 7;
1502            rng ^= rng << 17;
1503            let x = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1504            let target = -3.0 * x + 5.0; // Different relationship
1505            cloned.train_one(&Sample::new(vec![x], target));
1506        }
1507
1508        let pred_original_after = model.predict(&test_features);
1509        let pred_cloned_after = cloned.predict(&test_features);
1510
1511        // Original should be unchanged
1512        assert!(
1513            (pred_original - pred_original_after).abs() < 1e-12,
1514            "original should be unchanged after training clone"
1515        );
1516
1517        // Clone should have diverged
1518        assert!(
1519            (pred_original_after - pred_cloned_after).abs() > 1e-6,
1520            "clone should diverge after independent training"
1521        );
1522    }
1523
1524    // -------------------------------------------------------------------
1525    // predict_with_confidence returns finite values
1526    // -------------------------------------------------------------------
1527    #[test]
1528    fn predict_with_confidence_finite() {
1529        let config = SGBTConfig::builder()
1530            .n_steps(5)
1531            .grace_period(10)
1532            .build()
1533            .unwrap();
1534        let mut model = SGBT::new(config);
1535
1536        // Train enough to initialize
1537        for i in 0..100 {
1538            let x = i as f64 * 0.1;
1539            model.train_one(&(&[x, x * 2.0][..], x + 1.0));
1540        }
1541
1542        let (pred, confidence) = model.predict_with_confidence(&[1.0, 2.0]);
1543        assert!(pred.is_finite(), "prediction should be finite");
1544        assert!(confidence.is_finite(), "confidence should be finite");
1545        assert!(
1546            confidence > 0.0,
1547            "confidence should be positive after training"
1548        );
1549    }
1550
1551    // -------------------------------------------------------------------
1552    // predict_with_confidence positive after training
1553    // -------------------------------------------------------------------
1554    #[test]
1555    fn predict_with_confidence_positive_after_training() {
1556        let config = SGBTConfig::builder()
1557            .n_steps(5)
1558            .grace_period(10)
1559            .build()
1560            .unwrap();
1561        let mut model = SGBT::new(config);
1562
1563        // Train enough to initialize and build structure
1564        for i in 0..200 {
1565            let x = i as f64 * 0.05;
1566            model.train_one(&(&[x][..], x * 2.0));
1567        }
1568
1569        let (pred, confidence) = model.predict_with_confidence(&[1.0]);
1570
1571        assert!(pred.is_finite(), "prediction should be finite");
1572        assert!(
1573            confidence > 0.0 && confidence.is_finite(),
1574            "confidence should be finite and positive, got {}",
1575            confidence,
1576        );
1577
1578        // Multiple queries should give consistent confidence
1579        let (pred2, conf2) = model.predict_with_confidence(&[1.0]);
1580        assert!(
1581            (pred - pred2).abs() < 1e-12,
1582            "same input should give same prediction"
1583        );
1584        assert!(
1585            (confidence - conf2).abs() < 1e-12,
1586            "same input should give same confidence"
1587        );
1588    }
1589
1590    // -------------------------------------------------------------------
1591    // predict_with_confidence agrees with predict on point estimate
1592    // -------------------------------------------------------------------
1593    #[test]
1594    fn predict_with_confidence_matches_predict() {
1595        let config = SGBTConfig::builder()
1596            .n_steps(10)
1597            .grace_period(10)
1598            .build()
1599            .unwrap();
1600        let mut model = SGBT::new(config);
1601
1602        for i in 0..200 {
1603            let x = (i as f64 - 100.0) * 0.01;
1604            model.train_one(&(&[x, x * x][..], x * 3.0 + 1.0));
1605        }
1606
1607        let pred = model.predict(&[0.5, 0.25]);
1608        let (conf_pred, _) = model.predict_with_confidence(&[0.5, 0.25]);
1609
1610        assert!(
1611            (pred - conf_pred).abs() < 1e-10,
1612            "prediction mismatch: predict()={} vs predict_with_confidence()={}",
1613            pred,
1614            conf_pred,
1615        );
1616    }
1617
1618    // -------------------------------------------------------------------
1619    // gradient clipping config round-trips through builder
1620    // -------------------------------------------------------------------
1621    #[test]
1622    fn gradient_clip_config_builder() {
1623        let config = SGBTConfig::builder()
1624            .n_steps(10)
1625            .gradient_clip_sigma(3.0)
1626            .build()
1627            .unwrap();
1628
1629        assert_eq!(config.gradient_clip_sigma, Some(3.0));
1630    }
1631
1632    // -------------------------------------------------------------------
1633    // monotonic constraints config round-trips through builder
1634    // -------------------------------------------------------------------
1635    #[test]
1636    fn monotone_constraints_config_builder() {
1637        let config = SGBTConfig::builder()
1638            .n_steps(10)
1639            .monotone_constraints(vec![1, -1, 0])
1640            .build()
1641            .unwrap();
1642
1643        assert_eq!(config.monotone_constraints, Some(vec![1, -1, 0]));
1644    }
1645
1646    // -------------------------------------------------------------------
1647    // monotonic constraints validation rejects invalid values
1648    // -------------------------------------------------------------------
1649    #[test]
1650    fn monotone_constraints_invalid_value_rejected() {
1651        let result = SGBTConfig::builder()
1652            .n_steps(10)
1653            .monotone_constraints(vec![1, 2, 0])
1654            .build();
1655
1656        assert!(result.is_err(), "constraint value 2 should be rejected");
1657    }
1658
1659    // -------------------------------------------------------------------
1660    // gradient clipping validation rejects non-positive sigma
1661    // -------------------------------------------------------------------
1662    #[test]
1663    fn gradient_clip_sigma_negative_rejected() {
1664        let result = SGBTConfig::builder()
1665            .n_steps(10)
1666            .gradient_clip_sigma(-1.0)
1667            .build();
1668
1669        assert!(result.is_err(), "negative sigma should be rejected");
1670    }
1671
1672    // -------------------------------------------------------------------
1673    // gradient clipping ensemble-level reduces outlier impact
1674    // -------------------------------------------------------------------
1675    #[test]
1676    fn gradient_clipping_reduces_outlier_impact() {
1677        // Without clipping
1678        let config_no_clip = SGBTConfig::builder()
1679            .n_steps(5)
1680            .grace_period(10)
1681            .build()
1682            .unwrap();
1683        let mut model_no_clip = SGBT::new(config_no_clip);
1684
1685        // With clipping
1686        let config_clip = SGBTConfig::builder()
1687            .n_steps(5)
1688            .grace_period(10)
1689            .gradient_clip_sigma(3.0)
1690            .build()
1691            .unwrap();
1692        let mut model_clip = SGBT::new(config_clip);
1693
1694        // Train both on identical normal data
1695        for i in 0..100 {
1696            let x = (i as f64) * 0.01;
1697            let sample = (&[x][..], x * 2.0);
1698            model_no_clip.train_one(&sample);
1699            model_clip.train_one(&sample);
1700        }
1701
1702        let pred_no_clip_before = model_no_clip.predict(&[0.5]);
1703        let pred_clip_before = model_clip.predict(&[0.5]);
1704
1705        // Inject outlier
1706        let outlier = (&[0.5_f64][..], 10000.0);
1707        model_no_clip.train_one(&outlier);
1708        model_clip.train_one(&outlier);
1709
1710        let pred_no_clip_after = model_no_clip.predict(&[0.5]);
1711        let pred_clip_after = model_clip.predict(&[0.5]);
1712
1713        let delta_no_clip = (pred_no_clip_after - pred_no_clip_before).abs();
1714        let delta_clip = (pred_clip_after - pred_clip_before).abs();
1715
1716        // Clipped model should be less affected by the outlier
1717        assert!(
1718            delta_clip <= delta_no_clip + 1e-10,
1719            "clipped model should be less affected: delta_clip={}, delta_no_clip={}",
1720            delta_clip,
1721            delta_no_clip,
1722        );
1723    }
1724
1725    // -------------------------------------------------------------------
1726    // train_batch_with_callback fires at correct intervals
1727    // -------------------------------------------------------------------
1728    #[test]
1729    fn train_batch_with_callback_fires() {
1730        let config = SGBTConfig::builder()
1731            .n_steps(3)
1732            .grace_period(5)
1733            .build()
1734            .unwrap();
1735        let mut model = SGBT::new(config);
1736
1737        let data: Vec<(Vec<f64>, f64)> = (0..25)
1738            .map(|i| (vec![i as f64 * 0.1], i as f64 * 0.5))
1739            .collect();
1740
1741        let mut callbacks = Vec::new();
1742        model.train_batch_with_callback(&data, 10, |n| {
1743            callbacks.push(n);
1744        });
1745
1746        // Should fire at 10, 20, and 25 (final)
1747        assert_eq!(callbacks, vec![10, 20, 25]);
1748    }
1749
1750    // -------------------------------------------------------------------
1751    // train_batch_subsampled produces deterministic subset
1752    // -------------------------------------------------------------------
1753    #[test]
1754    fn train_batch_subsampled_trains_subset() {
1755        let config = SGBTConfig::builder()
1756            .n_steps(3)
1757            .grace_period(5)
1758            .build()
1759            .unwrap();
1760        let mut model = SGBT::new(config);
1761
1762        let data: Vec<(Vec<f64>, f64)> = (0..100)
1763            .map(|i| (vec![i as f64 * 0.01], i as f64 * 0.1))
1764            .collect();
1765
1766        // Train on only 20 of 100 samples
1767        model.train_batch_subsampled(&data, 20);
1768
1769        // Model should have seen some samples
1770        assert!(
1771            model.n_samples_seen() > 0,
1772            "model should have trained on subset"
1773        );
1774        assert!(
1775            model.n_samples_seen() <= 20,
1776            "model should have trained at most 20 samples, got {}",
1777            model.n_samples_seen(),
1778        );
1779    }
1780
1781    // -------------------------------------------------------------------
1782    // train_batch_subsampled full dataset = train_batch
1783    // -------------------------------------------------------------------
1784    #[test]
1785    fn train_batch_subsampled_full_equals_batch() {
1786        let config1 = SGBTConfig::builder()
1787            .n_steps(3)
1788            .grace_period(5)
1789            .build()
1790            .unwrap();
1791        let config2 = config1.clone();
1792
1793        let mut model1 = SGBT::new(config1);
1794        let mut model2 = SGBT::new(config2);
1795
1796        let data: Vec<(Vec<f64>, f64)> = (0..50)
1797            .map(|i| (vec![i as f64 * 0.1], i as f64 * 0.5))
1798            .collect();
1799
1800        model1.train_batch(&data);
1801        model2.train_batch_subsampled(&data, 1000); // max_samples > data.len()
1802
1803        // Both should have identical state
1804        assert_eq!(model1.n_samples_seen(), model2.n_samples_seen());
1805        let pred1 = model1.predict(&[2.5]);
1806        let pred2 = model2.predict(&[2.5]);
1807        assert!(
1808            (pred1 - pred2).abs() < 1e-12,
1809            "full subsample should equal batch: {} vs {}",
1810            pred1,
1811            pred2,
1812        );
1813    }
1814
1815    // -------------------------------------------------------------------
1816    // train_batch_subsampled_with_callback combines both
1817    // -------------------------------------------------------------------
1818    #[test]
1819    fn train_batch_subsampled_with_callback_works() {
1820        let config = SGBTConfig::builder()
1821            .n_steps(3)
1822            .grace_period(5)
1823            .build()
1824            .unwrap();
1825        let mut model = SGBT::new(config);
1826
1827        let data: Vec<(Vec<f64>, f64)> = (0..200)
1828            .map(|i| (vec![i as f64 * 0.01], i as f64 * 0.1))
1829            .collect();
1830
1831        let mut callbacks = Vec::new();
1832        model.train_batch_subsampled_with_callback(&data, 50, 10, |n| {
1833            callbacks.push(n);
1834        });
1835
1836        // Should have trained ~50 samples with callbacks at 10, 20, 30, 40, 50
1837        assert!(!callbacks.is_empty(), "should have received callbacks");
1838        assert_eq!(
1839            *callbacks.last().unwrap(),
1840            50,
1841            "final callback should be total samples"
1842        );
1843    }
1844
1845    // ---------------------------------------------------------------
1846    // Linear leaf model integration tests
1847    // ---------------------------------------------------------------
1848
1849    /// xorshift64 PRNG for deterministic test data.
1850    fn xorshift64(state: &mut u64) -> u64 {
1851        let mut s = *state;
1852        s ^= s << 13;
1853        s ^= s >> 7;
1854        s ^= s << 17;
1855        *state = s;
1856        s
1857    }
1858
1859    fn rand_f64(state: &mut u64) -> f64 {
1860        xorshift64(state) as f64 / u64::MAX as f64
1861    }
1862
1863    fn linear_leaves_config() -> SGBTConfig {
1864        SGBTConfig::builder()
1865            .n_steps(10)
1866            .learning_rate(0.1)
1867            .grace_period(20)
1868            .max_depth(2) // low depth -- linear leaves should shine
1869            .n_bins(16)
1870            .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1871                learning_rate: 0.1,
1872                decay: None,
1873                use_adagrad: false,
1874            })
1875            .build()
1876            .unwrap()
1877    }
1878
1879    #[test]
1880    fn linear_leaves_trains_without_panic() {
1881        let mut model = SGBT::new(linear_leaves_config());
1882        let mut rng = 42u64;
1883        for _ in 0..200 {
1884            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1885            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1886            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1887            model.train_one(&Sample::new(vec![x1, x2], y));
1888        }
1889        assert_eq!(model.n_samples_seen(), 200);
1890    }
1891
1892    #[test]
1893    fn linear_leaves_prediction_finite() {
1894        let mut model = SGBT::new(linear_leaves_config());
1895        let mut rng = 42u64;
1896        for _ in 0..200 {
1897            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1898            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1899            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1900            model.train_one(&Sample::new(vec![x1, x2], y));
1901        }
1902        let pred = model.predict(&[0.5, -0.3]);
1903        assert!(pred.is_finite(), "prediction should be finite, got {pred}");
1904    }
1905
1906    #[test]
1907    fn linear_leaves_learns_linear_target() {
1908        let mut model = SGBT::new(linear_leaves_config());
1909        let mut rng = 42u64;
1910        for _ in 0..500 {
1911            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1912            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1913            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1914            model.train_one(&Sample::new(vec![x1, x2], y));
1915        }
1916
1917        // Test on a few points -- should be reasonably close for a linear target.
1918        let mut total_error = 0.0;
1919        for _ in 0..50 {
1920            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1921            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1922            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1923            let pred = model.predict(&[x1, x2]);
1924            total_error += (pred - y).powi(2);
1925        }
1926        let mse = total_error / 50.0;
1927        assert!(
1928            mse < 5.0,
1929            "linear leaves MSE on linear target should be < 5.0, got {mse}"
1930        );
1931    }
1932
1933    #[test]
1934    fn linear_leaves_better_than_constant_at_low_depth() {
1935        // Train two models on a linear target at depth 2:
1936        // one with constant leaves, one with linear leaves.
1937        let constant_config = SGBTConfig::builder()
1938            .n_steps(10)
1939            .learning_rate(0.1)
1940            .grace_period(20)
1941            .max_depth(2)
1942            .n_bins(16)
1943            .seed(0xDEAD)
1944            .build()
1945            .unwrap();
1946        let linear_config = SGBTConfig::builder()
1947            .n_steps(10)
1948            .learning_rate(0.1)
1949            .grace_period(20)
1950            .max_depth(2)
1951            .n_bins(16)
1952            .seed(0xDEAD)
1953            .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1954                learning_rate: 0.1,
1955                decay: None,
1956                use_adagrad: false,
1957            })
1958            .build()
1959            .unwrap();
1960
1961        let mut constant_model = SGBT::new(constant_config);
1962        let mut linear_model = SGBT::new(linear_config);
1963        let mut rng = 42u64;
1964
1965        for _ in 0..500 {
1966            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1967            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1968            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1969            let sample = Sample::new(vec![x1, x2], y);
1970            constant_model.train_one(&sample);
1971            linear_model.train_one(&sample);
1972        }
1973
1974        // Evaluate both on test set.
1975        let mut constant_mse = 0.0;
1976        let mut linear_mse = 0.0;
1977        for _ in 0..100 {
1978            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1979            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1980            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1981            constant_mse += (constant_model.predict(&[x1, x2]) - y).powi(2);
1982            linear_mse += (linear_model.predict(&[x1, x2]) - y).powi(2);
1983        }
1984        constant_mse /= 100.0;
1985        linear_mse /= 100.0;
1986
1987        // Linear leaves should outperform constant leaves on a linear target.
1988        assert!(
1989            linear_mse < constant_mse,
1990            "linear leaves MSE ({linear_mse:.4}) should be less than constant ({constant_mse:.4})"
1991        );
1992    }
1993
1994    #[test]
1995    fn adaptive_leaves_trains_without_panic() {
1996        let config = SGBTConfig::builder()
1997            .n_steps(10)
1998            .learning_rate(0.1)
1999            .grace_period(20)
2000            .max_depth(3)
2001            .n_bins(16)
2002            .leaf_model_type(crate::tree::leaf_model::LeafModelType::Adaptive {
2003                promote_to: Box::new(crate::tree::leaf_model::LeafModelType::Linear {
2004                    learning_rate: 0.1,
2005                    decay: None,
2006                    use_adagrad: false,
2007                }),
2008            })
2009            .build()
2010            .unwrap();
2011
2012        let mut model = SGBT::new(config);
2013        let mut rng = 42u64;
2014        for _ in 0..500 {
2015            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
2016            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
2017            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
2018            model.train_one(&Sample::new(vec![x1, x2], y));
2019        }
2020        let pred = model.predict(&[0.5, -0.3]);
2021        assert!(
2022            pred.is_finite(),
2023            "adaptive leaf prediction should be finite, got {pred}"
2024        );
2025    }
2026
2027    #[test]
2028    fn linear_leaves_with_decay_trains_without_panic() {
2029        let config = SGBTConfig::builder()
2030            .n_steps(10)
2031            .learning_rate(0.1)
2032            .grace_period(20)
2033            .max_depth(3)
2034            .n_bins(16)
2035            .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
2036                learning_rate: 0.1,
2037                decay: Some(0.995),
2038                use_adagrad: false,
2039            })
2040            .build()
2041            .unwrap();
2042
2043        let mut model = SGBT::new(config);
2044        let mut rng = 42u64;
2045        for _ in 0..500 {
2046            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
2047            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
2048            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
2049            model.train_one(&Sample::new(vec![x1, x2], y));
2050        }
2051        let pred = model.predict(&[0.5, -0.3]);
2052        assert!(
2053            pred.is_finite(),
2054            "decay leaf prediction should be finite, got {pred}"
2055        );
2056    }
2057
2058    // -------------------------------------------------------------------
2059    // predict_smooth returns finite values
2060    // -------------------------------------------------------------------
2061    #[test]
2062    fn predict_smooth_returns_finite() {
2063        let config = SGBTConfig::builder()
2064            .n_steps(5)
2065            .learning_rate(0.1)
2066            .grace_period(10)
2067            .build()
2068            .unwrap();
2069        let mut model = SGBT::new(config);
2070
2071        for i in 0..200 {
2072            let x = (i as f64) * 0.1;
2073            model.train_one(&Sample::new(vec![x, x.sin()], 2.0 * x + 1.0));
2074        }
2075
2076        let pred_hard = model.predict(&[1.0, 1.0_f64.sin()]);
2077        let pred_smooth = model.predict_smooth(&[1.0, 1.0_f64.sin()], 0.5);
2078
2079        assert!(pred_hard.is_finite(), "hard prediction should be finite");
2080        assert!(
2081            pred_smooth.is_finite(),
2082            "smooth prediction should be finite"
2083        );
2084    }
2085
2086    // -------------------------------------------------------------------
2087    // predict_smooth converges to hard predict at small bandwidth
2088    // -------------------------------------------------------------------
2089    #[test]
2090    fn predict_smooth_converges_to_hard_at_small_bandwidth() {
2091        let config = SGBTConfig::builder()
2092            .n_steps(5)
2093            .learning_rate(0.1)
2094            .grace_period(10)
2095            .build()
2096            .unwrap();
2097        let mut model = SGBT::new(config);
2098
2099        for i in 0..300 {
2100            let x = (i as f64) * 0.1;
2101            model.train_one(&Sample::new(vec![x, x * 0.5], 2.0 * x + 1.0));
2102        }
2103
2104        let features = [5.0, 2.5];
2105        let hard = model.predict(&features);
2106        let smooth = model.predict_smooth(&features, 0.001);
2107
2108        assert!(
2109            (hard - smooth).abs() < 0.5,
2110            "smooth with tiny bandwidth should approximate hard: hard={}, smooth={}",
2111            hard,
2112            smooth,
2113        );
2114    }
2115
2116    #[test]
2117    fn auto_bandwidth_computed_after_training() {
2118        let config = SGBTConfig::builder()
2119            .n_steps(5)
2120            .learning_rate(0.1)
2121            .grace_period(10)
2122            .build()
2123            .unwrap();
2124        let mut model = SGBT::new(config);
2125
2126        // Before training, no bandwidths
2127        assert!(model.auto_bandwidths().is_empty());
2128
2129        for i in 0..200 {
2130            let x = (i as f64) * 0.1;
2131            model.train_one(&Sample::new(vec![x, x * 0.5], 2.0 * x + 1.0));
2132        }
2133
2134        // After training, auto_bandwidths should be populated
2135        let bws = model.auto_bandwidths();
2136        assert_eq!(bws.len(), 2, "should have 2 feature bandwidths");
2137
2138        // predict() always uses smooth routing with auto-bandwidths
2139        let pred = model.predict(&[5.0, 2.5]);
2140        assert!(
2141            pred.is_finite(),
2142            "auto-bandwidth predict should be finite: {}",
2143            pred
2144        );
2145    }
2146
2147    #[test]
2148    fn predict_interpolated_returns_finite() {
2149        let config = SGBTConfig::builder()
2150            .n_steps(5)
2151            .learning_rate(0.01)
2152            .build()
2153            .unwrap();
2154        let mut model = SGBT::new(config);
2155
2156        for i in 0..200 {
2157            let x = (i as f64) * 0.1;
2158            model.train_one(&Sample::new(vec![x, x.sin()], x.cos()));
2159        }
2160
2161        let pred = model.predict_interpolated(&[1.0, 0.5]);
2162        assert!(
2163            pred.is_finite(),
2164            "interpolated prediction should be finite: {}",
2165            pred
2166        );
2167    }
2168
2169    #[test]
2170    fn predict_sibling_interpolated_varies_with_features() {
2171        let config = SGBTConfig::builder()
2172            .n_steps(10)
2173            .learning_rate(0.1)
2174            .grace_period(10)
2175            .max_depth(6)
2176            .delta(0.1)
2177            .build()
2178            .unwrap();
2179        let mut model = SGBT::new(config);
2180
2181        for i in 0..2000 {
2182            let x = (i as f64) * 0.01;
2183            let y = x.sin() * x + 0.5 * (x * 2.0).cos();
2184            model.train_one(&Sample::new(vec![x, x * 0.3], y));
2185        }
2186
2187        // Verify the method is callable and produces finite predictions
2188        let pred = model.predict_sibling_interpolated(&[5.0, 1.5]);
2189        assert!(pred.is_finite(), "sibling interpolated should be finite");
2190
2191        // If bandwidths are finite, verify sibling produces at least as much
2192        // variation as hard routing across a feature sweep
2193        let bws = model.auto_bandwidths();
2194        if bws.iter().any(|&b| b.is_finite()) {
2195            let hard: Vec<f64> = (0..200)
2196                .map(|i| model.predict(&[i as f64 * 0.1, i as f64 * 0.03]))
2197                .collect();
2198            let sib: Vec<f64> = (0..200)
2199                .map(|i| model.predict_sibling_interpolated(&[i as f64 * 0.1, i as f64 * 0.03]))
2200                .collect();
2201            let hc = hard
2202                .windows(2)
2203                .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2204                .count();
2205            let sc = sib
2206                .windows(2)
2207                .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2208                .count();
2209            assert!(
2210                sc >= hc,
2211                "sibling should produce >= hard changes: sib={}, hard={}",
2212                sc,
2213                hc
2214            );
2215        }
2216    }
2217
2218    #[test]
2219    fn predict_graduated_returns_finite() {
2220        let config = SGBTConfig::builder()
2221            .n_steps(5)
2222            .learning_rate(0.01)
2223            .max_tree_samples(200)
2224            .shadow_warmup(50)
2225            .build()
2226            .unwrap();
2227        let mut model = SGBT::new(config);
2228
2229        for i in 0..300 {
2230            let x = (i as f64) * 0.1;
2231            model.train_one(&Sample::new(vec![x, x.sin()], x.cos()));
2232        }
2233
2234        let pred = model.predict_graduated(&[1.0, 0.5]);
2235        assert!(
2236            pred.is_finite(),
2237            "graduated prediction should be finite: {}",
2238            pred
2239        );
2240
2241        let pred2 = model.predict_graduated_sibling_interpolated(&[1.0, 0.5]);
2242        assert!(
2243            pred2.is_finite(),
2244            "graduated+sibling prediction should be finite: {}",
2245            pred2
2246        );
2247    }
2248
2249    #[test]
2250    fn shadow_warmup_validation() {
2251        let result = SGBTConfig::builder()
2252            .n_steps(5)
2253            .learning_rate(0.01)
2254            .shadow_warmup(0)
2255            .build();
2256        assert!(result.is_err(), "shadow_warmup=0 should fail validation");
2257    }
2258
2259    // -------------------------------------------------------------------
2260    // adaptive_mts config tests
2261    // -------------------------------------------------------------------
2262
2263    #[test]
2264    fn adaptive_mts_defaults_to_none() {
2265        let cfg = SGBTConfig::default();
2266        assert!(
2267            cfg.adaptive_mts.is_none(),
2268            "adaptive_mts should default to None"
2269        );
2270    }
2271
2272    #[test]
2273    fn adaptive_mts_config_builder() {
2274        let cfg = SGBTConfig::builder()
2275            .n_steps(10)
2276            .adaptive_mts(500, 2.0)
2277            .build()
2278            .unwrap();
2279        assert_eq!(
2280            cfg.adaptive_mts,
2281            Some((500, 2.0)),
2282            "adaptive_mts should store (base_mts, k)"
2283        );
2284    }
2285
2286    #[test]
2287    fn adaptive_mts_validation_rejects_low_base() {
2288        let result = SGBTConfig::builder()
2289            .n_steps(5)
2290            .adaptive_mts(50, 1.0)
2291            .build();
2292        assert!(
2293            result.is_err(),
2294            "adaptive_mts with base_mts < 100 should fail"
2295        );
2296    }
2297
2298    #[test]
2299    fn adaptive_mts_validation_rejects_zero_k() {
2300        let result = SGBTConfig::builder()
2301            .n_steps(5)
2302            .adaptive_mts(500, 0.0)
2303            .build();
2304        assert!(result.is_err(), "adaptive_mts with k=0 should fail");
2305    }
2306
2307    #[test]
2308    fn adaptive_mts_trains_without_panic() {
2309        let config = SGBTConfig::builder()
2310            .n_steps(5)
2311            .learning_rate(0.1)
2312            .grace_period(10)
2313            .max_depth(3)
2314            .n_bins(16)
2315            .adaptive_mts(200, 1.0)
2316            .build()
2317            .unwrap();
2318        let mut model = SGBT::new(config);
2319
2320        for i in 0..500 {
2321            let x = (i as f64) * 0.1;
2322            model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
2323        }
2324
2325        let pred = model.predict(&[1.0, 2.0]);
2326        assert!(
2327            pred.is_finite(),
2328            "adaptive_mts model should produce finite predictions, got {}",
2329            pred
2330        );
2331    }
2332
2333    // -------------------------------------------------------------------
2334    // proactive_prune config tests
2335    // -------------------------------------------------------------------
2336
2337    #[test]
2338    fn proactive_prune_defaults_to_none() {
2339        let cfg = SGBTConfig::default();
2340        assert!(
2341            cfg.proactive_prune_interval.is_none(),
2342            "proactive_prune_interval should default to None"
2343        );
2344    }
2345
2346    #[test]
2347    fn proactive_prune_config_builder() {
2348        let cfg = SGBTConfig::builder()
2349            .n_steps(10)
2350            .proactive_prune_interval(500)
2351            .build()
2352            .unwrap();
2353        assert_eq!(
2354            cfg.proactive_prune_interval,
2355            Some(500),
2356            "proactive_prune_interval should be set"
2357        );
2358    }
2359
2360    #[test]
2361    fn proactive_prune_validation_rejects_low_interval() {
2362        let result = SGBTConfig::builder()
2363            .n_steps(5)
2364            .proactive_prune_interval(50)
2365            .build();
2366        assert!(
2367            result.is_err(),
2368            "proactive_prune_interval < 100 should fail"
2369        );
2370    }
2371
2372    #[test]
2373    fn proactive_prune_enables_contribution_tracking() {
2374        let config = SGBTConfig::builder()
2375            .n_steps(5)
2376            .learning_rate(0.1)
2377            .grace_period(10)
2378            .max_depth(3)
2379            .n_bins(16)
2380            .proactive_prune_interval(200)
2381            .build()
2382            .unwrap();
2383
2384        // quality_prune_alpha is None, but proactive_prune_interval is set
2385        assert!(config.quality_prune_alpha.is_none());
2386
2387        let mut model = SGBT::new(config);
2388
2389        // Train a few samples so contribution_ewma gets populated
2390        for i in 0..100 {
2391            let x = (i as f64) * 0.1;
2392            model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
2393        }
2394
2395        // contribution_ewma should be populated (not empty)
2396        // We can verify by checking that model trains without panic and produces
2397        // finite predictions -- the tracking is happening internally.
2398        let pred = model.predict(&[1.0, 2.0]);
2399        assert!(
2400            pred.is_finite(),
2401            "proactive_prune model should produce finite predictions, got {}",
2402            pred
2403        );
2404    }
2405
2406    #[test]
2407    fn proactive_prune_trains_without_panic() {
2408        let config = SGBTConfig::builder()
2409            .n_steps(5)
2410            .learning_rate(0.1)
2411            .grace_period(10)
2412            .max_depth(3)
2413            .n_bins(16)
2414            .proactive_prune_interval(200)
2415            .build()
2416            .unwrap();
2417        let mut model = SGBT::new(config);
2418
2419        // Train past the prune interval to exercise the proactive prune path
2420        for i in 0..500 {
2421            let x = (i as f64) * 0.1;
2422            model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
2423        }
2424
2425        let pred = model.predict(&[1.0, 2.0]);
2426        assert!(
2427            pred.is_finite(),
2428            "proactive_prune model should produce finite predictions after pruning, got {}",
2429            pred
2430        );
2431    }
2432}