Skip to main content

irithyll_core/ensemble/
mod.rs

1//! SGBT ensemble orchestrator -- the core boosting loop.
2//!
3//! Implements Streaming Gradient Boosted Trees (Gunasekara et al., 2024):
4//! a sequence of boosting steps, each owning a streaming tree and drift detector,
5//! with automatic tree replacement when concept drift is detected.
6//!
7//! # Algorithm
8//!
9//! For each incoming sample `(x, y)`:
10//! 1. Compute the current ensemble prediction: `F(x) = base + lr * Σ tree_s(x)`
11//! 2. For each boosting step `s = 1..N`:
12//!    - Compute gradient `g = loss.gradient(y, current_pred)`
13//!    - Compute hessian `h = loss.hessian(y, current_pred)`
14//!    - Feed `(x, g, h)` to tree `s` (which internally uses weighted squared loss)
15//!    - Update `current_pred += lr * tree_s.predict(x)`
16//! 3. The ensemble adapts incrementally, with each tree targeting the residual
17//!    of all preceding trees.
18
19pub mod adaptive;
20pub mod adaptive_forest;
21pub mod bagged;
22pub mod config;
23pub mod distributional;
24pub mod lr_schedule;
25pub mod moe;
26pub mod moe_distributional;
27pub mod multi_target;
28pub mod multiclass;
29#[cfg(feature = "parallel")]
30pub mod parallel;
31pub mod quantile_regressor;
32pub mod replacement;
33pub mod stacked;
34pub mod step;
35pub mod variants;
36
37use alloc::boxed::Box;
38use alloc::string::String;
39use alloc::vec;
40use alloc::vec::Vec;
41
42use core::fmt;
43
44use crate::ensemble::config::SGBTConfig;
45use crate::ensemble::step::BoostingStep;
46use crate::loss::squared::SquaredLoss;
47use crate::loss::Loss;
48use crate::sample::Observation;
49#[allow(unused_imports)] // Used in doc links + tests
50use crate::sample::Sample;
51use crate::tree::builder::TreeConfig;
52
53/// Type alias for an SGBT model using dynamic (boxed) loss dispatch.
54///
55/// Use this when the loss function is determined at runtime (e.g., when
56/// deserializing a model from JSON where the loss type is stored as a tag).
57///
58/// For compile-time loss dispatch (preferred for performance), use
59/// `SGBT<LogisticLoss>`, `SGBT<HuberLoss>`, etc.
60pub type DynSGBT = SGBT<Box<dyn Loss>>;
61
62/// Streaming Gradient Boosted Trees ensemble.
63///
64/// The primary entry point for training and prediction. Generic over `L: Loss`
65/// so the loss function's gradient/hessian calls are monomorphized (inlined)
66/// into the boosting hot loop -- no virtual dispatch overhead.
67///
68/// The default type parameter `L = SquaredLoss` means `SGBT::new(config)`
69/// creates a regression model without specifying the loss type explicitly.
70///
71/// # Examples
72///
73/// ```text
74/// use irithyll::{SGBTConfig, SGBT};
75///
76/// // Regression with squared loss (default):
77/// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
78/// let model = SGBT::new(config);
79/// ```ignore
80///
81/// ```text
82/// use irithyll::{SGBTConfig, SGBT};
83/// use irithyll::loss::logistic::LogisticLoss;
84///
85/// // Classification with logistic loss -- no Box::new()!
86/// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
87/// let model = SGBT::with_loss(config, LogisticLoss);
88/// ```
89pub struct SGBT<L: Loss = SquaredLoss> {
90    /// Configuration.
91    config: SGBTConfig,
92    /// Boosting steps (one tree + drift detector each).
93    steps: Vec<BoostingStep>,
94    /// Loss function (monomorphized -- no vtable).
95    loss: L,
96    /// Base prediction (initial constant, computed from first batch of targets).
97    base_prediction: f64,
98    /// Whether base_prediction has been initialized.
99    base_initialized: bool,
100    /// Running collection of initial targets for computing base_prediction.
101    initial_targets: Vec<f64>,
102    /// Number of initial targets to collect before setting base_prediction.
103    initial_target_count: usize,
104    /// Total samples trained.
105    samples_seen: u64,
106    /// RNG state for variant skip logic.
107    rng_state: u64,
108    /// Per-step EWMA of |marginal contribution| for quality-based pruning.
109    /// Empty when `quality_prune_alpha` is `None`.
110    contribution_ewma: Vec<f64>,
111    /// Per-step consecutive low-contribution sample counter.
112    /// Empty when `quality_prune_alpha` is `None`.
113    low_contrib_count: Vec<u64>,
114    /// Rolling mean absolute error for error-weighted sample importance.
115    /// Only used when `error_weight_alpha` is `Some`.
116    rolling_mean_error: f64,
117    /// Per-feature auto-calibrated bandwidths for smooth prediction.
118    /// Computed from median split threshold gaps across all trees.
119    auto_bandwidths: Vec<f64>,
120    /// Sum of replacement counts across all steps at last bandwidth computation.
121    /// Used to detect when trees have been replaced and bandwidths need refresh.
122    last_replacement_sum: u64,
123}
124
125impl<L: Loss + Clone> Clone for SGBT<L> {
126    fn clone(&self) -> Self {
127        Self {
128            config: self.config.clone(),
129            steps: self.steps.clone(),
130            loss: self.loss.clone(),
131            base_prediction: self.base_prediction,
132            base_initialized: self.base_initialized,
133            initial_targets: self.initial_targets.clone(),
134            initial_target_count: self.initial_target_count,
135            samples_seen: self.samples_seen,
136            rng_state: self.rng_state,
137            contribution_ewma: self.contribution_ewma.clone(),
138            low_contrib_count: self.low_contrib_count.clone(),
139            rolling_mean_error: self.rolling_mean_error,
140            auto_bandwidths: self.auto_bandwidths.clone(),
141            last_replacement_sum: self.last_replacement_sum,
142        }
143    }
144}
145
146impl<L: Loss> fmt::Debug for SGBT<L> {
147    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148        f.debug_struct("SGBT")
149            .field("n_steps", &self.steps.len())
150            .field("samples_seen", &self.samples_seen)
151            .field("base_prediction", &self.base_prediction)
152            .field("base_initialized", &self.base_initialized)
153            .finish()
154    }
155}
156
157// ---------------------------------------------------------------------------
158// Convenience constructor for the default loss (SquaredLoss)
159// ---------------------------------------------------------------------------
160
161impl SGBT<SquaredLoss> {
162    /// Create a new SGBT ensemble with squared loss (regression).
163    ///
164    /// This is the most common constructor. For classification or custom
165    /// losses, use [`with_loss`](SGBT::with_loss).
166    pub fn new(config: SGBTConfig) -> Self {
167        Self::with_loss(config, SquaredLoss)
168    }
169}
170
171// ---------------------------------------------------------------------------
172// General impl for all Loss types
173// ---------------------------------------------------------------------------
174
175impl<L: Loss> SGBT<L> {
176    /// Create a new SGBT ensemble with a specific loss function.
177    ///
178    /// The loss is stored by value (monomorphized), giving zero-cost
179    /// gradient/hessian dispatch.
180    ///
181    /// ```ignore
182    /// use irithyll::{SGBTConfig, SGBT};
183    /// use irithyll::loss::logistic::LogisticLoss;
184    ///
185    /// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
186    /// let model = SGBT::with_loss(config, LogisticLoss);
187    /// ```
188    pub fn with_loss(config: SGBTConfig, loss: L) -> Self {
189        let leaf_decay_alpha = config
190            .leaf_half_life
191            .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
192
193        let tree_config = TreeConfig::new()
194            .max_depth(config.max_depth)
195            .n_bins(config.n_bins)
196            .lambda(config.lambda)
197            .gamma(config.gamma)
198            .grace_period(config.grace_period)
199            .delta(config.delta)
200            .feature_subsample_rate(config.feature_subsample_rate)
201            .leaf_decay_alpha_opt(leaf_decay_alpha)
202            .split_reeval_interval_opt(config.split_reeval_interval)
203            .feature_types_opt(config.feature_types.clone())
204            .gradient_clip_sigma_opt(config.gradient_clip_sigma)
205            .monotone_constraints_opt(config.monotone_constraints.clone())
206            .max_leaf_output_opt(config.max_leaf_output)
207            .adaptive_leaf_bound_opt(config.adaptive_leaf_bound)
208            .adaptive_depth_opt(config.adaptive_depth)
209            .min_hessian_sum_opt(config.min_hessian_sum)
210            .leaf_model_type(config.leaf_model_type.clone());
211
212        let max_tree_samples = config.max_tree_samples;
213
214        let shadow_warmup = config.shadow_warmup.unwrap_or(0);
215        let steps: Vec<BoostingStep> = (0..config.n_steps)
216            .map(|i| {
217                let mut tc = tree_config.clone();
218                tc.seed = config.seed ^ (i as u64);
219                let detector = config.drift_detector.create();
220                if shadow_warmup > 0 {
221                    BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
222                } else {
223                    BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
224                }
225            })
226            .collect();
227
228        let seed = config.seed;
229        let initial_target_count = config.initial_target_count;
230        let n = config.n_steps;
231        let has_pruning = config.quality_prune_alpha.is_some();
232        Self {
233            config,
234            steps,
235            loss,
236            base_prediction: 0.0,
237            base_initialized: false,
238            initial_targets: Vec::new(),
239            initial_target_count,
240            samples_seen: 0,
241            rng_state: seed,
242            contribution_ewma: if has_pruning {
243                vec![0.0; n]
244            } else {
245                Vec::new()
246            },
247            low_contrib_count: if has_pruning { vec![0; n] } else { Vec::new() },
248            rolling_mean_error: 0.0,
249            auto_bandwidths: Vec::new(),
250            last_replacement_sum: 0,
251        }
252    }
253
254    /// Train on a single observation.
255    ///
256    /// Accepts any type implementing [`Observation`], including [`Sample`],
257    /// [`SampleRef`](crate::SampleRef), or tuples like `(&[f64], f64)` for
258    /// zero-copy training.
259    pub fn train_one(&mut self, sample: &impl Observation) {
260        self.samples_seen += 1;
261        let target = sample.target();
262        let features = sample.features();
263
264        // Initialize base prediction from first few targets
265        if !self.base_initialized {
266            self.initial_targets.push(target);
267            if self.initial_targets.len() >= self.initial_target_count {
268                self.base_prediction = self.loss.initial_prediction(&self.initial_targets);
269                self.base_initialized = true;
270                self.initial_targets.clear();
271                self.initial_targets.shrink_to_fit();
272            }
273        }
274
275        // Current prediction starts from base
276        let mut current_pred = self.base_prediction;
277
278        let prune_alpha = self.config.quality_prune_alpha;
279        let prune_threshold = self.config.quality_prune_threshold;
280        let prune_patience = self.config.quality_prune_patience;
281
282        // Error-weighted sample importance: compute weight from prediction error
283        let error_weight = if let Some(ew_alpha) = self.config.error_weight_alpha {
284            let abs_error = crate::math::abs(target - current_pred);
285            if self.rolling_mean_error > 1e-15 {
286                let w = (1.0 + abs_error / (self.rolling_mean_error + 1e-15)).min(10.0);
287                self.rolling_mean_error =
288                    ew_alpha * abs_error + (1.0 - ew_alpha) * self.rolling_mean_error;
289                w
290            } else {
291                self.rolling_mean_error = abs_error.max(1e-15);
292                1.0 // first sample, no reweighting
293            }
294        } else {
295            1.0
296        };
297
298        // Sequential boosting: each step targets the residual of all prior steps
299        for s in 0..self.steps.len() {
300            let gradient = self.loss.gradient(target, current_pred) * error_weight;
301            let hessian = self.loss.hessian(target, current_pred) * error_weight;
302            let train_count = self
303                .config
304                .variant
305                .train_count(hessian, &mut self.rng_state);
306
307            let step_pred =
308                self.steps[s].train_and_predict(features, gradient, hessian, train_count);
309
310            current_pred += self.config.learning_rate * step_pred;
311
312            // Quality-based tree pruning: track contribution and replace dead wood
313            if let Some(alpha) = prune_alpha {
314                let contribution = crate::math::abs(self.config.learning_rate * step_pred);
315                self.contribution_ewma[s] =
316                    alpha * contribution + (1.0 - alpha) * self.contribution_ewma[s];
317
318                if self.contribution_ewma[s] < prune_threshold {
319                    self.low_contrib_count[s] += 1;
320                    if self.low_contrib_count[s] >= prune_patience {
321                        self.steps[s].reset();
322                        self.contribution_ewma[s] = 0.0;
323                        self.low_contrib_count[s] = 0;
324                    }
325                } else {
326                    self.low_contrib_count[s] = 0;
327                }
328            }
329        }
330
331        // Refresh auto-bandwidths when trees have been replaced or not yet computed.
332        self.refresh_bandwidths();
333    }
334
335    /// Train on a batch of observations.
336    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
337        for sample in samples {
338            self.train_one(sample);
339        }
340    }
341
342    /// Train on a batch with periodic callback for cooperative yielding.
343    ///
344    /// The callback is invoked every `interval` samples with the number of
345    /// samples processed so far. This allows long-running training to yield
346    /// to other tasks in an async runtime, update progress bars, or perform
347    /// periodic checkpointing.
348    ///
349    /// # Example
350    ///
351    /// ```ignore
352    /// use irithyll::{SGBTConfig, SGBT};
353    ///
354    /// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
355    /// let mut model = SGBT::new(config);
356    /// let data: Vec<(Vec<f64>, f64)> = Vec::new(); // your data
357    ///
358    /// model.train_batch_with_callback(&data, 1000, |processed| {
359    ///     println!("Trained {} samples", processed);
360    /// });
361    /// ```
362    pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
363        &mut self,
364        samples: &[O],
365        interval: usize,
366        mut callback: F,
367    ) {
368        let interval = interval.max(1); // Prevent zero interval
369        for (i, sample) in samples.iter().enumerate() {
370            self.train_one(sample);
371            if (i + 1) % interval == 0 {
372                callback(i + 1);
373            }
374        }
375        // Final callback if the total isn't a multiple of interval
376        let total = samples.len();
377        if total % interval != 0 {
378            callback(total);
379        }
380    }
381
382    /// Train on a random subsample of a batch using reservoir sampling.
383    ///
384    /// When `max_samples < samples.len()`, selects a representative subset
385    /// using Algorithm R (Vitter, 1985) -- a uniform random sample without
386    /// replacement. The selected samples are then trained in their original
387    /// order to preserve sequential dependencies.
388    ///
389    /// This is ideal for large replay buffers where training on the full
390    /// dataset is prohibitively slow but a representative subset gives
391    /// equivalent model quality (e.g., 1M of 4.3M samples with R²=0.997).
392    ///
393    /// When `max_samples >= samples.len()`, all samples are trained.
394    pub fn train_batch_subsampled<O: Observation>(&mut self, samples: &[O], max_samples: usize) {
395        if max_samples >= samples.len() {
396            self.train_batch(samples);
397            return;
398        }
399
400        // Reservoir sampling (Algorithm R) to select indices
401        let mut reservoir: Vec<usize> = (0..max_samples).collect();
402        let mut rng = self.rng_state;
403
404        for i in max_samples..samples.len() {
405            // Generate random index in [0, i]
406            rng ^= rng << 13;
407            rng ^= rng >> 7;
408            rng ^= rng << 17;
409            let j = (rng % (i as u64 + 1)) as usize;
410            if j < max_samples {
411                reservoir[j] = i;
412            }
413        }
414
415        self.rng_state = rng;
416
417        // Sort to preserve original order (important for EWMA/drift state)
418        reservoir.sort_unstable();
419
420        // Train on the selected subset
421        for &idx in &reservoir {
422            self.train_one(&samples[idx]);
423        }
424    }
425
426    /// Train on a batch with both subsampling and periodic callbacks.
427    ///
428    /// Combines reservoir subsampling with cooperative yield points.
429    /// Ideal for long-running daemon training where you need both
430    /// efficiency (subsampling) and cooperation (yielding).
431    pub fn train_batch_subsampled_with_callback<O: Observation, F: FnMut(usize)>(
432        &mut self,
433        samples: &[O],
434        max_samples: usize,
435        interval: usize,
436        mut callback: F,
437    ) {
438        if max_samples >= samples.len() {
439            self.train_batch_with_callback(samples, interval, callback);
440            return;
441        }
442
443        // Reservoir sampling
444        let mut reservoir: Vec<usize> = (0..max_samples).collect();
445        let mut rng = self.rng_state;
446
447        for i in max_samples..samples.len() {
448            rng ^= rng << 13;
449            rng ^= rng >> 7;
450            rng ^= rng << 17;
451            let j = (rng % (i as u64 + 1)) as usize;
452            if j < max_samples {
453                reservoir[j] = i;
454            }
455        }
456
457        self.rng_state = rng;
458        reservoir.sort_unstable();
459
460        let interval = interval.max(1);
461        for (i, &idx) in reservoir.iter().enumerate() {
462            self.train_one(&samples[idx]);
463            if (i + 1) % interval == 0 {
464                callback(i + 1);
465            }
466        }
467        let total = reservoir.len();
468        if total % interval != 0 {
469            callback(total);
470        }
471    }
472
473    /// Predict the raw output for a feature vector.
474    ///
475    /// Always uses sigmoid-blended soft routing with auto-calibrated per-feature
476    /// bandwidths derived from median split threshold gaps. Features that have
477    /// never been split on use hard routing (bandwidth = infinity).
478    pub fn predict(&self, features: &[f64]) -> f64 {
479        let mut pred = self.base_prediction;
480        if self.auto_bandwidths.is_empty() {
481            // No bandwidths computed yet (no training) — hard routing fallback
482            for step in &self.steps {
483                pred += self.config.learning_rate * step.predict(features);
484            }
485        } else {
486            for step in &self.steps {
487                pred += self.config.learning_rate
488                    * step.predict_smooth_auto(features, &self.auto_bandwidths);
489            }
490        }
491        pred
492    }
493
494    /// Predict using sigmoid-blended soft routing with an explicit bandwidth.
495    ///
496    /// Uses a single bandwidth for all features. For auto-calibrated per-feature
497    /// bandwidths, use [`predict()`](SGBT::predict) which always uses smooth routing.
498    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
499        let mut pred = self.base_prediction;
500        for step in &self.steps {
501            pred += self.config.learning_rate * step.predict_smooth(features, bandwidth);
502        }
503        pred
504    }
505
506    /// Per-feature auto-calibrated bandwidths used by `predict()`.
507    ///
508    /// Empty before the first training sample. Each entry corresponds to a
509    /// feature index; `f64::INFINITY` means that feature has no splits and
510    /// uses hard routing.
511    pub fn auto_bandwidths(&self) -> &[f64] {
512        &self.auto_bandwidths
513    }
514
515    /// Predict with parent-leaf linear interpolation.
516    ///
517    /// Blends each leaf prediction with its parent's preserved prediction
518    /// based on sample count, preventing stale predictions from fresh leaves.
519    pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
520        let mut pred = self.base_prediction;
521        for step in &self.steps {
522            pred += self.config.learning_rate * step.predict_interpolated(features);
523        }
524        pred
525    }
526
527    /// Predict with sibling-based interpolation for feature-continuous predictions.
528    ///
529    /// At each split node near the threshold boundary, blends left and right
530    /// subtree predictions linearly based on distance from the threshold.
531    /// Uses auto-calibrated bandwidths as the interpolation margin.
532    /// Predictions vary continuously as features change, eliminating
533    /// step-function artifacts.
534    pub fn predict_sibling_interpolated(&self, features: &[f64]) -> f64 {
535        let mut pred = self.base_prediction;
536        for step in &self.steps {
537            pred += self.config.learning_rate
538                * step.predict_sibling_interpolated(features, &self.auto_bandwidths);
539        }
540        pred
541    }
542
543    /// Predict with graduated active-shadow blending.
544    ///
545    /// Smoothly transitions between active and shadow trees during replacement,
546    /// eliminating prediction dips. Requires `shadow_warmup` to be configured.
547    /// When disabled, equivalent to `predict()`.
548    pub fn predict_graduated(&self, features: &[f64]) -> f64 {
549        let mut pred = self.base_prediction;
550        for step in &self.steps {
551            pred += self.config.learning_rate * step.predict_graduated(features);
552        }
553        pred
554    }
555
556    /// Predict with graduated blending + sibling interpolation (premium path).
557    ///
558    /// Combines graduated active-shadow handoff (no prediction dips during
559    /// tree replacement) with feature-continuous sibling interpolation
560    /// (no step-function artifacts near split boundaries).
561    pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> f64 {
562        let mut pred = self.base_prediction;
563        for step in &self.steps {
564            pred += self.config.learning_rate
565                * step.predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
566        }
567        pred
568    }
569
570    /// Predict with loss transform applied (e.g., sigmoid for logistic loss).
571    pub fn predict_transformed(&self, features: &[f64]) -> f64 {
572        self.loss.predict_transform(self.predict(features))
573    }
574
575    /// Predict probability (alias for `predict_transformed`).
576    pub fn predict_proba(&self, features: &[f64]) -> f64 {
577        self.predict_transformed(features)
578    }
579
580    /// Predict with confidence estimation.
581    ///
582    /// Returns `(prediction, confidence)` where confidence = 1 / sqrt(sum_variance).
583    /// Higher confidence indicates more certain predictions (leaves have seen
584    /// more hessian mass). Confidence of 0.0 means the model has no information.
585    ///
586    /// This enables execution engines to modulate aggressiveness:
587    /// - High confidence + favorable prediction → act immediately
588    /// - Low confidence → fall back to simpler models or wait for more data
589    ///
590    /// The variance per tree is estimated as `1 / (H_sum + lambda)` at the
591    /// leaf where the sample lands. The ensemble variance is the sum of
592    /// per-tree variances (scaled by learning_rate²), and confidence is
593    /// the reciprocal of the standard deviation.
594    pub fn predict_with_confidence(&self, features: &[f64]) -> (f64, f64) {
595        let mut pred = self.base_prediction;
596        let mut total_variance = 0.0;
597        let lr2 = self.config.learning_rate * self.config.learning_rate;
598
599        for step in &self.steps {
600            let (value, variance) = step.predict_with_variance(features);
601            pred += self.config.learning_rate * value;
602            total_variance += lr2 * variance;
603        }
604
605        let confidence = if total_variance > 0.0 && total_variance.is_finite() {
606            1.0 / crate::math::sqrt(total_variance)
607        } else {
608            0.0
609        };
610
611        (pred, confidence)
612    }
613
614    /// Batch prediction.
615    pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
616        feature_matrix.iter().map(|f| self.predict(f)).collect()
617    }
618
619    /// Number of boosting steps.
620    pub fn n_steps(&self) -> usize {
621        self.steps.len()
622    }
623
624    /// Total trees (active + alternates).
625    pub fn n_trees(&self) -> usize {
626        self.steps.len() + self.steps.iter().filter(|s| s.has_alternate()).count()
627    }
628
629    /// Total leaves across all active trees.
630    pub fn total_leaves(&self) -> usize {
631        self.steps.iter().map(|s| s.n_leaves()).sum()
632    }
633
634    /// Total samples trained.
635    pub fn n_samples_seen(&self) -> u64 {
636        self.samples_seen
637    }
638
639    /// The current base prediction.
640    pub fn base_prediction(&self) -> f64 {
641        self.base_prediction
642    }
643
644    /// Whether the base prediction has been initialized.
645    pub fn is_initialized(&self) -> bool {
646        self.base_initialized
647    }
648
649    /// Access the configuration.
650    pub fn config(&self) -> &SGBTConfig {
651        &self.config
652    }
653
654    /// Set the learning rate for future boosting rounds.
655    ///
656    /// This allows external schedulers (e.g., [`lr_schedule::LRScheduler`]) to
657    /// adapt the rate over time without rebuilding the model.
658    ///
659    /// # Arguments
660    ///
661    /// * `lr` -- New learning rate (should be positive and finite)
662    #[inline]
663    pub fn set_learning_rate(&mut self, lr: f64) {
664        self.config.learning_rate = lr;
665    }
666
667    /// Immutable access to the boosting steps.
668    ///
669    /// Useful for model inspection and export (e.g., ONNX serialization).
670    pub fn steps(&self) -> &[BoostingStep] {
671        &self.steps
672    }
673
674    /// Immutable access to the loss function.
675    pub fn loss(&self) -> &L {
676        &self.loss
677    }
678
679    /// Feature importances based on accumulated split gains across all trees.
680    ///
681    /// Returns normalized importances (sum to 1.0) indexed by feature.
682    /// Returns an empty Vec if no splits have occurred yet.
683    pub fn feature_importances(&self) -> Vec<f64> {
684        // Aggregate split gains across all boosting steps.
685        let mut totals: Vec<f64> = Vec::new();
686        for step in &self.steps {
687            let gains = step.slot().split_gains();
688            if totals.is_empty() && !gains.is_empty() {
689                totals.resize(gains.len(), 0.0);
690            }
691            for (i, &g) in gains.iter().enumerate() {
692                if i < totals.len() {
693                    totals[i] += g;
694                }
695            }
696        }
697
698        // Normalize to sum to 1.0.
699        let sum: f64 = totals.iter().sum();
700        if sum > 0.0 {
701            totals.iter_mut().for_each(|v| *v /= sum);
702        }
703        totals
704    }
705
706    /// Feature names, if configured.
707    pub fn feature_names(&self) -> Option<&[String]> {
708        self.config.feature_names.as_deref()
709    }
710
711    /// Feature importances paired with their names.
712    ///
713    /// Returns `None` if feature names are not configured. Otherwise returns
714    /// `(name, importance)` pairs sorted by importance descending.
715    pub fn named_feature_importances(&self) -> Option<Vec<(String, f64)>> {
716        let names = self.config.feature_names.as_ref()?;
717        let importances = self.feature_importances();
718        let mut pairs: Vec<(String, f64)> = names
719            .iter()
720            .zip(importances.iter().chain(core::iter::repeat(&0.0)))
721            .map(|(n, &v)| (n.clone(), v))
722            .collect();
723        pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
724        Some(pairs)
725    }
726
727    /// Train on a single sample with named features.
728    ///
729    /// Converts a `HashMap<String, f64>` of named features into a positional
730    /// vector using the configured feature names. Missing features default to 0.0.
731    ///
732    /// # Panics
733    ///
734    /// Panics if `feature_names` is not configured.
735    #[cfg(feature = "std")]
736    pub fn train_one_named(
737        &mut self,
738        features: &std::collections::HashMap<alloc::string::String, f64>,
739        target: f64,
740    ) {
741        let names = self
742            .config
743            .feature_names
744            .as_ref()
745            .expect("train_one_named requires feature_names to be configured");
746        let vec: Vec<f64> = names
747            .iter()
748            .map(|name| features.get(name).copied().unwrap_or(0.0))
749            .collect();
750        self.train_one(&(&vec[..], target));
751    }
752
753    /// Predict with named features.
754    ///
755    /// Converts named features into a positional vector, same as `train_one_named`.
756    ///
757    /// # Panics
758    ///
759    /// Panics if `feature_names` is not configured.
760    #[cfg(feature = "std")]
761    pub fn predict_named(
762        &self,
763        features: &std::collections::HashMap<alloc::string::String, f64>,
764    ) -> f64 {
765        let names = self
766            .config
767            .feature_names
768            .as_ref()
769            .expect("predict_named requires feature_names to be configured");
770        let vec: Vec<f64> = names
771            .iter()
772            .map(|name| features.get(name).copied().unwrap_or(0.0))
773            .collect();
774        self.predict(&vec)
775    }
776
777    // NOTE: explain() and explain_named() require the `explain` module which
778    // lives in the full `irithyll` crate, not in `irithyll-core`. Those methods
779    // are provided via the re-export layer in `irithyll::ensemble`.
780
781    /// Refresh auto-bandwidths if any tree has been replaced since last computation.
782    fn refresh_bandwidths(&mut self) {
783        let current_sum: u64 = self.steps.iter().map(|s| s.slot().replacements()).sum();
784        if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
785            self.auto_bandwidths = self.compute_auto_bandwidths();
786            self.last_replacement_sum = current_sum;
787        }
788    }
789
790    /// Compute per-feature auto-calibrated bandwidths from all trees.
791    ///
792    /// For each feature, collects all split thresholds across all trees,
793    /// computes the median gap between consecutive unique thresholds, and
794    /// returns `median_gap * K` (K = 2.0).
795    ///
796    /// Edge cases:
797    /// - Feature with < 3 unique thresholds: `range / n_bins * K`
798    /// - Feature never split on (< 2 unique thresholds): `f64::INFINITY` (hard routing)
799    fn compute_auto_bandwidths(&self) -> Vec<f64> {
800        const K: f64 = 2.0;
801
802        // Determine n_features from the trees
803        let n_features = self
804            .steps
805            .iter()
806            .filter_map(|s| s.slot().active_tree().n_features())
807            .max()
808            .unwrap_or(0);
809
810        if n_features == 0 {
811            return Vec::new();
812        }
813
814        // Collect all thresholds from all trees per feature
815        let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
816
817        for step in &self.steps {
818            let tree_thresholds = step
819                .slot()
820                .active_tree()
821                .collect_split_thresholds_per_feature();
822            for (i, ts) in tree_thresholds.into_iter().enumerate() {
823                if i < n_features {
824                    all_thresholds[i].extend(ts);
825                }
826            }
827        }
828
829        let n_bins = self.config.n_bins as f64;
830
831        // Compute per-feature bandwidth
832        all_thresholds
833            .iter()
834            .map(|ts| {
835                if ts.is_empty() {
836                    return f64::INFINITY; // Never split on → hard routing
837                }
838
839                let mut sorted = ts.clone();
840                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
841                sorted.dedup_by(|a, b| crate::math::abs(*a - *b) < 1e-15);
842
843                if sorted.len() < 2 {
844                    return f64::INFINITY; // Single threshold → hard routing
845                }
846
847                // Compute gaps between consecutive unique thresholds
848                let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
849
850                if sorted.len() < 3 {
851                    // Fallback: feature_range / n_bins * K
852                    let range = sorted.last().unwrap() - sorted.first().unwrap();
853                    if range < 1e-15 {
854                        return f64::INFINITY;
855                    }
856                    return (range / n_bins) * K;
857                }
858
859                // Median gap
860                gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
861                let median_gap = if gaps.len() % 2 == 0 {
862                    (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
863                } else {
864                    gaps[gaps.len() / 2]
865                };
866
867                if median_gap < 1e-15 {
868                    f64::INFINITY
869                } else {
870                    median_gap * K
871                }
872            })
873            .collect()
874    }
875
876    /// Reset the ensemble to initial state.
877    pub fn reset(&mut self) {
878        for step in &mut self.steps {
879            step.reset();
880        }
881        self.base_prediction = 0.0;
882        self.base_initialized = false;
883        self.initial_targets.clear();
884        self.samples_seen = 0;
885        self.rng_state = self.config.seed;
886        self.auto_bandwidths.clear();
887        self.last_replacement_sum = 0;
888    }
889
890    /// Serialize the model into a [`ModelState`](crate::serde_support::ModelState).
891    ///
892    /// Auto-detects the [`LossType`](crate::loss::LossType) from the loss
893    /// function's [`Loss::loss_type()`] implementation.
894    ///
895    /// # Errors
896    ///
897    /// Returns [`IrithyllError::Serialization`](crate::IrithyllError::Serialization)
898    /// if the loss does not implement `loss_type()` (returns `None`). For custom
899    /// losses, use [`to_model_state_with`](Self::to_model_state_with) instead.
900    #[cfg(feature = "_serde_support")]
901    pub fn to_model_state(&self) -> crate::error::Result<crate::serde_support::ModelState> {
902        let loss_type = self.loss.loss_type().ok_or_else(|| {
903            crate::error::IrithyllError::Serialization(
904                "cannot auto-detect loss type for serialization: \
905                 implement Loss::loss_type() or use to_model_state_with()"
906                    .into(),
907            )
908        })?;
909        Ok(self.to_model_state_with(loss_type))
910    }
911
912    /// Serialize the model with an explicit [`LossType`](crate::loss::LossType) tag.
913    ///
914    /// Use this for custom loss functions that don't implement `loss_type()`.
915    #[cfg(feature = "_serde_support")]
916    pub fn to_model_state_with(
917        &self,
918        loss_type: crate::loss::LossType,
919    ) -> crate::serde_support::ModelState {
920        use crate::serde_support::{ModelState, StepSnapshot};
921
922        let steps = self
923            .steps
924            .iter()
925            .map(|step| {
926                let slot = step.slot();
927                let tree_snap = snapshot_tree(slot.active_tree());
928                let alt_snap = slot.alternate_tree().map(snapshot_tree);
929                let drift_state = slot.detector().serialize_state();
930                let alt_drift_state = slot.alt_detector().and_then(|d| d.serialize_state());
931                StepSnapshot {
932                    tree: tree_snap,
933                    alternate_tree: alt_snap,
934                    drift_state,
935                    alt_drift_state,
936                }
937            })
938            .collect();
939
940        ModelState {
941            config: self.config.clone(),
942            loss_type,
943            base_prediction: self.base_prediction,
944            base_initialized: self.base_initialized,
945            initial_targets: self.initial_targets.clone(),
946            initial_target_count: self.initial_target_count,
947            samples_seen: self.samples_seen,
948            rng_state: self.rng_state,
949            steps,
950            rolling_mean_error: self.rolling_mean_error,
951            contribution_ewma: self.contribution_ewma.clone(),
952            low_contrib_count: self.low_contrib_count.clone(),
953        }
954    }
955}
956
957// ---------------------------------------------------------------------------
958// DynSGBT: deserialization returns a dynamically-dispatched model
959// ---------------------------------------------------------------------------
960
961#[cfg(feature = "_serde_support")]
962impl SGBT<Box<dyn Loss>> {
963    /// Reconstruct an SGBT model from a [`ModelState`](crate::serde_support::ModelState).
964    ///
965    /// Returns a [`DynSGBT`] (`SGBT<Box<dyn Loss>>`) because the concrete
966    /// loss type is determined at runtime from the serialized tag.
967    ///
968    /// Rebuilds the full ensemble including tree topology and leaf values.
969    /// Histogram accumulators are left empty and will rebuild from continued
970    /// training. If drift detector state was serialized, it is restored;
971    /// otherwise a fresh detector is created from the config.
972    pub fn from_model_state(state: crate::serde_support::ModelState) -> Self {
973        use crate::ensemble::replacement::TreeSlot;
974
975        let loss = state.loss_type.into_loss();
976
977        let leaf_decay_alpha = state
978            .config
979            .leaf_half_life
980            .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
981        let max_tree_samples = state.config.max_tree_samples;
982
983        let steps: Vec<BoostingStep> = state
984            .steps
985            .iter()
986            .enumerate()
987            .map(|(i, step_snap)| {
988                let tree_config = TreeConfig::new()
989                    .max_depth(state.config.max_depth)
990                    .n_bins(state.config.n_bins)
991                    .lambda(state.config.lambda)
992                    .gamma(state.config.gamma)
993                    .grace_period(state.config.grace_period)
994                    .delta(state.config.delta)
995                    .feature_subsample_rate(state.config.feature_subsample_rate)
996                    .leaf_decay_alpha_opt(leaf_decay_alpha)
997                    .split_reeval_interval_opt(state.config.split_reeval_interval)
998                    .feature_types_opt(state.config.feature_types.clone())
999                    .gradient_clip_sigma_opt(state.config.gradient_clip_sigma)
1000                    .monotone_constraints_opt(state.config.monotone_constraints.clone())
1001                    .adaptive_depth_opt(state.config.adaptive_depth)
1002                    .leaf_model_type(state.config.leaf_model_type.clone())
1003                    .seed(state.config.seed ^ (i as u64));
1004
1005                let active = rebuild_tree(&step_snap.tree, tree_config.clone());
1006                let alternate = step_snap
1007                    .alternate_tree
1008                    .as_ref()
1009                    .map(|snap| rebuild_tree(snap, tree_config.clone()));
1010
1011                let mut detector = state.config.drift_detector.create();
1012                if let Some(ref ds) = step_snap.drift_state {
1013                    detector.restore_state(ds);
1014                }
1015                let mut slot = TreeSlot::from_trees(
1016                    active,
1017                    alternate,
1018                    tree_config,
1019                    detector,
1020                    max_tree_samples,
1021                );
1022                if let Some(ref ads) = step_snap.alt_drift_state {
1023                    if let Some(alt_det) = slot.alt_detector_mut() {
1024                        alt_det.restore_state(ads);
1025                    }
1026                }
1027                BoostingStep::from_slot(slot)
1028            })
1029            .collect();
1030
1031        let n = steps.len();
1032        let has_pruning = state.config.quality_prune_alpha.is_some();
1033
1034        // Restore pruning state if available, otherwise initialize
1035        let contribution_ewma = if !state.contribution_ewma.is_empty() {
1036            state.contribution_ewma
1037        } else if has_pruning {
1038            vec![0.0; n]
1039        } else {
1040            Vec::new()
1041        };
1042        let low_contrib_count = if !state.low_contrib_count.is_empty() {
1043            state.low_contrib_count
1044        } else if has_pruning {
1045            vec![0; n]
1046        } else {
1047            Vec::new()
1048        };
1049
1050        Self {
1051            config: state.config,
1052            steps,
1053            loss,
1054            base_prediction: state.base_prediction,
1055            base_initialized: state.base_initialized,
1056            initial_targets: state.initial_targets,
1057            initial_target_count: state.initial_target_count,
1058            samples_seen: state.samples_seen,
1059            rng_state: state.rng_state,
1060            contribution_ewma,
1061            low_contrib_count,
1062            rolling_mean_error: state.rolling_mean_error,
1063            auto_bandwidths: Vec::new(),
1064            last_replacement_sum: 0,
1065        }
1066    }
1067}
1068
1069// ---------------------------------------------------------------------------
1070// Shared snapshot/rebuild helpers for serde (used by SGBT + DistributionalSGBT)
1071// ---------------------------------------------------------------------------
1072
1073/// Snapshot a [`HoeffdingTree`] into a serializable [`TreeSnapshot`].
1074#[cfg(feature = "_serde_support")]
1075pub(crate) fn snapshot_tree(
1076    tree: &crate::tree::hoeffding::HoeffdingTree,
1077) -> crate::serde_support::TreeSnapshot {
1078    use crate::serde_support::TreeSnapshot;
1079    use crate::tree::StreamingTree;
1080    let arena = tree.arena();
1081    TreeSnapshot {
1082        feature_idx: arena.feature_idx.clone(),
1083        threshold: arena.threshold.clone(),
1084        left: arena.left.iter().map(|id| id.0).collect(),
1085        right: arena.right.iter().map(|id| id.0).collect(),
1086        leaf_value: arena.leaf_value.clone(),
1087        is_leaf: arena.is_leaf.clone(),
1088        depth: arena.depth.clone(),
1089        sample_count: arena.sample_count.clone(),
1090        n_features: tree.n_features(),
1091        samples_seen: tree.n_samples_seen(),
1092        rng_state: tree.rng_state(),
1093        categorical_mask: arena.categorical_mask.clone(),
1094    }
1095}
1096
1097/// Rebuild a [`HoeffdingTree`] from a [`TreeSnapshot`] and a [`TreeConfig`].
1098#[cfg(feature = "_serde_support")]
1099pub(crate) fn rebuild_tree(
1100    snapshot: &crate::serde_support::TreeSnapshot,
1101    tree_config: TreeConfig,
1102) -> crate::tree::hoeffding::HoeffdingTree {
1103    use crate::tree::hoeffding::HoeffdingTree;
1104    use crate::tree::node::{NodeId, TreeArena};
1105
1106    let mut arena = TreeArena::new();
1107    let n = snapshot.feature_idx.len();
1108
1109    for i in 0..n {
1110        arena.feature_idx.push(snapshot.feature_idx[i]);
1111        arena.threshold.push(snapshot.threshold[i]);
1112        arena.left.push(NodeId(snapshot.left[i]));
1113        arena.right.push(NodeId(snapshot.right[i]));
1114        arena.leaf_value.push(snapshot.leaf_value[i]);
1115        arena.is_leaf.push(snapshot.is_leaf[i]);
1116        arena.depth.push(snapshot.depth[i]);
1117        arena.sample_count.push(snapshot.sample_count[i]);
1118        let mask = snapshot.categorical_mask.get(i).copied().flatten();
1119        arena.categorical_mask.push(mask);
1120    }
1121
1122    HoeffdingTree::from_arena(
1123        tree_config,
1124        arena,
1125        snapshot.n_features,
1126        snapshot.samples_seen,
1127        snapshot.rng_state,
1128    )
1129}
1130
1131#[cfg(test)]
1132mod tests {
1133    use super::*;
1134    use alloc::boxed::Box;
1135    use alloc::vec;
1136    use alloc::vec::Vec;
1137
1138    fn default_config() -> SGBTConfig {
1139        SGBTConfig::builder()
1140            .n_steps(10)
1141            .learning_rate(0.1)
1142            .grace_period(20)
1143            .max_depth(4)
1144            .n_bins(16)
1145            .build()
1146            .unwrap()
1147    }
1148
1149    #[test]
1150    fn new_model_predicts_zero() {
1151        let model = SGBT::new(default_config());
1152        let pred = model.predict(&[1.0, 2.0, 3.0]);
1153        assert!(pred.abs() < 1e-12);
1154    }
1155
1156    #[test]
1157    fn train_one_does_not_panic() {
1158        let mut model = SGBT::new(default_config());
1159        model.train_one(&Sample::new(vec![1.0, 2.0, 3.0], 5.0));
1160        assert_eq!(model.n_samples_seen(), 1);
1161    }
1162
1163    #[test]
1164    fn prediction_changes_after_training() {
1165        let mut model = SGBT::new(default_config());
1166        let features = vec![1.0, 2.0, 3.0];
1167        for i in 0..100 {
1168            model.train_one(&Sample::new(features.clone(), (i as f64) * 0.1));
1169        }
1170        let pred = model.predict(&features);
1171        assert!(pred.is_finite());
1172    }
1173
1174    #[test]
1175    fn linear_signal_rmse_improves() {
1176        let config = SGBTConfig::builder()
1177            .n_steps(20)
1178            .learning_rate(0.1)
1179            .grace_period(10)
1180            .max_depth(3)
1181            .n_bins(16)
1182            .build()
1183            .unwrap();
1184        let mut model = SGBT::new(config);
1185
1186        let mut rng: u64 = 12345;
1187        let mut early_errors = Vec::new();
1188        let mut late_errors = Vec::new();
1189
1190        for i in 0..500 {
1191            rng ^= rng << 13;
1192            rng ^= rng >> 7;
1193            rng ^= rng << 17;
1194            let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1195            rng ^= rng << 13;
1196            rng ^= rng >> 7;
1197            rng ^= rng << 17;
1198            let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1199            let target = 2.0 * x1 + 3.0 * x2;
1200
1201            let pred = model.predict(&[x1, x2]);
1202            let error = (pred - target).powi(2);
1203
1204            if (50..150).contains(&i) {
1205                early_errors.push(error);
1206            }
1207            if i >= 400 {
1208                late_errors.push(error);
1209            }
1210
1211            model.train_one(&Sample::new(vec![x1, x2], target));
1212        }
1213
1214        let early_rmse = (early_errors.iter().sum::<f64>() / early_errors.len() as f64).sqrt();
1215        let late_rmse = (late_errors.iter().sum::<f64>() / late_errors.len() as f64).sqrt();
1216
1217        assert!(
1218            late_rmse < early_rmse,
1219            "RMSE should decrease: early={:.4}, late={:.4}",
1220            early_rmse,
1221            late_rmse
1222        );
1223    }
1224
1225    #[test]
1226    fn train_batch_equivalent_to_sequential() {
1227        let config = default_config();
1228        let mut model_seq = SGBT::new(config.clone());
1229        let mut model_batch = SGBT::new(config);
1230
1231        let samples: Vec<Sample> = (0..20)
1232            .map(|i| {
1233                let x = i as f64 * 0.5;
1234                Sample::new(vec![x, x * 2.0], x * 3.0)
1235            })
1236            .collect();
1237
1238        for s in &samples {
1239            model_seq.train_one(s);
1240        }
1241        model_batch.train_batch(&samples);
1242
1243        let pred_seq = model_seq.predict(&[1.0, 2.0]);
1244        let pred_batch = model_batch.predict(&[1.0, 2.0]);
1245
1246        assert!(
1247            (pred_seq - pred_batch).abs() < 1e-10,
1248            "seq={}, batch={}",
1249            pred_seq,
1250            pred_batch
1251        );
1252    }
1253
1254    #[test]
1255    fn reset_returns_to_initial() {
1256        let mut model = SGBT::new(default_config());
1257        for i in 0..100 {
1258            model.train_one(&Sample::new(vec![1.0, 2.0], i as f64));
1259        }
1260        model.reset();
1261        assert_eq!(model.n_samples_seen(), 0);
1262        assert!(!model.is_initialized());
1263        assert!(model.predict(&[1.0, 2.0]).abs() < 1e-12);
1264    }
1265
1266    #[test]
1267    fn base_prediction_initializes() {
1268        let mut model = SGBT::new(default_config());
1269        for i in 0..50 {
1270            model.train_one(&Sample::new(vec![1.0], i as f64 + 100.0));
1271        }
1272        assert!(model.is_initialized());
1273        let expected = (100.0 + 149.0) / 2.0;
1274        assert!((model.base_prediction() - expected).abs() < 1.0);
1275    }
1276
1277    #[test]
1278    fn with_loss_uses_custom_loss() {
1279        use crate::loss::logistic::LogisticLoss;
1280        let model = SGBT::with_loss(default_config(), LogisticLoss);
1281        let pred = model.predict_transformed(&[1.0, 2.0]);
1282        assert!(
1283            (pred - 0.5).abs() < 1e-6,
1284            "sigmoid(0) should be 0.5, got {}",
1285            pred
1286        );
1287    }
1288
1289    #[test]
1290    fn ewma_config_propagates_and_trains() {
1291        let config = SGBTConfig::builder()
1292            .n_steps(5)
1293            .learning_rate(0.1)
1294            .grace_period(10)
1295            .max_depth(3)
1296            .n_bins(16)
1297            .leaf_half_life(50)
1298            .build()
1299            .unwrap();
1300        let mut model = SGBT::new(config);
1301
1302        for i in 0..200 {
1303            let x = (i as f64) * 0.1;
1304            model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
1305        }
1306
1307        let pred = model.predict(&[1.0, 2.0]);
1308        assert!(
1309            pred.is_finite(),
1310            "EWMA-enabled model should produce finite predictions, got {}",
1311            pred
1312        );
1313    }
1314
1315    #[test]
1316    fn max_tree_samples_config_propagates() {
1317        let config = SGBTConfig::builder()
1318            .n_steps(5)
1319            .learning_rate(0.1)
1320            .grace_period(10)
1321            .max_depth(3)
1322            .n_bins(16)
1323            .max_tree_samples(200)
1324            .build()
1325            .unwrap();
1326        let mut model = SGBT::new(config);
1327
1328        for i in 0..500 {
1329            let x = (i as f64) * 0.1;
1330            model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
1331        }
1332
1333        let pred = model.predict(&[1.0, 2.0]);
1334        assert!(
1335            pred.is_finite(),
1336            "max_tree_samples model should produce finite predictions, got {}",
1337            pred
1338        );
1339    }
1340
1341    #[test]
1342    fn split_reeval_config_propagates() {
1343        let config = SGBTConfig::builder()
1344            .n_steps(5)
1345            .learning_rate(0.1)
1346            .grace_period(10)
1347            .max_depth(2)
1348            .n_bins(16)
1349            .split_reeval_interval(50)
1350            .build()
1351            .unwrap();
1352        let mut model = SGBT::new(config);
1353
1354        let mut rng: u64 = 12345;
1355        for _ in 0..1000 {
1356            rng ^= rng << 13;
1357            rng ^= rng >> 7;
1358            rng ^= rng << 17;
1359            let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1360            rng ^= rng << 13;
1361            rng ^= rng >> 7;
1362            rng ^= rng << 17;
1363            let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1364            let target = 2.0 * x1 + 3.0 * x2;
1365            model.train_one(&Sample::new(vec![x1, x2], target));
1366        }
1367
1368        let pred = model.predict(&[1.0, 2.0]);
1369        assert!(
1370            pred.is_finite(),
1371            "split re-eval model should produce finite predictions, got {}",
1372            pred
1373        );
1374    }
1375
1376    #[test]
1377    fn loss_accessor_works() {
1378        use crate::loss::logistic::LogisticLoss;
1379        let model = SGBT::with_loss(default_config(), LogisticLoss);
1380        // Verify we can access the concrete loss type
1381        let _loss: &LogisticLoss = model.loss();
1382        assert_eq!(_loss.n_outputs(), 1);
1383    }
1384
1385    #[test]
1386    fn clone_produces_independent_copy() {
1387        let config = default_config();
1388        let mut model = SGBT::new(config);
1389
1390        // Train the original on some data
1391        let mut rng: u64 = 99999;
1392        for _ in 0..200 {
1393            rng ^= rng << 13;
1394            rng ^= rng >> 7;
1395            rng ^= rng << 17;
1396            let x = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1397            let target = 2.0 * x + 1.0;
1398            model.train_one(&Sample::new(vec![x], target));
1399        }
1400
1401        // Clone the model
1402        let mut cloned = model.clone();
1403
1404        // Both should produce identical predictions
1405        let test_features = [3.0];
1406        let pred_original = model.predict(&test_features);
1407        let pred_cloned = cloned.predict(&test_features);
1408        assert!(
1409            (pred_original - pred_cloned).abs() < 1e-12,
1410            "clone should predict identically: original={pred_original}, cloned={pred_cloned}"
1411        );
1412
1413        // Train only the clone further -- models should diverge
1414        for _ in 0..200 {
1415            rng ^= rng << 13;
1416            rng ^= rng >> 7;
1417            rng ^= rng << 17;
1418            let x = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1419            let target = -3.0 * x + 5.0; // Different relationship
1420            cloned.train_one(&Sample::new(vec![x], target));
1421        }
1422
1423        let pred_original_after = model.predict(&test_features);
1424        let pred_cloned_after = cloned.predict(&test_features);
1425
1426        // Original should be unchanged
1427        assert!(
1428            (pred_original - pred_original_after).abs() < 1e-12,
1429            "original should be unchanged after training clone"
1430        );
1431
1432        // Clone should have diverged
1433        assert!(
1434            (pred_original_after - pred_cloned_after).abs() > 1e-6,
1435            "clone should diverge after independent training"
1436        );
1437    }
1438
1439    // -------------------------------------------------------------------
1440    // predict_with_confidence returns finite values
1441    // -------------------------------------------------------------------
1442    #[test]
1443    fn predict_with_confidence_finite() {
1444        let config = SGBTConfig::builder()
1445            .n_steps(5)
1446            .grace_period(10)
1447            .build()
1448            .unwrap();
1449        let mut model = SGBT::new(config);
1450
1451        // Train enough to initialize
1452        for i in 0..100 {
1453            let x = i as f64 * 0.1;
1454            model.train_one(&(&[x, x * 2.0][..], x + 1.0));
1455        }
1456
1457        let (pred, confidence) = model.predict_with_confidence(&[1.0, 2.0]);
1458        assert!(pred.is_finite(), "prediction should be finite");
1459        assert!(confidence.is_finite(), "confidence should be finite");
1460        assert!(
1461            confidence > 0.0,
1462            "confidence should be positive after training"
1463        );
1464    }
1465
1466    // -------------------------------------------------------------------
1467    // predict_with_confidence positive after training
1468    // -------------------------------------------------------------------
1469    #[test]
1470    fn predict_with_confidence_positive_after_training() {
1471        let config = SGBTConfig::builder()
1472            .n_steps(5)
1473            .grace_period(10)
1474            .build()
1475            .unwrap();
1476        let mut model = SGBT::new(config);
1477
1478        // Train enough to initialize and build structure
1479        for i in 0..200 {
1480            let x = i as f64 * 0.05;
1481            model.train_one(&(&[x][..], x * 2.0));
1482        }
1483
1484        let (pred, confidence) = model.predict_with_confidence(&[1.0]);
1485
1486        assert!(pred.is_finite(), "prediction should be finite");
1487        assert!(
1488            confidence > 0.0 && confidence.is_finite(),
1489            "confidence should be finite and positive, got {}",
1490            confidence,
1491        );
1492
1493        // Multiple queries should give consistent confidence
1494        let (pred2, conf2) = model.predict_with_confidence(&[1.0]);
1495        assert!(
1496            (pred - pred2).abs() < 1e-12,
1497            "same input should give same prediction"
1498        );
1499        assert!(
1500            (confidence - conf2).abs() < 1e-12,
1501            "same input should give same confidence"
1502        );
1503    }
1504
1505    // -------------------------------------------------------------------
1506    // predict_with_confidence agrees with predict on point estimate
1507    // -------------------------------------------------------------------
1508    #[test]
1509    fn predict_with_confidence_matches_predict() {
1510        let config = SGBTConfig::builder()
1511            .n_steps(10)
1512            .grace_period(10)
1513            .build()
1514            .unwrap();
1515        let mut model = SGBT::new(config);
1516
1517        for i in 0..200 {
1518            let x = (i as f64 - 100.0) * 0.01;
1519            model.train_one(&(&[x, x * x][..], x * 3.0 + 1.0));
1520        }
1521
1522        let pred = model.predict(&[0.5, 0.25]);
1523        let (conf_pred, _) = model.predict_with_confidence(&[0.5, 0.25]);
1524
1525        assert!(
1526            (pred - conf_pred).abs() < 1e-10,
1527            "prediction mismatch: predict()={} vs predict_with_confidence()={}",
1528            pred,
1529            conf_pred,
1530        );
1531    }
1532
1533    // -------------------------------------------------------------------
1534    // gradient clipping config round-trips through builder
1535    // -------------------------------------------------------------------
1536    #[test]
1537    fn gradient_clip_config_builder() {
1538        let config = SGBTConfig::builder()
1539            .n_steps(10)
1540            .gradient_clip_sigma(3.0)
1541            .build()
1542            .unwrap();
1543
1544        assert_eq!(config.gradient_clip_sigma, Some(3.0));
1545    }
1546
1547    // -------------------------------------------------------------------
1548    // monotonic constraints config round-trips through builder
1549    // -------------------------------------------------------------------
1550    #[test]
1551    fn monotone_constraints_config_builder() {
1552        let config = SGBTConfig::builder()
1553            .n_steps(10)
1554            .monotone_constraints(vec![1, -1, 0])
1555            .build()
1556            .unwrap();
1557
1558        assert_eq!(config.monotone_constraints, Some(vec![1, -1, 0]));
1559    }
1560
1561    // -------------------------------------------------------------------
1562    // monotonic constraints validation rejects invalid values
1563    // -------------------------------------------------------------------
1564    #[test]
1565    fn monotone_constraints_invalid_value_rejected() {
1566        let result = SGBTConfig::builder()
1567            .n_steps(10)
1568            .monotone_constraints(vec![1, 2, 0])
1569            .build();
1570
1571        assert!(result.is_err(), "constraint value 2 should be rejected");
1572    }
1573
1574    // -------------------------------------------------------------------
1575    // gradient clipping validation rejects non-positive sigma
1576    // -------------------------------------------------------------------
1577    #[test]
1578    fn gradient_clip_sigma_negative_rejected() {
1579        let result = SGBTConfig::builder()
1580            .n_steps(10)
1581            .gradient_clip_sigma(-1.0)
1582            .build();
1583
1584        assert!(result.is_err(), "negative sigma should be rejected");
1585    }
1586
1587    // -------------------------------------------------------------------
1588    // gradient clipping ensemble-level reduces outlier impact
1589    // -------------------------------------------------------------------
1590    #[test]
1591    fn gradient_clipping_reduces_outlier_impact() {
1592        // Without clipping
1593        let config_no_clip = SGBTConfig::builder()
1594            .n_steps(5)
1595            .grace_period(10)
1596            .build()
1597            .unwrap();
1598        let mut model_no_clip = SGBT::new(config_no_clip);
1599
1600        // With clipping
1601        let config_clip = SGBTConfig::builder()
1602            .n_steps(5)
1603            .grace_period(10)
1604            .gradient_clip_sigma(3.0)
1605            .build()
1606            .unwrap();
1607        let mut model_clip = SGBT::new(config_clip);
1608
1609        // Train both on identical normal data
1610        for i in 0..100 {
1611            let x = (i as f64) * 0.01;
1612            let sample = (&[x][..], x * 2.0);
1613            model_no_clip.train_one(&sample);
1614            model_clip.train_one(&sample);
1615        }
1616
1617        let pred_no_clip_before = model_no_clip.predict(&[0.5]);
1618        let pred_clip_before = model_clip.predict(&[0.5]);
1619
1620        // Inject outlier
1621        let outlier = (&[0.5_f64][..], 10000.0);
1622        model_no_clip.train_one(&outlier);
1623        model_clip.train_one(&outlier);
1624
1625        let pred_no_clip_after = model_no_clip.predict(&[0.5]);
1626        let pred_clip_after = model_clip.predict(&[0.5]);
1627
1628        let delta_no_clip = (pred_no_clip_after - pred_no_clip_before).abs();
1629        let delta_clip = (pred_clip_after - pred_clip_before).abs();
1630
1631        // Clipped model should be less affected by the outlier
1632        assert!(
1633            delta_clip <= delta_no_clip + 1e-10,
1634            "clipped model should be less affected: delta_clip={}, delta_no_clip={}",
1635            delta_clip,
1636            delta_no_clip,
1637        );
1638    }
1639
1640    // -------------------------------------------------------------------
1641    // train_batch_with_callback fires at correct intervals
1642    // -------------------------------------------------------------------
1643    #[test]
1644    fn train_batch_with_callback_fires() {
1645        let config = SGBTConfig::builder()
1646            .n_steps(3)
1647            .grace_period(5)
1648            .build()
1649            .unwrap();
1650        let mut model = SGBT::new(config);
1651
1652        let data: Vec<(Vec<f64>, f64)> = (0..25)
1653            .map(|i| (vec![i as f64 * 0.1], i as f64 * 0.5))
1654            .collect();
1655
1656        let mut callbacks = Vec::new();
1657        model.train_batch_with_callback(&data, 10, |n| {
1658            callbacks.push(n);
1659        });
1660
1661        // Should fire at 10, 20, and 25 (final)
1662        assert_eq!(callbacks, vec![10, 20, 25]);
1663    }
1664
1665    // -------------------------------------------------------------------
1666    // train_batch_subsampled produces deterministic subset
1667    // -------------------------------------------------------------------
1668    #[test]
1669    fn train_batch_subsampled_trains_subset() {
1670        let config = SGBTConfig::builder()
1671            .n_steps(3)
1672            .grace_period(5)
1673            .build()
1674            .unwrap();
1675        let mut model = SGBT::new(config);
1676
1677        let data: Vec<(Vec<f64>, f64)> = (0..100)
1678            .map(|i| (vec![i as f64 * 0.01], i as f64 * 0.1))
1679            .collect();
1680
1681        // Train on only 20 of 100 samples
1682        model.train_batch_subsampled(&data, 20);
1683
1684        // Model should have seen some samples
1685        assert!(
1686            model.n_samples_seen() > 0,
1687            "model should have trained on subset"
1688        );
1689        assert!(
1690            model.n_samples_seen() <= 20,
1691            "model should have trained at most 20 samples, got {}",
1692            model.n_samples_seen(),
1693        );
1694    }
1695
1696    // -------------------------------------------------------------------
1697    // train_batch_subsampled full dataset = train_batch
1698    // -------------------------------------------------------------------
1699    #[test]
1700    fn train_batch_subsampled_full_equals_batch() {
1701        let config1 = SGBTConfig::builder()
1702            .n_steps(3)
1703            .grace_period(5)
1704            .build()
1705            .unwrap();
1706        let config2 = config1.clone();
1707
1708        let mut model1 = SGBT::new(config1);
1709        let mut model2 = SGBT::new(config2);
1710
1711        let data: Vec<(Vec<f64>, f64)> = (0..50)
1712            .map(|i| (vec![i as f64 * 0.1], i as f64 * 0.5))
1713            .collect();
1714
1715        model1.train_batch(&data);
1716        model2.train_batch_subsampled(&data, 1000); // max_samples > data.len()
1717
1718        // Both should have identical state
1719        assert_eq!(model1.n_samples_seen(), model2.n_samples_seen());
1720        let pred1 = model1.predict(&[2.5]);
1721        let pred2 = model2.predict(&[2.5]);
1722        assert!(
1723            (pred1 - pred2).abs() < 1e-12,
1724            "full subsample should equal batch: {} vs {}",
1725            pred1,
1726            pred2,
1727        );
1728    }
1729
1730    // -------------------------------------------------------------------
1731    // train_batch_subsampled_with_callback combines both
1732    // -------------------------------------------------------------------
1733    #[test]
1734    fn train_batch_subsampled_with_callback_works() {
1735        let config = SGBTConfig::builder()
1736            .n_steps(3)
1737            .grace_period(5)
1738            .build()
1739            .unwrap();
1740        let mut model = SGBT::new(config);
1741
1742        let data: Vec<(Vec<f64>, f64)> = (0..200)
1743            .map(|i| (vec![i as f64 * 0.01], i as f64 * 0.1))
1744            .collect();
1745
1746        let mut callbacks = Vec::new();
1747        model.train_batch_subsampled_with_callback(&data, 50, 10, |n| {
1748            callbacks.push(n);
1749        });
1750
1751        // Should have trained ~50 samples with callbacks at 10, 20, 30, 40, 50
1752        assert!(!callbacks.is_empty(), "should have received callbacks");
1753        assert_eq!(
1754            *callbacks.last().unwrap(),
1755            50,
1756            "final callback should be total samples"
1757        );
1758    }
1759
1760    // ---------------------------------------------------------------
1761    // Linear leaf model integration tests
1762    // ---------------------------------------------------------------
1763
1764    /// xorshift64 PRNG for deterministic test data.
1765    fn xorshift64(state: &mut u64) -> u64 {
1766        let mut s = *state;
1767        s ^= s << 13;
1768        s ^= s >> 7;
1769        s ^= s << 17;
1770        *state = s;
1771        s
1772    }
1773
1774    fn rand_f64(state: &mut u64) -> f64 {
1775        xorshift64(state) as f64 / u64::MAX as f64
1776    }
1777
1778    fn linear_leaves_config() -> SGBTConfig {
1779        SGBTConfig::builder()
1780            .n_steps(10)
1781            .learning_rate(0.1)
1782            .grace_period(20)
1783            .max_depth(2) // low depth -- linear leaves should shine
1784            .n_bins(16)
1785            .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1786                learning_rate: 0.1,
1787                decay: None,
1788                use_adagrad: false,
1789            })
1790            .build()
1791            .unwrap()
1792    }
1793
1794    #[test]
1795    fn linear_leaves_trains_without_panic() {
1796        let mut model = SGBT::new(linear_leaves_config());
1797        let mut rng = 42u64;
1798        for _ in 0..200 {
1799            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1800            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1801            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1802            model.train_one(&Sample::new(vec![x1, x2], y));
1803        }
1804        assert_eq!(model.n_samples_seen(), 200);
1805    }
1806
1807    #[test]
1808    fn linear_leaves_prediction_finite() {
1809        let mut model = SGBT::new(linear_leaves_config());
1810        let mut rng = 42u64;
1811        for _ in 0..200 {
1812            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1813            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1814            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1815            model.train_one(&Sample::new(vec![x1, x2], y));
1816        }
1817        let pred = model.predict(&[0.5, -0.3]);
1818        assert!(pred.is_finite(), "prediction should be finite, got {pred}");
1819    }
1820
1821    #[test]
1822    fn linear_leaves_learns_linear_target() {
1823        let mut model = SGBT::new(linear_leaves_config());
1824        let mut rng = 42u64;
1825        for _ in 0..500 {
1826            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1827            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1828            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1829            model.train_one(&Sample::new(vec![x1, x2], y));
1830        }
1831
1832        // Test on a few points -- should be reasonably close for a linear target.
1833        let mut total_error = 0.0;
1834        for _ in 0..50 {
1835            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1836            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1837            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1838            let pred = model.predict(&[x1, x2]);
1839            total_error += (pred - y).powi(2);
1840        }
1841        let mse = total_error / 50.0;
1842        assert!(
1843            mse < 5.0,
1844            "linear leaves MSE on linear target should be < 5.0, got {mse}"
1845        );
1846    }
1847
1848    #[test]
1849    fn linear_leaves_better_than_constant_at_low_depth() {
1850        // Train two models on a linear target at depth 2:
1851        // one with constant leaves, one with linear leaves.
1852        let constant_config = SGBTConfig::builder()
1853            .n_steps(10)
1854            .learning_rate(0.1)
1855            .grace_period(20)
1856            .max_depth(2)
1857            .n_bins(16)
1858            .seed(0xDEAD)
1859            .build()
1860            .unwrap();
1861        let linear_config = SGBTConfig::builder()
1862            .n_steps(10)
1863            .learning_rate(0.1)
1864            .grace_period(20)
1865            .max_depth(2)
1866            .n_bins(16)
1867            .seed(0xDEAD)
1868            .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1869                learning_rate: 0.1,
1870                decay: None,
1871                use_adagrad: false,
1872            })
1873            .build()
1874            .unwrap();
1875
1876        let mut constant_model = SGBT::new(constant_config);
1877        let mut linear_model = SGBT::new(linear_config);
1878        let mut rng = 42u64;
1879
1880        for _ in 0..500 {
1881            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1882            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1883            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1884            let sample = Sample::new(vec![x1, x2], y);
1885            constant_model.train_one(&sample);
1886            linear_model.train_one(&sample);
1887        }
1888
1889        // Evaluate both on test set.
1890        let mut constant_mse = 0.0;
1891        let mut linear_mse = 0.0;
1892        for _ in 0..100 {
1893            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1894            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1895            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1896            constant_mse += (constant_model.predict(&[x1, x2]) - y).powi(2);
1897            linear_mse += (linear_model.predict(&[x1, x2]) - y).powi(2);
1898        }
1899        constant_mse /= 100.0;
1900        linear_mse /= 100.0;
1901
1902        // Linear leaves should outperform constant leaves on a linear target.
1903        assert!(
1904            linear_mse < constant_mse,
1905            "linear leaves MSE ({linear_mse:.4}) should be less than constant ({constant_mse:.4})"
1906        );
1907    }
1908
1909    #[test]
1910    fn adaptive_leaves_trains_without_panic() {
1911        let config = SGBTConfig::builder()
1912            .n_steps(10)
1913            .learning_rate(0.1)
1914            .grace_period(20)
1915            .max_depth(3)
1916            .n_bins(16)
1917            .leaf_model_type(crate::tree::leaf_model::LeafModelType::Adaptive {
1918                promote_to: Box::new(crate::tree::leaf_model::LeafModelType::Linear {
1919                    learning_rate: 0.1,
1920                    decay: None,
1921                    use_adagrad: false,
1922                }),
1923            })
1924            .build()
1925            .unwrap();
1926
1927        let mut model = SGBT::new(config);
1928        let mut rng = 42u64;
1929        for _ in 0..500 {
1930            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1931            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1932            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1933            model.train_one(&Sample::new(vec![x1, x2], y));
1934        }
1935        let pred = model.predict(&[0.5, -0.3]);
1936        assert!(
1937            pred.is_finite(),
1938            "adaptive leaf prediction should be finite, got {pred}"
1939        );
1940    }
1941
1942    #[test]
1943    fn linear_leaves_with_decay_trains_without_panic() {
1944        let config = SGBTConfig::builder()
1945            .n_steps(10)
1946            .learning_rate(0.1)
1947            .grace_period(20)
1948            .max_depth(3)
1949            .n_bins(16)
1950            .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1951                learning_rate: 0.1,
1952                decay: Some(0.995),
1953                use_adagrad: false,
1954            })
1955            .build()
1956            .unwrap();
1957
1958        let mut model = SGBT::new(config);
1959        let mut rng = 42u64;
1960        for _ in 0..500 {
1961            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1962            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1963            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1964            model.train_one(&Sample::new(vec![x1, x2], y));
1965        }
1966        let pred = model.predict(&[0.5, -0.3]);
1967        assert!(
1968            pred.is_finite(),
1969            "decay leaf prediction should be finite, got {pred}"
1970        );
1971    }
1972
1973    // -------------------------------------------------------------------
1974    // predict_smooth returns finite values
1975    // -------------------------------------------------------------------
1976    #[test]
1977    fn predict_smooth_returns_finite() {
1978        let config = SGBTConfig::builder()
1979            .n_steps(5)
1980            .learning_rate(0.1)
1981            .grace_period(10)
1982            .build()
1983            .unwrap();
1984        let mut model = SGBT::new(config);
1985
1986        for i in 0..200 {
1987            let x = (i as f64) * 0.1;
1988            model.train_one(&Sample::new(vec![x, x.sin()], 2.0 * x + 1.0));
1989        }
1990
1991        let pred_hard = model.predict(&[1.0, 1.0_f64.sin()]);
1992        let pred_smooth = model.predict_smooth(&[1.0, 1.0_f64.sin()], 0.5);
1993
1994        assert!(pred_hard.is_finite(), "hard prediction should be finite");
1995        assert!(
1996            pred_smooth.is_finite(),
1997            "smooth prediction should be finite"
1998        );
1999    }
2000
2001    // -------------------------------------------------------------------
2002    // predict_smooth converges to hard predict at small bandwidth
2003    // -------------------------------------------------------------------
2004    #[test]
2005    fn predict_smooth_converges_to_hard_at_small_bandwidth() {
2006        let config = SGBTConfig::builder()
2007            .n_steps(5)
2008            .learning_rate(0.1)
2009            .grace_period(10)
2010            .build()
2011            .unwrap();
2012        let mut model = SGBT::new(config);
2013
2014        for i in 0..300 {
2015            let x = (i as f64) * 0.1;
2016            model.train_one(&Sample::new(vec![x, x * 0.5], 2.0 * x + 1.0));
2017        }
2018
2019        let features = [5.0, 2.5];
2020        let hard = model.predict(&features);
2021        let smooth = model.predict_smooth(&features, 0.001);
2022
2023        assert!(
2024            (hard - smooth).abs() < 0.5,
2025            "smooth with tiny bandwidth should approximate hard: hard={}, smooth={}",
2026            hard,
2027            smooth,
2028        );
2029    }
2030
2031    #[test]
2032    fn auto_bandwidth_computed_after_training() {
2033        let config = SGBTConfig::builder()
2034            .n_steps(5)
2035            .learning_rate(0.1)
2036            .grace_period(10)
2037            .build()
2038            .unwrap();
2039        let mut model = SGBT::new(config);
2040
2041        // Before training, no bandwidths
2042        assert!(model.auto_bandwidths().is_empty());
2043
2044        for i in 0..200 {
2045            let x = (i as f64) * 0.1;
2046            model.train_one(&Sample::new(vec![x, x * 0.5], 2.0 * x + 1.0));
2047        }
2048
2049        // After training, auto_bandwidths should be populated
2050        let bws = model.auto_bandwidths();
2051        assert_eq!(bws.len(), 2, "should have 2 feature bandwidths");
2052
2053        // predict() always uses smooth routing with auto-bandwidths
2054        let pred = model.predict(&[5.0, 2.5]);
2055        assert!(
2056            pred.is_finite(),
2057            "auto-bandwidth predict should be finite: {}",
2058            pred
2059        );
2060    }
2061
2062    #[test]
2063    fn predict_interpolated_returns_finite() {
2064        let config = SGBTConfig::builder()
2065            .n_steps(5)
2066            .learning_rate(0.01)
2067            .build()
2068            .unwrap();
2069        let mut model = SGBT::new(config);
2070
2071        for i in 0..200 {
2072            let x = (i as f64) * 0.1;
2073            model.train_one(&Sample::new(vec![x, x.sin()], x.cos()));
2074        }
2075
2076        let pred = model.predict_interpolated(&[1.0, 0.5]);
2077        assert!(
2078            pred.is_finite(),
2079            "interpolated prediction should be finite: {}",
2080            pred
2081        );
2082    }
2083
2084    #[test]
2085    fn predict_sibling_interpolated_varies_with_features() {
2086        let config = SGBTConfig::builder()
2087            .n_steps(10)
2088            .learning_rate(0.1)
2089            .grace_period(10)
2090            .max_depth(6)
2091            .delta(0.1)
2092            .build()
2093            .unwrap();
2094        let mut model = SGBT::new(config);
2095
2096        for i in 0..2000 {
2097            let x = (i as f64) * 0.01;
2098            let y = x.sin() * x + 0.5 * (x * 2.0).cos();
2099            model.train_one(&Sample::new(vec![x, x * 0.3], y));
2100        }
2101
2102        // Verify the method is callable and produces finite predictions
2103        let pred = model.predict_sibling_interpolated(&[5.0, 1.5]);
2104        assert!(pred.is_finite(), "sibling interpolated should be finite");
2105
2106        // If bandwidths are finite, verify sibling produces at least as much
2107        // variation as hard routing across a feature sweep
2108        let bws = model.auto_bandwidths();
2109        if bws.iter().any(|&b| b.is_finite()) {
2110            let hard: Vec<f64> = (0..200)
2111                .map(|i| model.predict(&[i as f64 * 0.1, i as f64 * 0.03]))
2112                .collect();
2113            let sib: Vec<f64> = (0..200)
2114                .map(|i| model.predict_sibling_interpolated(&[i as f64 * 0.1, i as f64 * 0.03]))
2115                .collect();
2116            let hc = hard
2117                .windows(2)
2118                .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2119                .count();
2120            let sc = sib
2121                .windows(2)
2122                .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2123                .count();
2124            assert!(
2125                sc >= hc,
2126                "sibling should produce >= hard changes: sib={}, hard={}",
2127                sc,
2128                hc
2129            );
2130        }
2131    }
2132
2133    #[test]
2134    fn predict_graduated_returns_finite() {
2135        let config = SGBTConfig::builder()
2136            .n_steps(5)
2137            .learning_rate(0.01)
2138            .max_tree_samples(200)
2139            .shadow_warmup(50)
2140            .build()
2141            .unwrap();
2142        let mut model = SGBT::new(config);
2143
2144        for i in 0..300 {
2145            let x = (i as f64) * 0.1;
2146            model.train_one(&Sample::new(vec![x, x.sin()], x.cos()));
2147        }
2148
2149        let pred = model.predict_graduated(&[1.0, 0.5]);
2150        assert!(
2151            pred.is_finite(),
2152            "graduated prediction should be finite: {}",
2153            pred
2154        );
2155
2156        let pred2 = model.predict_graduated_sibling_interpolated(&[1.0, 0.5]);
2157        assert!(
2158            pred2.is_finite(),
2159            "graduated+sibling prediction should be finite: {}",
2160            pred2
2161        );
2162    }
2163
2164    #[test]
2165    fn shadow_warmup_validation() {
2166        let result = SGBTConfig::builder()
2167            .n_steps(5)
2168            .learning_rate(0.01)
2169            .shadow_warmup(0)
2170            .build();
2171        assert!(result.is_err(), "shadow_warmup=0 should fail validation");
2172    }
2173}