Skip to main content

irithyll_core/ensemble/
mod.rs

1//! SGBT ensemble orchestrator -- the core boosting loop.
2//!
3//! Implements Streaming Gradient Boosted Trees (Gunasekara et al., 2024):
4//! a sequence of boosting steps, each owning a streaming tree and drift detector,
5//! with automatic tree replacement when concept drift is detected.
6//!
7//! # Algorithm
8//!
9//! For each incoming sample `(x, y)`:
10//! 1. Compute the current ensemble prediction: `F(x) = base + lr * Σ tree_s(x)`
11//! 2. For each boosting step `s = 1..N`:
12//!    - Compute gradient `g = loss.gradient(y, current_pred)`
13//!    - Compute hessian `h = loss.hessian(y, current_pred)`
14//!    - Feed `(x, g, h)` to tree `s` (which internally uses weighted squared loss)
15//!    - Update `current_pred += lr * tree_s.predict(x)`
16//! 3. The ensemble adapts incrementally, with each tree targeting the residual
17//!    of all preceding trees.
18
19pub mod adaptive;
20pub mod adaptive_forest;
21pub mod bagged;
22pub mod config;
23pub mod distributional;
24pub mod lr_schedule;
25pub mod moe;
26pub mod moe_distributional;
27pub mod multi_target;
28pub mod multiclass;
29#[cfg(feature = "parallel")]
30pub mod parallel;
31pub mod quantile_regressor;
32pub mod replacement;
33pub mod stacked;
34pub mod step;
35pub mod variants;
36
37use alloc::boxed::Box;
38use alloc::string::String;
39use alloc::vec;
40use alloc::vec::Vec;
41
42use core::fmt;
43
44use crate::ensemble::config::SGBTConfig;
45use crate::ensemble::step::BoostingStep;
46use crate::loss::squared::SquaredLoss;
47use crate::loss::Loss;
48use crate::sample::Observation;
49#[allow(unused_imports)] // Used in doc links + tests
50use crate::sample::Sample;
51use crate::tree::builder::TreeConfig;
52
53/// Type alias for an SGBT model using dynamic (boxed) loss dispatch.
54///
55/// Use this when the loss function is determined at runtime (e.g., when
56/// deserializing a model from JSON where the loss type is stored as a tag).
57///
58/// For compile-time loss dispatch (preferred for performance), use
59/// `SGBT<LogisticLoss>`, `SGBT<HuberLoss>`, etc.
60pub type DynSGBT = SGBT<Box<dyn Loss>>;
61
62/// Streaming Gradient Boosted Trees ensemble.
63///
64/// The primary entry point for training and prediction. Generic over `L: Loss`
65/// so the loss function's gradient/hessian calls are monomorphized (inlined)
66/// into the boosting hot loop -- no virtual dispatch overhead.
67///
68/// The default type parameter `L = SquaredLoss` means `SGBT::new(config)`
69/// creates a regression model without specifying the loss type explicitly.
70///
71/// # Examples
72///
73/// ```ignore
74/// use irithyll::{SGBTConfig, SGBT};
75///
76/// // Regression with squared loss (default):
77/// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
78/// let model = SGBT::new(config);
79/// ```ignore
80///
81/// ```ignore
82/// use irithyll::{SGBTConfig, SGBT};
83/// use irithyll::loss::logistic::LogisticLoss;
84///
85/// // Classification with logistic loss -- no Box::new()!
86/// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
87/// let model = SGBT::with_loss(config, LogisticLoss);
88/// ```
89pub struct SGBT<L: Loss = SquaredLoss> {
90    /// Configuration.
91    config: SGBTConfig,
92    /// Boosting steps (one tree + drift detector each).
93    steps: Vec<BoostingStep>,
94    /// Loss function (monomorphized -- no vtable).
95    loss: L,
96    /// Base prediction (initial constant, computed from first batch of targets).
97    base_prediction: f64,
98    /// Whether base_prediction has been initialized.
99    base_initialized: bool,
100    /// Running collection of initial targets for computing base_prediction.
101    initial_targets: Vec<f64>,
102    /// Number of initial targets to collect before setting base_prediction.
103    initial_target_count: usize,
104    /// Total samples trained.
105    samples_seen: u64,
106    /// RNG state for variant skip logic.
107    rng_state: u64,
108    /// Per-step EWMA of |marginal contribution| for quality-based pruning.
109    /// Empty when `quality_prune_alpha` is `None`.
110    contribution_ewma: Vec<f64>,
111    /// Per-step consecutive low-contribution sample counter.
112    /// Empty when `quality_prune_alpha` is `None`.
113    low_contrib_count: Vec<u64>,
114    /// Rolling mean absolute error for error-weighted sample importance.
115    /// Only used when `error_weight_alpha` is `Some`.
116    rolling_mean_error: f64,
117    /// Per-feature auto-calibrated bandwidths for smooth prediction.
118    /// Computed from median split threshold gaps across all trees.
119    auto_bandwidths: Vec<f64>,
120    /// Sum of replacement counts across all steps at last bandwidth computation.
121    /// Used to detect when trees have been replaced and bandwidths need refresh.
122    last_replacement_sum: u64,
123}
124
125impl<L: Loss + Clone> Clone for SGBT<L> {
126    fn clone(&self) -> Self {
127        Self {
128            config: self.config.clone(),
129            steps: self.steps.clone(),
130            loss: self.loss.clone(),
131            base_prediction: self.base_prediction,
132            base_initialized: self.base_initialized,
133            initial_targets: self.initial_targets.clone(),
134            initial_target_count: self.initial_target_count,
135            samples_seen: self.samples_seen,
136            rng_state: self.rng_state,
137            contribution_ewma: self.contribution_ewma.clone(),
138            low_contrib_count: self.low_contrib_count.clone(),
139            rolling_mean_error: self.rolling_mean_error,
140            auto_bandwidths: self.auto_bandwidths.clone(),
141            last_replacement_sum: self.last_replacement_sum,
142        }
143    }
144}
145
146impl<L: Loss> fmt::Debug for SGBT<L> {
147    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148        f.debug_struct("SGBT")
149            .field("n_steps", &self.steps.len())
150            .field("samples_seen", &self.samples_seen)
151            .field("base_prediction", &self.base_prediction)
152            .field("base_initialized", &self.base_initialized)
153            .finish()
154    }
155}
156
157// ---------------------------------------------------------------------------
158// Convenience constructor for the default loss (SquaredLoss)
159// ---------------------------------------------------------------------------
160
161impl SGBT<SquaredLoss> {
162    /// Create a new SGBT ensemble with squared loss (regression).
163    ///
164    /// This is the most common constructor. For classification or custom
165    /// losses, use [`with_loss`](SGBT::with_loss).
166    pub fn new(config: SGBTConfig) -> Self {
167        Self::with_loss(config, SquaredLoss)
168    }
169}
170
171// ---------------------------------------------------------------------------
172// General impl for all Loss types
173// ---------------------------------------------------------------------------
174
175impl<L: Loss> SGBT<L> {
176    /// Create a new SGBT ensemble with a specific loss function.
177    ///
178    /// The loss is stored by value (monomorphized), giving zero-cost
179    /// gradient/hessian dispatch.
180    ///
181    /// ```ignore
182    /// use irithyll::{SGBTConfig, SGBT};
183    /// use irithyll::loss::logistic::LogisticLoss;
184    ///
185    /// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
186    /// let model = SGBT::with_loss(config, LogisticLoss);
187    /// ```
188    pub fn with_loss(config: SGBTConfig, loss: L) -> Self {
189        let leaf_decay_alpha = config
190            .leaf_half_life
191            .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
192
193        let tree_config = TreeConfig::new()
194            .max_depth(config.max_depth)
195            .n_bins(config.n_bins)
196            .lambda(config.lambda)
197            .gamma(config.gamma)
198            .grace_period(config.grace_period)
199            .delta(config.delta)
200            .feature_subsample_rate(config.feature_subsample_rate)
201            .leaf_decay_alpha_opt(leaf_decay_alpha)
202            .split_reeval_interval_opt(config.split_reeval_interval)
203            .feature_types_opt(config.feature_types.clone())
204            .gradient_clip_sigma_opt(config.gradient_clip_sigma)
205            .monotone_constraints_opt(config.monotone_constraints.clone())
206            .max_leaf_output_opt(config.max_leaf_output)
207            .adaptive_leaf_bound_opt(config.adaptive_leaf_bound)
208            .min_hessian_sum_opt(config.min_hessian_sum)
209            .leaf_model_type(config.leaf_model_type.clone());
210
211        let max_tree_samples = config.max_tree_samples;
212
213        let shadow_warmup = config.shadow_warmup.unwrap_or(0);
214        let steps: Vec<BoostingStep> = (0..config.n_steps)
215            .map(|i| {
216                let mut tc = tree_config.clone();
217                tc.seed = config.seed ^ (i as u64);
218                let detector = config.drift_detector.create();
219                if shadow_warmup > 0 {
220                    BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
221                } else {
222                    BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
223                }
224            })
225            .collect();
226
227        let seed = config.seed;
228        let initial_target_count = config.initial_target_count;
229        let n = config.n_steps;
230        let has_pruning = config.quality_prune_alpha.is_some();
231        Self {
232            config,
233            steps,
234            loss,
235            base_prediction: 0.0,
236            base_initialized: false,
237            initial_targets: Vec::new(),
238            initial_target_count,
239            samples_seen: 0,
240            rng_state: seed,
241            contribution_ewma: if has_pruning {
242                vec![0.0; n]
243            } else {
244                Vec::new()
245            },
246            low_contrib_count: if has_pruning { vec![0; n] } else { Vec::new() },
247            rolling_mean_error: 0.0,
248            auto_bandwidths: Vec::new(),
249            last_replacement_sum: 0,
250        }
251    }
252
253    /// Train on a single observation.
254    ///
255    /// Accepts any type implementing [`Observation`], including [`Sample`],
256    /// [`SampleRef`](crate::SampleRef), or tuples like `(&[f64], f64)` for
257    /// zero-copy training.
258    pub fn train_one(&mut self, sample: &impl Observation) {
259        self.samples_seen += 1;
260        let target = sample.target();
261        let features = sample.features();
262
263        // Initialize base prediction from first few targets
264        if !self.base_initialized {
265            self.initial_targets.push(target);
266            if self.initial_targets.len() >= self.initial_target_count {
267                self.base_prediction = self.loss.initial_prediction(&self.initial_targets);
268                self.base_initialized = true;
269                self.initial_targets.clear();
270                self.initial_targets.shrink_to_fit();
271            }
272        }
273
274        // Current prediction starts from base
275        let mut current_pred = self.base_prediction;
276
277        let prune_alpha = self.config.quality_prune_alpha;
278        let prune_threshold = self.config.quality_prune_threshold;
279        let prune_patience = self.config.quality_prune_patience;
280
281        // Error-weighted sample importance: compute weight from prediction error
282        let error_weight = if let Some(ew_alpha) = self.config.error_weight_alpha {
283            let abs_error = crate::math::abs(target - current_pred);
284            if self.rolling_mean_error > 1e-15 {
285                let w = (1.0 + abs_error / (self.rolling_mean_error + 1e-15)).min(10.0);
286                self.rolling_mean_error =
287                    ew_alpha * abs_error + (1.0 - ew_alpha) * self.rolling_mean_error;
288                w
289            } else {
290                self.rolling_mean_error = abs_error.max(1e-15);
291                1.0 // first sample, no reweighting
292            }
293        } else {
294            1.0
295        };
296
297        // Sequential boosting: each step targets the residual of all prior steps
298        for s in 0..self.steps.len() {
299            let gradient = self.loss.gradient(target, current_pred) * error_weight;
300            let hessian = self.loss.hessian(target, current_pred) * error_weight;
301            let train_count = self
302                .config
303                .variant
304                .train_count(hessian, &mut self.rng_state);
305
306            let step_pred =
307                self.steps[s].train_and_predict(features, gradient, hessian, train_count);
308
309            current_pred += self.config.learning_rate * step_pred;
310
311            // Quality-based tree pruning: track contribution and replace dead wood
312            if let Some(alpha) = prune_alpha {
313                let contribution = crate::math::abs(self.config.learning_rate * step_pred);
314                self.contribution_ewma[s] =
315                    alpha * contribution + (1.0 - alpha) * self.contribution_ewma[s];
316
317                if self.contribution_ewma[s] < prune_threshold {
318                    self.low_contrib_count[s] += 1;
319                    if self.low_contrib_count[s] >= prune_patience {
320                        self.steps[s].reset();
321                        self.contribution_ewma[s] = 0.0;
322                        self.low_contrib_count[s] = 0;
323                    }
324                } else {
325                    self.low_contrib_count[s] = 0;
326                }
327            }
328        }
329
330        // Refresh auto-bandwidths when trees have been replaced or not yet computed.
331        self.refresh_bandwidths();
332    }
333
334    /// Train on a batch of observations.
335    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
336        for sample in samples {
337            self.train_one(sample);
338        }
339    }
340
341    /// Train on a batch with periodic callback for cooperative yielding.
342    ///
343    /// The callback is invoked every `interval` samples with the number of
344    /// samples processed so far. This allows long-running training to yield
345    /// to other tasks in an async runtime, update progress bars, or perform
346    /// periodic checkpointing.
347    ///
348    /// # Example
349    ///
350    /// ```ignore
351    /// use irithyll::{SGBTConfig, SGBT};
352    ///
353    /// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
354    /// let mut model = SGBT::new(config);
355    /// let data: Vec<(Vec<f64>, f64)> = Vec::new(); // your data
356    ///
357    /// model.train_batch_with_callback(&data, 1000, |processed| {
358    ///     println!("Trained {} samples", processed);
359    /// });
360    /// ```
361    pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
362        &mut self,
363        samples: &[O],
364        interval: usize,
365        mut callback: F,
366    ) {
367        let interval = interval.max(1); // Prevent zero interval
368        for (i, sample) in samples.iter().enumerate() {
369            self.train_one(sample);
370            if (i + 1) % interval == 0 {
371                callback(i + 1);
372            }
373        }
374        // Final callback if the total isn't a multiple of interval
375        let total = samples.len();
376        if total % interval != 0 {
377            callback(total);
378        }
379    }
380
381    /// Train on a random subsample of a batch using reservoir sampling.
382    ///
383    /// When `max_samples < samples.len()`, selects a representative subset
384    /// using Algorithm R (Vitter, 1985) -- a uniform random sample without
385    /// replacement. The selected samples are then trained in their original
386    /// order to preserve sequential dependencies.
387    ///
388    /// This is ideal for large replay buffers where training on the full
389    /// dataset is prohibitively slow but a representative subset gives
390    /// equivalent model quality (e.g., 1M of 4.3M samples with R²=0.997).
391    ///
392    /// When `max_samples >= samples.len()`, all samples are trained.
393    pub fn train_batch_subsampled<O: Observation>(&mut self, samples: &[O], max_samples: usize) {
394        if max_samples >= samples.len() {
395            self.train_batch(samples);
396            return;
397        }
398
399        // Reservoir sampling (Algorithm R) to select indices
400        let mut reservoir: Vec<usize> = (0..max_samples).collect();
401        let mut rng = self.rng_state;
402
403        for i in max_samples..samples.len() {
404            // Generate random index in [0, i]
405            rng ^= rng << 13;
406            rng ^= rng >> 7;
407            rng ^= rng << 17;
408            let j = (rng % (i as u64 + 1)) as usize;
409            if j < max_samples {
410                reservoir[j] = i;
411            }
412        }
413
414        self.rng_state = rng;
415
416        // Sort to preserve original order (important for EWMA/drift state)
417        reservoir.sort_unstable();
418
419        // Train on the selected subset
420        for &idx in &reservoir {
421            self.train_one(&samples[idx]);
422        }
423    }
424
425    /// Train on a batch with both subsampling and periodic callbacks.
426    ///
427    /// Combines reservoir subsampling with cooperative yield points.
428    /// Ideal for long-running daemon training where you need both
429    /// efficiency (subsampling) and cooperation (yielding).
430    pub fn train_batch_subsampled_with_callback<O: Observation, F: FnMut(usize)>(
431        &mut self,
432        samples: &[O],
433        max_samples: usize,
434        interval: usize,
435        mut callback: F,
436    ) {
437        if max_samples >= samples.len() {
438            self.train_batch_with_callback(samples, interval, callback);
439            return;
440        }
441
442        // Reservoir sampling
443        let mut reservoir: Vec<usize> = (0..max_samples).collect();
444        let mut rng = self.rng_state;
445
446        for i in max_samples..samples.len() {
447            rng ^= rng << 13;
448            rng ^= rng >> 7;
449            rng ^= rng << 17;
450            let j = (rng % (i as u64 + 1)) as usize;
451            if j < max_samples {
452                reservoir[j] = i;
453            }
454        }
455
456        self.rng_state = rng;
457        reservoir.sort_unstable();
458
459        let interval = interval.max(1);
460        for (i, &idx) in reservoir.iter().enumerate() {
461            self.train_one(&samples[idx]);
462            if (i + 1) % interval == 0 {
463                callback(i + 1);
464            }
465        }
466        let total = reservoir.len();
467        if total % interval != 0 {
468            callback(total);
469        }
470    }
471
472    /// Predict the raw output for a feature vector.
473    ///
474    /// Always uses sigmoid-blended soft routing with auto-calibrated per-feature
475    /// bandwidths derived from median split threshold gaps. Features that have
476    /// never been split on use hard routing (bandwidth = infinity).
477    pub fn predict(&self, features: &[f64]) -> f64 {
478        let mut pred = self.base_prediction;
479        if self.auto_bandwidths.is_empty() {
480            // No bandwidths computed yet (no training) — hard routing fallback
481            for step in &self.steps {
482                pred += self.config.learning_rate * step.predict(features);
483            }
484        } else {
485            for step in &self.steps {
486                pred += self.config.learning_rate
487                    * step.predict_smooth_auto(features, &self.auto_bandwidths);
488            }
489        }
490        pred
491    }
492
493    /// Predict using sigmoid-blended soft routing with an explicit bandwidth.
494    ///
495    /// Uses a single bandwidth for all features. For auto-calibrated per-feature
496    /// bandwidths, use [`predict()`](SGBT::predict) which always uses smooth routing.
497    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
498        let mut pred = self.base_prediction;
499        for step in &self.steps {
500            pred += self.config.learning_rate * step.predict_smooth(features, bandwidth);
501        }
502        pred
503    }
504
505    /// Per-feature auto-calibrated bandwidths used by `predict()`.
506    ///
507    /// Empty before the first training sample. Each entry corresponds to a
508    /// feature index; `f64::INFINITY` means that feature has no splits and
509    /// uses hard routing.
510    pub fn auto_bandwidths(&self) -> &[f64] {
511        &self.auto_bandwidths
512    }
513
514    /// Predict with parent-leaf linear interpolation.
515    ///
516    /// Blends each leaf prediction with its parent's preserved prediction
517    /// based on sample count, preventing stale predictions from fresh leaves.
518    pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
519        let mut pred = self.base_prediction;
520        for step in &self.steps {
521            pred += self.config.learning_rate * step.predict_interpolated(features);
522        }
523        pred
524    }
525
526    /// Predict with sibling-based interpolation for feature-continuous predictions.
527    ///
528    /// At each split node near the threshold boundary, blends left and right
529    /// subtree predictions linearly based on distance from the threshold.
530    /// Uses auto-calibrated bandwidths as the interpolation margin.
531    /// Predictions vary continuously as features change, eliminating
532    /// step-function artifacts.
533    pub fn predict_sibling_interpolated(&self, features: &[f64]) -> f64 {
534        let mut pred = self.base_prediction;
535        for step in &self.steps {
536            pred += self.config.learning_rate
537                * step.predict_sibling_interpolated(features, &self.auto_bandwidths);
538        }
539        pred
540    }
541
542    /// Predict with graduated active-shadow blending.
543    ///
544    /// Smoothly transitions between active and shadow trees during replacement,
545    /// eliminating prediction dips. Requires `shadow_warmup` to be configured.
546    /// When disabled, equivalent to `predict()`.
547    pub fn predict_graduated(&self, features: &[f64]) -> f64 {
548        let mut pred = self.base_prediction;
549        for step in &self.steps {
550            pred += self.config.learning_rate * step.predict_graduated(features);
551        }
552        pred
553    }
554
555    /// Predict with graduated blending + sibling interpolation (premium path).
556    ///
557    /// Combines graduated active-shadow handoff (no prediction dips during
558    /// tree replacement) with feature-continuous sibling interpolation
559    /// (no step-function artifacts near split boundaries).
560    pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> f64 {
561        let mut pred = self.base_prediction;
562        for step in &self.steps {
563            pred += self.config.learning_rate
564                * step.predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
565        }
566        pred
567    }
568
569    /// Predict with loss transform applied (e.g., sigmoid for logistic loss).
570    pub fn predict_transformed(&self, features: &[f64]) -> f64 {
571        self.loss.predict_transform(self.predict(features))
572    }
573
574    /// Predict probability (alias for `predict_transformed`).
575    pub fn predict_proba(&self, features: &[f64]) -> f64 {
576        self.predict_transformed(features)
577    }
578
579    /// Predict with confidence estimation.
580    ///
581    /// Returns `(prediction, confidence)` where confidence = 1 / sqrt(sum_variance).
582    /// Higher confidence indicates more certain predictions (leaves have seen
583    /// more hessian mass). Confidence of 0.0 means the model has no information.
584    ///
585    /// This enables execution engines to modulate aggressiveness:
586    /// - High confidence + favorable prediction → act immediately
587    /// - Low confidence → fall back to simpler models or wait for more data
588    ///
589    /// The variance per tree is estimated as `1 / (H_sum + lambda)` at the
590    /// leaf where the sample lands. The ensemble variance is the sum of
591    /// per-tree variances (scaled by learning_rate²), and confidence is
592    /// the reciprocal of the standard deviation.
593    pub fn predict_with_confidence(&self, features: &[f64]) -> (f64, f64) {
594        let mut pred = self.base_prediction;
595        let mut total_variance = 0.0;
596        let lr2 = self.config.learning_rate * self.config.learning_rate;
597
598        for step in &self.steps {
599            let (value, variance) = step.predict_with_variance(features);
600            pred += self.config.learning_rate * value;
601            total_variance += lr2 * variance;
602        }
603
604        let confidence = if total_variance > 0.0 && total_variance.is_finite() {
605            1.0 / crate::math::sqrt(total_variance)
606        } else {
607            0.0
608        };
609
610        (pred, confidence)
611    }
612
613    /// Batch prediction.
614    pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
615        feature_matrix.iter().map(|f| self.predict(f)).collect()
616    }
617
618    /// Number of boosting steps.
619    pub fn n_steps(&self) -> usize {
620        self.steps.len()
621    }
622
623    /// Total trees (active + alternates).
624    pub fn n_trees(&self) -> usize {
625        self.steps.len() + self.steps.iter().filter(|s| s.has_alternate()).count()
626    }
627
628    /// Total leaves across all active trees.
629    pub fn total_leaves(&self) -> usize {
630        self.steps.iter().map(|s| s.n_leaves()).sum()
631    }
632
633    /// Total samples trained.
634    pub fn n_samples_seen(&self) -> u64 {
635        self.samples_seen
636    }
637
638    /// The current base prediction.
639    pub fn base_prediction(&self) -> f64 {
640        self.base_prediction
641    }
642
643    /// Whether the base prediction has been initialized.
644    pub fn is_initialized(&self) -> bool {
645        self.base_initialized
646    }
647
648    /// Access the configuration.
649    pub fn config(&self) -> &SGBTConfig {
650        &self.config
651    }
652
653    /// Set the learning rate for future boosting rounds.
654    ///
655    /// This allows external schedulers (e.g., [`lr_schedule::LRScheduler`]) to
656    /// adapt the rate over time without rebuilding the model.
657    ///
658    /// # Arguments
659    ///
660    /// * `lr` -- New learning rate (should be positive and finite)
661    #[inline]
662    pub fn set_learning_rate(&mut self, lr: f64) {
663        self.config.learning_rate = lr;
664    }
665
666    /// Immutable access to the boosting steps.
667    ///
668    /// Useful for model inspection and export (e.g., ONNX serialization).
669    pub fn steps(&self) -> &[BoostingStep] {
670        &self.steps
671    }
672
673    /// Immutable access to the loss function.
674    pub fn loss(&self) -> &L {
675        &self.loss
676    }
677
678    /// Feature importances based on accumulated split gains across all trees.
679    ///
680    /// Returns normalized importances (sum to 1.0) indexed by feature.
681    /// Returns an empty Vec if no splits have occurred yet.
682    pub fn feature_importances(&self) -> Vec<f64> {
683        // Aggregate split gains across all boosting steps.
684        let mut totals: Vec<f64> = Vec::new();
685        for step in &self.steps {
686            let gains = step.slot().split_gains();
687            if totals.is_empty() && !gains.is_empty() {
688                totals.resize(gains.len(), 0.0);
689            }
690            for (i, &g) in gains.iter().enumerate() {
691                if i < totals.len() {
692                    totals[i] += g;
693                }
694            }
695        }
696
697        // Normalize to sum to 1.0.
698        let sum: f64 = totals.iter().sum();
699        if sum > 0.0 {
700            totals.iter_mut().for_each(|v| *v /= sum);
701        }
702        totals
703    }
704
705    /// Feature names, if configured.
706    pub fn feature_names(&self) -> Option<&[String]> {
707        self.config.feature_names.as_deref()
708    }
709
710    /// Feature importances paired with their names.
711    ///
712    /// Returns `None` if feature names are not configured. Otherwise returns
713    /// `(name, importance)` pairs sorted by importance descending.
714    pub fn named_feature_importances(&self) -> Option<Vec<(String, f64)>> {
715        let names = self.config.feature_names.as_ref()?;
716        let importances = self.feature_importances();
717        let mut pairs: Vec<(String, f64)> = names
718            .iter()
719            .zip(importances.iter().chain(core::iter::repeat(&0.0)))
720            .map(|(n, &v)| (n.clone(), v))
721            .collect();
722        pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
723        Some(pairs)
724    }
725
726    /// Train on a single sample with named features.
727    ///
728    /// Converts a `HashMap<String, f64>` of named features into a positional
729    /// vector using the configured feature names. Missing features default to 0.0.
730    ///
731    /// # Panics
732    ///
733    /// Panics if `feature_names` is not configured.
734    #[cfg(feature = "std")]
735    pub fn train_one_named(
736        &mut self,
737        features: &std::collections::HashMap<alloc::string::String, f64>,
738        target: f64,
739    ) {
740        let names = self
741            .config
742            .feature_names
743            .as_ref()
744            .expect("train_one_named requires feature_names to be configured");
745        let vec: Vec<f64> = names
746            .iter()
747            .map(|name| features.get(name).copied().unwrap_or(0.0))
748            .collect();
749        self.train_one(&(&vec[..], target));
750    }
751
752    /// Predict with named features.
753    ///
754    /// Converts named features into a positional vector, same as `train_one_named`.
755    ///
756    /// # Panics
757    ///
758    /// Panics if `feature_names` is not configured.
759    #[cfg(feature = "std")]
760    pub fn predict_named(
761        &self,
762        features: &std::collections::HashMap<alloc::string::String, f64>,
763    ) -> f64 {
764        let names = self
765            .config
766            .feature_names
767            .as_ref()
768            .expect("predict_named requires feature_names to be configured");
769        let vec: Vec<f64> = names
770            .iter()
771            .map(|name| features.get(name).copied().unwrap_or(0.0))
772            .collect();
773        self.predict(&vec)
774    }
775
776    // NOTE: explain() and explain_named() require the `explain` module which
777    // lives in the full `irithyll` crate, not in `irithyll-core`. Those methods
778    // are provided via the re-export layer in `irithyll::ensemble`.
779
780    /// Refresh auto-bandwidths if any tree has been replaced since last computation.
781    fn refresh_bandwidths(&mut self) {
782        let current_sum: u64 = self.steps.iter().map(|s| s.slot().replacements()).sum();
783        if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
784            self.auto_bandwidths = self.compute_auto_bandwidths();
785            self.last_replacement_sum = current_sum;
786        }
787    }
788
789    /// Compute per-feature auto-calibrated bandwidths from all trees.
790    ///
791    /// For each feature, collects all split thresholds across all trees,
792    /// computes the median gap between consecutive unique thresholds, and
793    /// returns `median_gap * K` (K = 2.0).
794    ///
795    /// Edge cases:
796    /// - Feature with < 3 unique thresholds: `range / n_bins * K`
797    /// - Feature never split on (< 2 unique thresholds): `f64::INFINITY` (hard routing)
798    fn compute_auto_bandwidths(&self) -> Vec<f64> {
799        const K: f64 = 2.0;
800
801        // Determine n_features from the trees
802        let n_features = self
803            .steps
804            .iter()
805            .filter_map(|s| s.slot().active_tree().n_features())
806            .max()
807            .unwrap_or(0);
808
809        if n_features == 0 {
810            return Vec::new();
811        }
812
813        // Collect all thresholds from all trees per feature
814        let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
815
816        for step in &self.steps {
817            let tree_thresholds = step
818                .slot()
819                .active_tree()
820                .collect_split_thresholds_per_feature();
821            for (i, ts) in tree_thresholds.into_iter().enumerate() {
822                if i < n_features {
823                    all_thresholds[i].extend(ts);
824                }
825            }
826        }
827
828        let n_bins = self.config.n_bins as f64;
829
830        // Compute per-feature bandwidth
831        all_thresholds
832            .iter()
833            .map(|ts| {
834                if ts.is_empty() {
835                    return f64::INFINITY; // Never split on → hard routing
836                }
837
838                let mut sorted = ts.clone();
839                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
840                sorted.dedup_by(|a, b| crate::math::abs(*a - *b) < 1e-15);
841
842                if sorted.len() < 2 {
843                    return f64::INFINITY; // Single threshold → hard routing
844                }
845
846                // Compute gaps between consecutive unique thresholds
847                let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
848
849                if sorted.len() < 3 {
850                    // Fallback: feature_range / n_bins * K
851                    let range = sorted.last().unwrap() - sorted.first().unwrap();
852                    if range < 1e-15 {
853                        return f64::INFINITY;
854                    }
855                    return (range / n_bins) * K;
856                }
857
858                // Median gap
859                gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
860                let median_gap = if gaps.len() % 2 == 0 {
861                    (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
862                } else {
863                    gaps[gaps.len() / 2]
864                };
865
866                if median_gap < 1e-15 {
867                    f64::INFINITY
868                } else {
869                    median_gap * K
870                }
871            })
872            .collect()
873    }
874
875    /// Reset the ensemble to initial state.
876    pub fn reset(&mut self) {
877        for step in &mut self.steps {
878            step.reset();
879        }
880        self.base_prediction = 0.0;
881        self.base_initialized = false;
882        self.initial_targets.clear();
883        self.samples_seen = 0;
884        self.rng_state = self.config.seed;
885        self.auto_bandwidths.clear();
886        self.last_replacement_sum = 0;
887    }
888
889    /// Serialize the model into a [`ModelState`](crate::serde_support::ModelState).
890    ///
891    /// Auto-detects the [`LossType`](crate::loss::LossType) from the loss
892    /// function's [`Loss::loss_type()`] implementation.
893    ///
894    /// # Errors
895    ///
896    /// Returns [`IrithyllError::Serialization`](crate::IrithyllError::Serialization)
897    /// if the loss does not implement `loss_type()` (returns `None`). For custom
898    /// losses, use [`to_model_state_with`](Self::to_model_state_with) instead.
899    #[cfg(feature = "_serde_support")]
900    pub fn to_model_state(&self) -> crate::error::Result<crate::serde_support::ModelState> {
901        let loss_type = self.loss.loss_type().ok_or_else(|| {
902            crate::error::IrithyllError::Serialization(
903                "cannot auto-detect loss type for serialization: \
904                 implement Loss::loss_type() or use to_model_state_with()"
905                    .into(),
906            )
907        })?;
908        Ok(self.to_model_state_with(loss_type))
909    }
910
911    /// Serialize the model with an explicit [`LossType`](crate::loss::LossType) tag.
912    ///
913    /// Use this for custom loss functions that don't implement `loss_type()`.
914    #[cfg(feature = "_serde_support")]
915    pub fn to_model_state_with(
916        &self,
917        loss_type: crate::loss::LossType,
918    ) -> crate::serde_support::ModelState {
919        use crate::serde_support::{ModelState, StepSnapshot};
920
921        let steps = self
922            .steps
923            .iter()
924            .map(|step| {
925                let slot = step.slot();
926                let tree_snap = snapshot_tree(slot.active_tree());
927                let alt_snap = slot.alternate_tree().map(snapshot_tree);
928                let drift_state = slot.detector().serialize_state();
929                let alt_drift_state = slot.alt_detector().and_then(|d| d.serialize_state());
930                StepSnapshot {
931                    tree: tree_snap,
932                    alternate_tree: alt_snap,
933                    drift_state,
934                    alt_drift_state,
935                }
936            })
937            .collect();
938
939        ModelState {
940            config: self.config.clone(),
941            loss_type,
942            base_prediction: self.base_prediction,
943            base_initialized: self.base_initialized,
944            initial_targets: self.initial_targets.clone(),
945            initial_target_count: self.initial_target_count,
946            samples_seen: self.samples_seen,
947            rng_state: self.rng_state,
948            steps,
949            rolling_mean_error: self.rolling_mean_error,
950            contribution_ewma: self.contribution_ewma.clone(),
951            low_contrib_count: self.low_contrib_count.clone(),
952        }
953    }
954}
955
956// ---------------------------------------------------------------------------
957// DynSGBT: deserialization returns a dynamically-dispatched model
958// ---------------------------------------------------------------------------
959
960#[cfg(feature = "_serde_support")]
961impl SGBT<Box<dyn Loss>> {
962    /// Reconstruct an SGBT model from a [`ModelState`](crate::serde_support::ModelState).
963    ///
964    /// Returns a [`DynSGBT`] (`SGBT<Box<dyn Loss>>`) because the concrete
965    /// loss type is determined at runtime from the serialized tag.
966    ///
967    /// Rebuilds the full ensemble including tree topology and leaf values.
968    /// Histogram accumulators are left empty and will rebuild from continued
969    /// training. If drift detector state was serialized, it is restored;
970    /// otherwise a fresh detector is created from the config.
971    pub fn from_model_state(state: crate::serde_support::ModelState) -> Self {
972        use crate::ensemble::replacement::TreeSlot;
973
974        let loss = state.loss_type.into_loss();
975
976        let leaf_decay_alpha = state
977            .config
978            .leaf_half_life
979            .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
980        let max_tree_samples = state.config.max_tree_samples;
981
982        let steps: Vec<BoostingStep> = state
983            .steps
984            .iter()
985            .enumerate()
986            .map(|(i, step_snap)| {
987                let tree_config = TreeConfig::new()
988                    .max_depth(state.config.max_depth)
989                    .n_bins(state.config.n_bins)
990                    .lambda(state.config.lambda)
991                    .gamma(state.config.gamma)
992                    .grace_period(state.config.grace_period)
993                    .delta(state.config.delta)
994                    .feature_subsample_rate(state.config.feature_subsample_rate)
995                    .leaf_decay_alpha_opt(leaf_decay_alpha)
996                    .split_reeval_interval_opt(state.config.split_reeval_interval)
997                    .feature_types_opt(state.config.feature_types.clone())
998                    .gradient_clip_sigma_opt(state.config.gradient_clip_sigma)
999                    .monotone_constraints_opt(state.config.monotone_constraints.clone())
1000                    .leaf_model_type(state.config.leaf_model_type.clone())
1001                    .seed(state.config.seed ^ (i as u64));
1002
1003                let active = rebuild_tree(&step_snap.tree, tree_config.clone());
1004                let alternate = step_snap
1005                    .alternate_tree
1006                    .as_ref()
1007                    .map(|snap| rebuild_tree(snap, tree_config.clone()));
1008
1009                let mut detector = state.config.drift_detector.create();
1010                if let Some(ref ds) = step_snap.drift_state {
1011                    detector.restore_state(ds);
1012                }
1013                let mut slot = TreeSlot::from_trees(
1014                    active,
1015                    alternate,
1016                    tree_config,
1017                    detector,
1018                    max_tree_samples,
1019                );
1020                if let Some(ref ads) = step_snap.alt_drift_state {
1021                    if let Some(alt_det) = slot.alt_detector_mut() {
1022                        alt_det.restore_state(ads);
1023                    }
1024                }
1025                BoostingStep::from_slot(slot)
1026            })
1027            .collect();
1028
1029        let n = steps.len();
1030        let has_pruning = state.config.quality_prune_alpha.is_some();
1031
1032        // Restore pruning state if available, otherwise initialize
1033        let contribution_ewma = if !state.contribution_ewma.is_empty() {
1034            state.contribution_ewma
1035        } else if has_pruning {
1036            vec![0.0; n]
1037        } else {
1038            Vec::new()
1039        };
1040        let low_contrib_count = if !state.low_contrib_count.is_empty() {
1041            state.low_contrib_count
1042        } else if has_pruning {
1043            vec![0; n]
1044        } else {
1045            Vec::new()
1046        };
1047
1048        Self {
1049            config: state.config,
1050            steps,
1051            loss,
1052            base_prediction: state.base_prediction,
1053            base_initialized: state.base_initialized,
1054            initial_targets: state.initial_targets,
1055            initial_target_count: state.initial_target_count,
1056            samples_seen: state.samples_seen,
1057            rng_state: state.rng_state,
1058            contribution_ewma,
1059            low_contrib_count,
1060            rolling_mean_error: state.rolling_mean_error,
1061            auto_bandwidths: Vec::new(),
1062            last_replacement_sum: 0,
1063        }
1064    }
1065}
1066
1067// ---------------------------------------------------------------------------
1068// Shared snapshot/rebuild helpers for serde (used by SGBT + DistributionalSGBT)
1069// ---------------------------------------------------------------------------
1070
1071/// Snapshot a [`HoeffdingTree`] into a serializable [`TreeSnapshot`].
1072#[cfg(feature = "_serde_support")]
1073pub(crate) fn snapshot_tree(
1074    tree: &crate::tree::hoeffding::HoeffdingTree,
1075) -> crate::serde_support::TreeSnapshot {
1076    use crate::serde_support::TreeSnapshot;
1077    use crate::tree::StreamingTree;
1078    let arena = tree.arena();
1079    TreeSnapshot {
1080        feature_idx: arena.feature_idx.clone(),
1081        threshold: arena.threshold.clone(),
1082        left: arena.left.iter().map(|id| id.0).collect(),
1083        right: arena.right.iter().map(|id| id.0).collect(),
1084        leaf_value: arena.leaf_value.clone(),
1085        is_leaf: arena.is_leaf.clone(),
1086        depth: arena.depth.clone(),
1087        sample_count: arena.sample_count.clone(),
1088        n_features: tree.n_features(),
1089        samples_seen: tree.n_samples_seen(),
1090        rng_state: tree.rng_state(),
1091        categorical_mask: arena.categorical_mask.clone(),
1092    }
1093}
1094
1095/// Rebuild a [`HoeffdingTree`] from a [`TreeSnapshot`] and a [`TreeConfig`].
1096#[cfg(feature = "_serde_support")]
1097pub(crate) fn rebuild_tree(
1098    snapshot: &crate::serde_support::TreeSnapshot,
1099    tree_config: TreeConfig,
1100) -> crate::tree::hoeffding::HoeffdingTree {
1101    use crate::tree::hoeffding::HoeffdingTree;
1102    use crate::tree::node::{NodeId, TreeArena};
1103
1104    let mut arena = TreeArena::new();
1105    let n = snapshot.feature_idx.len();
1106
1107    for i in 0..n {
1108        arena.feature_idx.push(snapshot.feature_idx[i]);
1109        arena.threshold.push(snapshot.threshold[i]);
1110        arena.left.push(NodeId(snapshot.left[i]));
1111        arena.right.push(NodeId(snapshot.right[i]));
1112        arena.leaf_value.push(snapshot.leaf_value[i]);
1113        arena.is_leaf.push(snapshot.is_leaf[i]);
1114        arena.depth.push(snapshot.depth[i]);
1115        arena.sample_count.push(snapshot.sample_count[i]);
1116        let mask = snapshot.categorical_mask.get(i).copied().flatten();
1117        arena.categorical_mask.push(mask);
1118    }
1119
1120    HoeffdingTree::from_arena(
1121        tree_config,
1122        arena,
1123        snapshot.n_features,
1124        snapshot.samples_seen,
1125        snapshot.rng_state,
1126    )
1127}
1128
1129#[cfg(test)]
1130mod tests {
1131    use super::*;
1132    use alloc::boxed::Box;
1133    use alloc::vec;
1134    use alloc::vec::Vec;
1135
1136    fn default_config() -> SGBTConfig {
1137        SGBTConfig::builder()
1138            .n_steps(10)
1139            .learning_rate(0.1)
1140            .grace_period(20)
1141            .max_depth(4)
1142            .n_bins(16)
1143            .build()
1144            .unwrap()
1145    }
1146
1147    #[test]
1148    fn new_model_predicts_zero() {
1149        let model = SGBT::new(default_config());
1150        let pred = model.predict(&[1.0, 2.0, 3.0]);
1151        assert!(pred.abs() < 1e-12);
1152    }
1153
1154    #[test]
1155    fn train_one_does_not_panic() {
1156        let mut model = SGBT::new(default_config());
1157        model.train_one(&Sample::new(vec![1.0, 2.0, 3.0], 5.0));
1158        assert_eq!(model.n_samples_seen(), 1);
1159    }
1160
1161    #[test]
1162    fn prediction_changes_after_training() {
1163        let mut model = SGBT::new(default_config());
1164        let features = vec![1.0, 2.0, 3.0];
1165        for i in 0..100 {
1166            model.train_one(&Sample::new(features.clone(), (i as f64) * 0.1));
1167        }
1168        let pred = model.predict(&features);
1169        assert!(pred.is_finite());
1170    }
1171
1172    #[test]
1173    fn linear_signal_rmse_improves() {
1174        let config = SGBTConfig::builder()
1175            .n_steps(20)
1176            .learning_rate(0.1)
1177            .grace_period(10)
1178            .max_depth(3)
1179            .n_bins(16)
1180            .build()
1181            .unwrap();
1182        let mut model = SGBT::new(config);
1183
1184        let mut rng: u64 = 12345;
1185        let mut early_errors = Vec::new();
1186        let mut late_errors = Vec::new();
1187
1188        for i in 0..500 {
1189            rng ^= rng << 13;
1190            rng ^= rng >> 7;
1191            rng ^= rng << 17;
1192            let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1193            rng ^= rng << 13;
1194            rng ^= rng >> 7;
1195            rng ^= rng << 17;
1196            let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1197            let target = 2.0 * x1 + 3.0 * x2;
1198
1199            let pred = model.predict(&[x1, x2]);
1200            let error = (pred - target).powi(2);
1201
1202            if (50..150).contains(&i) {
1203                early_errors.push(error);
1204            }
1205            if i >= 400 {
1206                late_errors.push(error);
1207            }
1208
1209            model.train_one(&Sample::new(vec![x1, x2], target));
1210        }
1211
1212        let early_rmse = (early_errors.iter().sum::<f64>() / early_errors.len() as f64).sqrt();
1213        let late_rmse = (late_errors.iter().sum::<f64>() / late_errors.len() as f64).sqrt();
1214
1215        assert!(
1216            late_rmse < early_rmse,
1217            "RMSE should decrease: early={:.4}, late={:.4}",
1218            early_rmse,
1219            late_rmse
1220        );
1221    }
1222
1223    #[test]
1224    fn train_batch_equivalent_to_sequential() {
1225        let config = default_config();
1226        let mut model_seq = SGBT::new(config.clone());
1227        let mut model_batch = SGBT::new(config);
1228
1229        let samples: Vec<Sample> = (0..20)
1230            .map(|i| {
1231                let x = i as f64 * 0.5;
1232                Sample::new(vec![x, x * 2.0], x * 3.0)
1233            })
1234            .collect();
1235
1236        for s in &samples {
1237            model_seq.train_one(s);
1238        }
1239        model_batch.train_batch(&samples);
1240
1241        let pred_seq = model_seq.predict(&[1.0, 2.0]);
1242        let pred_batch = model_batch.predict(&[1.0, 2.0]);
1243
1244        assert!(
1245            (pred_seq - pred_batch).abs() < 1e-10,
1246            "seq={}, batch={}",
1247            pred_seq,
1248            pred_batch
1249        );
1250    }
1251
1252    #[test]
1253    fn reset_returns_to_initial() {
1254        let mut model = SGBT::new(default_config());
1255        for i in 0..100 {
1256            model.train_one(&Sample::new(vec![1.0, 2.0], i as f64));
1257        }
1258        model.reset();
1259        assert_eq!(model.n_samples_seen(), 0);
1260        assert!(!model.is_initialized());
1261        assert!(model.predict(&[1.0, 2.0]).abs() < 1e-12);
1262    }
1263
1264    #[test]
1265    fn base_prediction_initializes() {
1266        let mut model = SGBT::new(default_config());
1267        for i in 0..50 {
1268            model.train_one(&Sample::new(vec![1.0], i as f64 + 100.0));
1269        }
1270        assert!(model.is_initialized());
1271        let expected = (100.0 + 149.0) / 2.0;
1272        assert!((model.base_prediction() - expected).abs() < 1.0);
1273    }
1274
1275    #[test]
1276    fn with_loss_uses_custom_loss() {
1277        use crate::loss::logistic::LogisticLoss;
1278        let model = SGBT::with_loss(default_config(), LogisticLoss);
1279        let pred = model.predict_transformed(&[1.0, 2.0]);
1280        assert!(
1281            (pred - 0.5).abs() < 1e-6,
1282            "sigmoid(0) should be 0.5, got {}",
1283            pred
1284        );
1285    }
1286
1287    #[test]
1288    fn ewma_config_propagates_and_trains() {
1289        let config = SGBTConfig::builder()
1290            .n_steps(5)
1291            .learning_rate(0.1)
1292            .grace_period(10)
1293            .max_depth(3)
1294            .n_bins(16)
1295            .leaf_half_life(50)
1296            .build()
1297            .unwrap();
1298        let mut model = SGBT::new(config);
1299
1300        for i in 0..200 {
1301            let x = (i as f64) * 0.1;
1302            model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
1303        }
1304
1305        let pred = model.predict(&[1.0, 2.0]);
1306        assert!(
1307            pred.is_finite(),
1308            "EWMA-enabled model should produce finite predictions, got {}",
1309            pred
1310        );
1311    }
1312
1313    #[test]
1314    fn max_tree_samples_config_propagates() {
1315        let config = SGBTConfig::builder()
1316            .n_steps(5)
1317            .learning_rate(0.1)
1318            .grace_period(10)
1319            .max_depth(3)
1320            .n_bins(16)
1321            .max_tree_samples(200)
1322            .build()
1323            .unwrap();
1324        let mut model = SGBT::new(config);
1325
1326        for i in 0..500 {
1327            let x = (i as f64) * 0.1;
1328            model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
1329        }
1330
1331        let pred = model.predict(&[1.0, 2.0]);
1332        assert!(
1333            pred.is_finite(),
1334            "max_tree_samples model should produce finite predictions, got {}",
1335            pred
1336        );
1337    }
1338
1339    #[test]
1340    fn split_reeval_config_propagates() {
1341        let config = SGBTConfig::builder()
1342            .n_steps(5)
1343            .learning_rate(0.1)
1344            .grace_period(10)
1345            .max_depth(2)
1346            .n_bins(16)
1347            .split_reeval_interval(50)
1348            .build()
1349            .unwrap();
1350        let mut model = SGBT::new(config);
1351
1352        let mut rng: u64 = 12345;
1353        for _ in 0..1000 {
1354            rng ^= rng << 13;
1355            rng ^= rng >> 7;
1356            rng ^= rng << 17;
1357            let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1358            rng ^= rng << 13;
1359            rng ^= rng >> 7;
1360            rng ^= rng << 17;
1361            let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1362            let target = 2.0 * x1 + 3.0 * x2;
1363            model.train_one(&Sample::new(vec![x1, x2], target));
1364        }
1365
1366        let pred = model.predict(&[1.0, 2.0]);
1367        assert!(
1368            pred.is_finite(),
1369            "split re-eval model should produce finite predictions, got {}",
1370            pred
1371        );
1372    }
1373
1374    #[test]
1375    fn loss_accessor_works() {
1376        use crate::loss::logistic::LogisticLoss;
1377        let model = SGBT::with_loss(default_config(), LogisticLoss);
1378        // Verify we can access the concrete loss type
1379        let _loss: &LogisticLoss = model.loss();
1380        assert_eq!(_loss.n_outputs(), 1);
1381    }
1382
1383    #[test]
1384    fn clone_produces_independent_copy() {
1385        let config = default_config();
1386        let mut model = SGBT::new(config);
1387
1388        // Train the original on some data
1389        let mut rng: u64 = 99999;
1390        for _ in 0..200 {
1391            rng ^= rng << 13;
1392            rng ^= rng >> 7;
1393            rng ^= rng << 17;
1394            let x = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1395            let target = 2.0 * x + 1.0;
1396            model.train_one(&Sample::new(vec![x], target));
1397        }
1398
1399        // Clone the model
1400        let mut cloned = model.clone();
1401
1402        // Both should produce identical predictions
1403        let test_features = [3.0];
1404        let pred_original = model.predict(&test_features);
1405        let pred_cloned = cloned.predict(&test_features);
1406        assert!(
1407            (pred_original - pred_cloned).abs() < 1e-12,
1408            "clone should predict identically: original={pred_original}, cloned={pred_cloned}"
1409        );
1410
1411        // Train only the clone further -- models should diverge
1412        for _ in 0..200 {
1413            rng ^= rng << 13;
1414            rng ^= rng >> 7;
1415            rng ^= rng << 17;
1416            let x = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1417            let target = -3.0 * x + 5.0; // Different relationship
1418            cloned.train_one(&Sample::new(vec![x], target));
1419        }
1420
1421        let pred_original_after = model.predict(&test_features);
1422        let pred_cloned_after = cloned.predict(&test_features);
1423
1424        // Original should be unchanged
1425        assert!(
1426            (pred_original - pred_original_after).abs() < 1e-12,
1427            "original should be unchanged after training clone"
1428        );
1429
1430        // Clone should have diverged
1431        assert!(
1432            (pred_original_after - pred_cloned_after).abs() > 1e-6,
1433            "clone should diverge after independent training"
1434        );
1435    }
1436
1437    // -------------------------------------------------------------------
1438    // predict_with_confidence returns finite values
1439    // -------------------------------------------------------------------
1440    #[test]
1441    fn predict_with_confidence_finite() {
1442        let config = SGBTConfig::builder()
1443            .n_steps(5)
1444            .grace_period(10)
1445            .build()
1446            .unwrap();
1447        let mut model = SGBT::new(config);
1448
1449        // Train enough to initialize
1450        for i in 0..100 {
1451            let x = i as f64 * 0.1;
1452            model.train_one(&(&[x, x * 2.0][..], x + 1.0));
1453        }
1454
1455        let (pred, confidence) = model.predict_with_confidence(&[1.0, 2.0]);
1456        assert!(pred.is_finite(), "prediction should be finite");
1457        assert!(confidence.is_finite(), "confidence should be finite");
1458        assert!(
1459            confidence > 0.0,
1460            "confidence should be positive after training"
1461        );
1462    }
1463
1464    // -------------------------------------------------------------------
1465    // predict_with_confidence positive after training
1466    // -------------------------------------------------------------------
1467    #[test]
1468    fn predict_with_confidence_positive_after_training() {
1469        let config = SGBTConfig::builder()
1470            .n_steps(5)
1471            .grace_period(10)
1472            .build()
1473            .unwrap();
1474        let mut model = SGBT::new(config);
1475
1476        // Train enough to initialize and build structure
1477        for i in 0..200 {
1478            let x = i as f64 * 0.05;
1479            model.train_one(&(&[x][..], x * 2.0));
1480        }
1481
1482        let (pred, confidence) = model.predict_with_confidence(&[1.0]);
1483
1484        assert!(pred.is_finite(), "prediction should be finite");
1485        assert!(
1486            confidence > 0.0 && confidence.is_finite(),
1487            "confidence should be finite and positive, got {}",
1488            confidence,
1489        );
1490
1491        // Multiple queries should give consistent confidence
1492        let (pred2, conf2) = model.predict_with_confidence(&[1.0]);
1493        assert!(
1494            (pred - pred2).abs() < 1e-12,
1495            "same input should give same prediction"
1496        );
1497        assert!(
1498            (confidence - conf2).abs() < 1e-12,
1499            "same input should give same confidence"
1500        );
1501    }
1502
1503    // -------------------------------------------------------------------
1504    // predict_with_confidence agrees with predict on point estimate
1505    // -------------------------------------------------------------------
1506    #[test]
1507    fn predict_with_confidence_matches_predict() {
1508        let config = SGBTConfig::builder()
1509            .n_steps(10)
1510            .grace_period(10)
1511            .build()
1512            .unwrap();
1513        let mut model = SGBT::new(config);
1514
1515        for i in 0..200 {
1516            let x = (i as f64 - 100.0) * 0.01;
1517            model.train_one(&(&[x, x * x][..], x * 3.0 + 1.0));
1518        }
1519
1520        let pred = model.predict(&[0.5, 0.25]);
1521        let (conf_pred, _) = model.predict_with_confidence(&[0.5, 0.25]);
1522
1523        assert!(
1524            (pred - conf_pred).abs() < 1e-10,
1525            "prediction mismatch: predict()={} vs predict_with_confidence()={}",
1526            pred,
1527            conf_pred,
1528        );
1529    }
1530
1531    // -------------------------------------------------------------------
1532    // gradient clipping config round-trips through builder
1533    // -------------------------------------------------------------------
1534    #[test]
1535    fn gradient_clip_config_builder() {
1536        let config = SGBTConfig::builder()
1537            .n_steps(10)
1538            .gradient_clip_sigma(3.0)
1539            .build()
1540            .unwrap();
1541
1542        assert_eq!(config.gradient_clip_sigma, Some(3.0));
1543    }
1544
1545    // -------------------------------------------------------------------
1546    // monotonic constraints config round-trips through builder
1547    // -------------------------------------------------------------------
1548    #[test]
1549    fn monotone_constraints_config_builder() {
1550        let config = SGBTConfig::builder()
1551            .n_steps(10)
1552            .monotone_constraints(vec![1, -1, 0])
1553            .build()
1554            .unwrap();
1555
1556        assert_eq!(config.monotone_constraints, Some(vec![1, -1, 0]));
1557    }
1558
1559    // -------------------------------------------------------------------
1560    // monotonic constraints validation rejects invalid values
1561    // -------------------------------------------------------------------
1562    #[test]
1563    fn monotone_constraints_invalid_value_rejected() {
1564        let result = SGBTConfig::builder()
1565            .n_steps(10)
1566            .monotone_constraints(vec![1, 2, 0])
1567            .build();
1568
1569        assert!(result.is_err(), "constraint value 2 should be rejected");
1570    }
1571
1572    // -------------------------------------------------------------------
1573    // gradient clipping validation rejects non-positive sigma
1574    // -------------------------------------------------------------------
1575    #[test]
1576    fn gradient_clip_sigma_negative_rejected() {
1577        let result = SGBTConfig::builder()
1578            .n_steps(10)
1579            .gradient_clip_sigma(-1.0)
1580            .build();
1581
1582        assert!(result.is_err(), "negative sigma should be rejected");
1583    }
1584
1585    // -------------------------------------------------------------------
1586    // gradient clipping ensemble-level reduces outlier impact
1587    // -------------------------------------------------------------------
1588    #[test]
1589    fn gradient_clipping_reduces_outlier_impact() {
1590        // Without clipping
1591        let config_no_clip = SGBTConfig::builder()
1592            .n_steps(5)
1593            .grace_period(10)
1594            .build()
1595            .unwrap();
1596        let mut model_no_clip = SGBT::new(config_no_clip);
1597
1598        // With clipping
1599        let config_clip = SGBTConfig::builder()
1600            .n_steps(5)
1601            .grace_period(10)
1602            .gradient_clip_sigma(3.0)
1603            .build()
1604            .unwrap();
1605        let mut model_clip = SGBT::new(config_clip);
1606
1607        // Train both on identical normal data
1608        for i in 0..100 {
1609            let x = (i as f64) * 0.01;
1610            let sample = (&[x][..], x * 2.0);
1611            model_no_clip.train_one(&sample);
1612            model_clip.train_one(&sample);
1613        }
1614
1615        let pred_no_clip_before = model_no_clip.predict(&[0.5]);
1616        let pred_clip_before = model_clip.predict(&[0.5]);
1617
1618        // Inject outlier
1619        let outlier = (&[0.5_f64][..], 10000.0);
1620        model_no_clip.train_one(&outlier);
1621        model_clip.train_one(&outlier);
1622
1623        let pred_no_clip_after = model_no_clip.predict(&[0.5]);
1624        let pred_clip_after = model_clip.predict(&[0.5]);
1625
1626        let delta_no_clip = (pred_no_clip_after - pred_no_clip_before).abs();
1627        let delta_clip = (pred_clip_after - pred_clip_before).abs();
1628
1629        // Clipped model should be less affected by the outlier
1630        assert!(
1631            delta_clip <= delta_no_clip + 1e-10,
1632            "clipped model should be less affected: delta_clip={}, delta_no_clip={}",
1633            delta_clip,
1634            delta_no_clip,
1635        );
1636    }
1637
1638    // -------------------------------------------------------------------
1639    // train_batch_with_callback fires at correct intervals
1640    // -------------------------------------------------------------------
1641    #[test]
1642    fn train_batch_with_callback_fires() {
1643        let config = SGBTConfig::builder()
1644            .n_steps(3)
1645            .grace_period(5)
1646            .build()
1647            .unwrap();
1648        let mut model = SGBT::new(config);
1649
1650        let data: Vec<(Vec<f64>, f64)> = (0..25)
1651            .map(|i| (vec![i as f64 * 0.1], i as f64 * 0.5))
1652            .collect();
1653
1654        let mut callbacks = Vec::new();
1655        model.train_batch_with_callback(&data, 10, |n| {
1656            callbacks.push(n);
1657        });
1658
1659        // Should fire at 10, 20, and 25 (final)
1660        assert_eq!(callbacks, vec![10, 20, 25]);
1661    }
1662
1663    // -------------------------------------------------------------------
1664    // train_batch_subsampled produces deterministic subset
1665    // -------------------------------------------------------------------
1666    #[test]
1667    fn train_batch_subsampled_trains_subset() {
1668        let config = SGBTConfig::builder()
1669            .n_steps(3)
1670            .grace_period(5)
1671            .build()
1672            .unwrap();
1673        let mut model = SGBT::new(config);
1674
1675        let data: Vec<(Vec<f64>, f64)> = (0..100)
1676            .map(|i| (vec![i as f64 * 0.01], i as f64 * 0.1))
1677            .collect();
1678
1679        // Train on only 20 of 100 samples
1680        model.train_batch_subsampled(&data, 20);
1681
1682        // Model should have seen some samples
1683        assert!(
1684            model.n_samples_seen() > 0,
1685            "model should have trained on subset"
1686        );
1687        assert!(
1688            model.n_samples_seen() <= 20,
1689            "model should have trained at most 20 samples, got {}",
1690            model.n_samples_seen(),
1691        );
1692    }
1693
1694    // -------------------------------------------------------------------
1695    // train_batch_subsampled full dataset = train_batch
1696    // -------------------------------------------------------------------
1697    #[test]
1698    fn train_batch_subsampled_full_equals_batch() {
1699        let config1 = SGBTConfig::builder()
1700            .n_steps(3)
1701            .grace_period(5)
1702            .build()
1703            .unwrap();
1704        let config2 = config1.clone();
1705
1706        let mut model1 = SGBT::new(config1);
1707        let mut model2 = SGBT::new(config2);
1708
1709        let data: Vec<(Vec<f64>, f64)> = (0..50)
1710            .map(|i| (vec![i as f64 * 0.1], i as f64 * 0.5))
1711            .collect();
1712
1713        model1.train_batch(&data);
1714        model2.train_batch_subsampled(&data, 1000); // max_samples > data.len()
1715
1716        // Both should have identical state
1717        assert_eq!(model1.n_samples_seen(), model2.n_samples_seen());
1718        let pred1 = model1.predict(&[2.5]);
1719        let pred2 = model2.predict(&[2.5]);
1720        assert!(
1721            (pred1 - pred2).abs() < 1e-12,
1722            "full subsample should equal batch: {} vs {}",
1723            pred1,
1724            pred2,
1725        );
1726    }
1727
1728    // -------------------------------------------------------------------
1729    // train_batch_subsampled_with_callback combines both
1730    // -------------------------------------------------------------------
1731    #[test]
1732    fn train_batch_subsampled_with_callback_works() {
1733        let config = SGBTConfig::builder()
1734            .n_steps(3)
1735            .grace_period(5)
1736            .build()
1737            .unwrap();
1738        let mut model = SGBT::new(config);
1739
1740        let data: Vec<(Vec<f64>, f64)> = (0..200)
1741            .map(|i| (vec![i as f64 * 0.01], i as f64 * 0.1))
1742            .collect();
1743
1744        let mut callbacks = Vec::new();
1745        model.train_batch_subsampled_with_callback(&data, 50, 10, |n| {
1746            callbacks.push(n);
1747        });
1748
1749        // Should have trained ~50 samples with callbacks at 10, 20, 30, 40, 50
1750        assert!(!callbacks.is_empty(), "should have received callbacks");
1751        assert_eq!(
1752            *callbacks.last().unwrap(),
1753            50,
1754            "final callback should be total samples"
1755        );
1756    }
1757
1758    // ---------------------------------------------------------------
1759    // Linear leaf model integration tests
1760    // ---------------------------------------------------------------
1761
1762    /// xorshift64 PRNG for deterministic test data.
1763    fn xorshift64(state: &mut u64) -> u64 {
1764        let mut s = *state;
1765        s ^= s << 13;
1766        s ^= s >> 7;
1767        s ^= s << 17;
1768        *state = s;
1769        s
1770    }
1771
1772    fn rand_f64(state: &mut u64) -> f64 {
1773        xorshift64(state) as f64 / u64::MAX as f64
1774    }
1775
1776    fn linear_leaves_config() -> SGBTConfig {
1777        SGBTConfig::builder()
1778            .n_steps(10)
1779            .learning_rate(0.1)
1780            .grace_period(20)
1781            .max_depth(2) // low depth -- linear leaves should shine
1782            .n_bins(16)
1783            .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1784                learning_rate: 0.1,
1785                decay: None,
1786                use_adagrad: false,
1787            })
1788            .build()
1789            .unwrap()
1790    }
1791
1792    #[test]
1793    fn linear_leaves_trains_without_panic() {
1794        let mut model = SGBT::new(linear_leaves_config());
1795        let mut rng = 42u64;
1796        for _ in 0..200 {
1797            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1798            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1799            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1800            model.train_one(&Sample::new(vec![x1, x2], y));
1801        }
1802        assert_eq!(model.n_samples_seen(), 200);
1803    }
1804
1805    #[test]
1806    fn linear_leaves_prediction_finite() {
1807        let mut model = SGBT::new(linear_leaves_config());
1808        let mut rng = 42u64;
1809        for _ in 0..200 {
1810            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1811            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1812            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1813            model.train_one(&Sample::new(vec![x1, x2], y));
1814        }
1815        let pred = model.predict(&[0.5, -0.3]);
1816        assert!(pred.is_finite(), "prediction should be finite, got {pred}");
1817    }
1818
1819    #[test]
1820    fn linear_leaves_learns_linear_target() {
1821        let mut model = SGBT::new(linear_leaves_config());
1822        let mut rng = 42u64;
1823        for _ in 0..500 {
1824            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1825            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1826            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1827            model.train_one(&Sample::new(vec![x1, x2], y));
1828        }
1829
1830        // Test on a few points -- should be reasonably close for a linear target.
1831        let mut total_error = 0.0;
1832        for _ in 0..50 {
1833            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1834            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1835            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1836            let pred = model.predict(&[x1, x2]);
1837            total_error += (pred - y).powi(2);
1838        }
1839        let mse = total_error / 50.0;
1840        assert!(
1841            mse < 5.0,
1842            "linear leaves MSE on linear target should be < 5.0, got {mse}"
1843        );
1844    }
1845
1846    #[test]
1847    fn linear_leaves_better_than_constant_at_low_depth() {
1848        // Train two models on a linear target at depth 2:
1849        // one with constant leaves, one with linear leaves.
1850        let constant_config = SGBTConfig::builder()
1851            .n_steps(10)
1852            .learning_rate(0.1)
1853            .grace_period(20)
1854            .max_depth(2)
1855            .n_bins(16)
1856            .seed(0xDEAD)
1857            .build()
1858            .unwrap();
1859        let linear_config = SGBTConfig::builder()
1860            .n_steps(10)
1861            .learning_rate(0.1)
1862            .grace_period(20)
1863            .max_depth(2)
1864            .n_bins(16)
1865            .seed(0xDEAD)
1866            .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1867                learning_rate: 0.1,
1868                decay: None,
1869                use_adagrad: false,
1870            })
1871            .build()
1872            .unwrap();
1873
1874        let mut constant_model = SGBT::new(constant_config);
1875        let mut linear_model = SGBT::new(linear_config);
1876        let mut rng = 42u64;
1877
1878        for _ in 0..500 {
1879            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1880            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1881            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1882            let sample = Sample::new(vec![x1, x2], y);
1883            constant_model.train_one(&sample);
1884            linear_model.train_one(&sample);
1885        }
1886
1887        // Evaluate both on test set.
1888        let mut constant_mse = 0.0;
1889        let mut linear_mse = 0.0;
1890        for _ in 0..100 {
1891            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1892            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1893            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1894            constant_mse += (constant_model.predict(&[x1, x2]) - y).powi(2);
1895            linear_mse += (linear_model.predict(&[x1, x2]) - y).powi(2);
1896        }
1897        constant_mse /= 100.0;
1898        linear_mse /= 100.0;
1899
1900        // Linear leaves should outperform constant leaves on a linear target.
1901        assert!(
1902            linear_mse < constant_mse,
1903            "linear leaves MSE ({linear_mse:.4}) should be less than constant ({constant_mse:.4})"
1904        );
1905    }
1906
1907    #[test]
1908    fn adaptive_leaves_trains_without_panic() {
1909        let config = SGBTConfig::builder()
1910            .n_steps(10)
1911            .learning_rate(0.1)
1912            .grace_period(20)
1913            .max_depth(3)
1914            .n_bins(16)
1915            .leaf_model_type(crate::tree::leaf_model::LeafModelType::Adaptive {
1916                promote_to: Box::new(crate::tree::leaf_model::LeafModelType::Linear {
1917                    learning_rate: 0.1,
1918                    decay: None,
1919                    use_adagrad: false,
1920                }),
1921            })
1922            .build()
1923            .unwrap();
1924
1925        let mut model = SGBT::new(config);
1926        let mut rng = 42u64;
1927        for _ in 0..500 {
1928            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1929            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1930            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1931            model.train_one(&Sample::new(vec![x1, x2], y));
1932        }
1933        let pred = model.predict(&[0.5, -0.3]);
1934        assert!(
1935            pred.is_finite(),
1936            "adaptive leaf prediction should be finite, got {pred}"
1937        );
1938    }
1939
1940    #[test]
1941    fn linear_leaves_with_decay_trains_without_panic() {
1942        let config = SGBTConfig::builder()
1943            .n_steps(10)
1944            .learning_rate(0.1)
1945            .grace_period(20)
1946            .max_depth(3)
1947            .n_bins(16)
1948            .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1949                learning_rate: 0.1,
1950                decay: Some(0.995),
1951                use_adagrad: false,
1952            })
1953            .build()
1954            .unwrap();
1955
1956        let mut model = SGBT::new(config);
1957        let mut rng = 42u64;
1958        for _ in 0..500 {
1959            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1960            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1961            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1962            model.train_one(&Sample::new(vec![x1, x2], y));
1963        }
1964        let pred = model.predict(&[0.5, -0.3]);
1965        assert!(
1966            pred.is_finite(),
1967            "decay leaf prediction should be finite, got {pred}"
1968        );
1969    }
1970
1971    // -------------------------------------------------------------------
1972    // predict_smooth returns finite values
1973    // -------------------------------------------------------------------
1974    #[test]
1975    fn predict_smooth_returns_finite() {
1976        let config = SGBTConfig::builder()
1977            .n_steps(5)
1978            .learning_rate(0.1)
1979            .grace_period(10)
1980            .build()
1981            .unwrap();
1982        let mut model = SGBT::new(config);
1983
1984        for i in 0..200 {
1985            let x = (i as f64) * 0.1;
1986            model.train_one(&Sample::new(vec![x, x.sin()], 2.0 * x + 1.0));
1987        }
1988
1989        let pred_hard = model.predict(&[1.0, 1.0_f64.sin()]);
1990        let pred_smooth = model.predict_smooth(&[1.0, 1.0_f64.sin()], 0.5);
1991
1992        assert!(pred_hard.is_finite(), "hard prediction should be finite");
1993        assert!(
1994            pred_smooth.is_finite(),
1995            "smooth prediction should be finite"
1996        );
1997    }
1998
1999    // -------------------------------------------------------------------
2000    // predict_smooth converges to hard predict at small bandwidth
2001    // -------------------------------------------------------------------
2002    #[test]
2003    fn predict_smooth_converges_to_hard_at_small_bandwidth() {
2004        let config = SGBTConfig::builder()
2005            .n_steps(5)
2006            .learning_rate(0.1)
2007            .grace_period(10)
2008            .build()
2009            .unwrap();
2010        let mut model = SGBT::new(config);
2011
2012        for i in 0..300 {
2013            let x = (i as f64) * 0.1;
2014            model.train_one(&Sample::new(vec![x, x * 0.5], 2.0 * x + 1.0));
2015        }
2016
2017        let features = [5.0, 2.5];
2018        let hard = model.predict(&features);
2019        let smooth = model.predict_smooth(&features, 0.001);
2020
2021        assert!(
2022            (hard - smooth).abs() < 0.5,
2023            "smooth with tiny bandwidth should approximate hard: hard={}, smooth={}",
2024            hard,
2025            smooth,
2026        );
2027    }
2028
2029    #[test]
2030    fn auto_bandwidth_computed_after_training() {
2031        let config = SGBTConfig::builder()
2032            .n_steps(5)
2033            .learning_rate(0.1)
2034            .grace_period(10)
2035            .build()
2036            .unwrap();
2037        let mut model = SGBT::new(config);
2038
2039        // Before training, no bandwidths
2040        assert!(model.auto_bandwidths().is_empty());
2041
2042        for i in 0..200 {
2043            let x = (i as f64) * 0.1;
2044            model.train_one(&Sample::new(vec![x, x * 0.5], 2.0 * x + 1.0));
2045        }
2046
2047        // After training, auto_bandwidths should be populated
2048        let bws = model.auto_bandwidths();
2049        assert_eq!(bws.len(), 2, "should have 2 feature bandwidths");
2050
2051        // predict() always uses smooth routing with auto-bandwidths
2052        let pred = model.predict(&[5.0, 2.5]);
2053        assert!(
2054            pred.is_finite(),
2055            "auto-bandwidth predict should be finite: {}",
2056            pred
2057        );
2058    }
2059
2060    #[test]
2061    fn predict_interpolated_returns_finite() {
2062        let config = SGBTConfig::builder()
2063            .n_steps(5)
2064            .learning_rate(0.01)
2065            .build()
2066            .unwrap();
2067        let mut model = SGBT::new(config);
2068
2069        for i in 0..200 {
2070            let x = (i as f64) * 0.1;
2071            model.train_one(&Sample::new(vec![x, x.sin()], x.cos()));
2072        }
2073
2074        let pred = model.predict_interpolated(&[1.0, 0.5]);
2075        assert!(
2076            pred.is_finite(),
2077            "interpolated prediction should be finite: {}",
2078            pred
2079        );
2080    }
2081
2082    #[test]
2083    fn predict_sibling_interpolated_varies_with_features() {
2084        let config = SGBTConfig::builder()
2085            .n_steps(10)
2086            .learning_rate(0.1)
2087            .grace_period(10)
2088            .max_depth(6)
2089            .delta(0.1)
2090            .build()
2091            .unwrap();
2092        let mut model = SGBT::new(config);
2093
2094        for i in 0..2000 {
2095            let x = (i as f64) * 0.01;
2096            let y = x.sin() * x + 0.5 * (x * 2.0).cos();
2097            model.train_one(&Sample::new(vec![x, x * 0.3], y));
2098        }
2099
2100        // Verify the method is callable and produces finite predictions
2101        let pred = model.predict_sibling_interpolated(&[5.0, 1.5]);
2102        assert!(pred.is_finite(), "sibling interpolated should be finite");
2103
2104        // If bandwidths are finite, verify sibling produces at least as much
2105        // variation as hard routing across a feature sweep
2106        let bws = model.auto_bandwidths();
2107        if bws.iter().any(|&b| b.is_finite()) {
2108            let hard: Vec<f64> = (0..200)
2109                .map(|i| model.predict(&[i as f64 * 0.1, i as f64 * 0.03]))
2110                .collect();
2111            let sib: Vec<f64> = (0..200)
2112                .map(|i| model.predict_sibling_interpolated(&[i as f64 * 0.1, i as f64 * 0.03]))
2113                .collect();
2114            let hc = hard
2115                .windows(2)
2116                .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2117                .count();
2118            let sc = sib
2119                .windows(2)
2120                .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2121                .count();
2122            assert!(
2123                sc >= hc,
2124                "sibling should produce >= hard changes: sib={}, hard={}",
2125                sc,
2126                hc
2127            );
2128        }
2129    }
2130
2131    #[test]
2132    fn predict_graduated_returns_finite() {
2133        let config = SGBTConfig::builder()
2134            .n_steps(5)
2135            .learning_rate(0.01)
2136            .max_tree_samples(200)
2137            .shadow_warmup(50)
2138            .build()
2139            .unwrap();
2140        let mut model = SGBT::new(config);
2141
2142        for i in 0..300 {
2143            let x = (i as f64) * 0.1;
2144            model.train_one(&Sample::new(vec![x, x.sin()], x.cos()));
2145        }
2146
2147        let pred = model.predict_graduated(&[1.0, 0.5]);
2148        assert!(
2149            pred.is_finite(),
2150            "graduated prediction should be finite: {}",
2151            pred
2152        );
2153
2154        let pred2 = model.predict_graduated_sibling_interpolated(&[1.0, 0.5]);
2155        assert!(
2156            pred2.is_finite(),
2157            "graduated+sibling prediction should be finite: {}",
2158            pred2
2159        );
2160    }
2161
2162    #[test]
2163    fn shadow_warmup_validation() {
2164        let result = SGBTConfig::builder()
2165            .n_steps(5)
2166            .learning_rate(0.01)
2167            .shadow_warmup(0)
2168            .build();
2169        assert!(result.is_err(), "shadow_warmup=0 should fail validation");
2170    }
2171}