Skip to main content

irithyll_core/ensemble/
mod.rs

1//! SGBT ensemble orchestrator -- the core boosting loop.
2//!
3//! Implements Streaming Gradient Boosted Trees (Gunasekara et al., 2024):
4//! a sequence of boosting steps, each owning a streaming tree and drift detector,
5//! with automatic tree replacement when concept drift is detected.
6//!
7//! # Algorithm
8//!
9//! For each incoming sample `(x, y)`:
10//! 1. Compute the current ensemble prediction: `F(x) = base + lr * Σ tree_s(x)`
11//! 2. For each boosting step `s = 1..N`:
12//!    - Compute gradient `g = loss.gradient(y, current_pred)`
13//!    - Compute hessian `h = loss.hessian(y, current_pred)`
14//!    - Feed `(x, g, h)` to tree `s` (which internally uses weighted squared loss)
15//!    - Update `current_pred += lr * tree_s.predict(x)`
16//! 3. The ensemble adapts incrementally, with each tree targeting the residual
17//!    of all preceding trees.
18
19pub mod adaptive;
20pub mod adaptive_forest;
21pub mod bagged;
22pub mod config;
23pub mod distributional;
24pub mod lr_schedule;
25pub mod moe;
26pub mod moe_distributional;
27pub mod multi_target;
28pub mod multiclass;
29#[cfg(feature = "parallel")]
30pub mod parallel;
31pub mod quantile_regressor;
32pub mod replacement;
33pub mod stacked;
34pub mod step;
35pub mod variants;
36
37use alloc::boxed::Box;
38use alloc::string::String;
39use alloc::vec;
40use alloc::vec::Vec;
41
42use core::fmt;
43
44use crate::ensemble::config::SGBTConfig;
45use crate::ensemble::step::BoostingStep;
46use crate::loss::squared::SquaredLoss;
47use crate::loss::Loss;
48use crate::sample::Observation;
49#[allow(unused_imports)] // Used in doc links + tests
50use crate::sample::Sample;
51use crate::tree::builder::TreeConfig;
52
53/// Type alias for an SGBT model using dynamic (boxed) loss dispatch.
54///
55/// Use this when the loss function is determined at runtime (e.g., when
56/// deserializing a model from JSON where the loss type is stored as a tag).
57///
58/// For compile-time loss dispatch (preferred for performance), use
59/// `SGBT<LogisticLoss>`, `SGBT<HuberLoss>`, etc.
60pub type DynSGBT = SGBT<Box<dyn Loss>>;
61
62/// Streaming Gradient Boosted Trees ensemble.
63///
64/// The primary entry point for training and prediction. Generic over `L: Loss`
65/// so the loss function's gradient/hessian calls are monomorphized (inlined)
66/// into the boosting hot loop -- no virtual dispatch overhead.
67///
68/// The default type parameter `L = SquaredLoss` means `SGBT::new(config)`
69/// creates a regression model without specifying the loss type explicitly.
70///
71/// # Examples
72///
73/// ```text
74/// use irithyll::{SGBTConfig, SGBT};
75///
76/// // Regression with squared loss (default):
77/// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
78/// let model = SGBT::new(config);
79/// ```ignore
80///
81/// ```text
82/// use irithyll::{SGBTConfig, SGBT};
83/// use irithyll::loss::logistic::LogisticLoss;
84///
85/// // Classification with logistic loss -- no Box::new()!
86/// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
87/// let model = SGBT::with_loss(config, LogisticLoss);
88/// ```
89pub struct SGBT<L: Loss = SquaredLoss> {
90    /// Configuration.
91    config: SGBTConfig,
92    /// Boosting steps (one tree + drift detector each).
93    steps: Vec<BoostingStep>,
94    /// Loss function (monomorphized -- no vtable).
95    loss: L,
96    /// Base prediction (initial constant, computed from first batch of targets).
97    base_prediction: f64,
98    /// Whether base_prediction has been initialized.
99    base_initialized: bool,
100    /// Running collection of initial targets for computing base_prediction.
101    initial_targets: Vec<f64>,
102    /// Number of initial targets to collect before setting base_prediction.
103    initial_target_count: usize,
104    /// Total samples trained.
105    samples_seen: u64,
106    /// RNG state for variant skip logic.
107    rng_state: u64,
108    /// Per-step EWMA of |marginal contribution| for quality-based pruning.
109    /// Empty when `quality_prune_alpha` is `None`.
110    contribution_ewma: Vec<f64>,
111    /// Per-step consecutive low-contribution sample counter.
112    /// Empty when `quality_prune_alpha` is `None`.
113    low_contrib_count: Vec<u64>,
114    /// Rolling mean absolute error for error-weighted sample importance.
115    /// Only used when `error_weight_alpha` is `Some`.
116    rolling_mean_error: f64,
117    /// EWMA of contribution standard deviation (σ proxy for adaptive_mts).
118    /// Only updated when `adaptive_mts` is `Some`.
119    rolling_contribution_sigma: f64,
120    /// Per-feature auto-calibrated bandwidths for smooth prediction.
121    /// Computed from median split threshold gaps across all trees.
122    auto_bandwidths: Vec<f64>,
123    /// Sum of replacement counts across all steps at last bandwidth computation.
124    /// Used to detect when trees have been replaced and bandwidths need refresh.
125    last_replacement_sum: u64,
126}
127
128impl<L: Loss + Clone> Clone for SGBT<L> {
129    fn clone(&self) -> Self {
130        Self {
131            config: self.config.clone(),
132            steps: self.steps.clone(),
133            loss: self.loss.clone(),
134            base_prediction: self.base_prediction,
135            base_initialized: self.base_initialized,
136            initial_targets: self.initial_targets.clone(),
137            initial_target_count: self.initial_target_count,
138            samples_seen: self.samples_seen,
139            rng_state: self.rng_state,
140            contribution_ewma: self.contribution_ewma.clone(),
141            low_contrib_count: self.low_contrib_count.clone(),
142            rolling_mean_error: self.rolling_mean_error,
143            rolling_contribution_sigma: self.rolling_contribution_sigma,
144            auto_bandwidths: self.auto_bandwidths.clone(),
145            last_replacement_sum: self.last_replacement_sum,
146        }
147    }
148}
149
150impl<L: Loss> fmt::Debug for SGBT<L> {
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        f.debug_struct("SGBT")
153            .field("n_steps", &self.steps.len())
154            .field("samples_seen", &self.samples_seen)
155            .field("base_prediction", &self.base_prediction)
156            .field("base_initialized", &self.base_initialized)
157            .finish()
158    }
159}
160
161// ---------------------------------------------------------------------------
162// Convenience constructor for the default loss (SquaredLoss)
163// ---------------------------------------------------------------------------
164
165impl SGBT<SquaredLoss> {
166    /// Create a new SGBT ensemble with squared loss (regression).
167    ///
168    /// This is the most common constructor. For classification or custom
169    /// losses, use [`with_loss`](SGBT::with_loss).
170    pub fn new(config: SGBTConfig) -> Self {
171        Self::with_loss(config, SquaredLoss)
172    }
173}
174
175// ---------------------------------------------------------------------------
176// General impl for all Loss types
177// ---------------------------------------------------------------------------
178
179impl<L: Loss> SGBT<L> {
180    /// Create a new SGBT ensemble with a specific loss function.
181    ///
182    /// The loss is stored by value (monomorphized), giving zero-cost
183    /// gradient/hessian dispatch.
184    ///
185    /// ```ignore
186    /// use irithyll::{SGBTConfig, SGBT};
187    /// use irithyll::loss::logistic::LogisticLoss;
188    ///
189    /// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
190    /// let model = SGBT::with_loss(config, LogisticLoss);
191    /// ```
192    pub fn with_loss(config: SGBTConfig, loss: L) -> Self {
193        let leaf_decay_alpha = config
194            .leaf_half_life
195            .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
196
197        let tree_config = TreeConfig::new()
198            .max_depth(config.max_depth)
199            .n_bins(config.n_bins)
200            .lambda(config.lambda)
201            .gamma(config.gamma)
202            .grace_period(config.grace_period)
203            .delta(config.delta)
204            .feature_subsample_rate(config.feature_subsample_rate)
205            .leaf_decay_alpha_opt(leaf_decay_alpha)
206            .split_reeval_interval_opt(config.split_reeval_interval)
207            .feature_types_opt(config.feature_types.clone())
208            .gradient_clip_sigma_opt(config.gradient_clip_sigma)
209            .monotone_constraints_opt(config.monotone_constraints.clone())
210            .max_leaf_output_opt(config.max_leaf_output)
211            .adaptive_leaf_bound_opt(config.adaptive_leaf_bound)
212            .adaptive_depth_opt(config.adaptive_depth)
213            .min_hessian_sum_opt(config.min_hessian_sum)
214            .leaf_model_type(config.leaf_model_type.clone());
215
216        let max_tree_samples = config.max_tree_samples;
217
218        let shadow_warmup = config.shadow_warmup.unwrap_or(0);
219        let steps: Vec<BoostingStep> = (0..config.n_steps)
220            .map(|i| {
221                let mut tc = tree_config.clone();
222                tc.seed = config.seed ^ (i as u64);
223                let detector = config.drift_detector.create();
224                if shadow_warmup > 0 {
225                    BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
226                } else {
227                    BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
228                }
229            })
230            .collect();
231
232        let seed = config.seed;
233        let initial_target_count = config.initial_target_count;
234        let n = config.n_steps;
235        let has_pruning =
236            config.quality_prune_alpha.is_some() || config.proactive_prune_interval.is_some();
237        Self {
238            config,
239            steps,
240            loss,
241            base_prediction: 0.0,
242            base_initialized: false,
243            initial_targets: Vec::new(),
244            initial_target_count,
245            samples_seen: 0,
246            rng_state: seed,
247            contribution_ewma: if has_pruning {
248                vec![0.0; n]
249            } else {
250                Vec::new()
251            },
252            low_contrib_count: if has_pruning { vec![0; n] } else { Vec::new() },
253            rolling_mean_error: 0.0,
254            rolling_contribution_sigma: 0.0,
255            auto_bandwidths: Vec::new(),
256            last_replacement_sum: 0,
257        }
258    }
259
260    /// Train on a single observation.
261    ///
262    /// Accepts any type implementing [`Observation`], including [`Sample`],
263    /// [`SampleRef`](crate::SampleRef), or tuples like `(&[f64], f64)` for
264    /// zero-copy training.
265    pub fn train_one(&mut self, sample: &impl Observation) {
266        self.samples_seen += 1;
267        let target = sample.target();
268        let features = sample.features();
269
270        // Initialize base prediction from first few targets
271        if !self.base_initialized {
272            self.initial_targets.push(target);
273            if self.initial_targets.len() >= self.initial_target_count {
274                self.base_prediction = self.loss.initial_prediction(&self.initial_targets);
275                self.base_initialized = true;
276                self.initial_targets.clear();
277                self.initial_targets.shrink_to_fit();
278            }
279        }
280
281        // Current prediction starts from base
282        let mut current_pred = self.base_prediction;
283
284        // Adaptive MTS: compute contribution variance and set effective max_tree_samples
285        if let Some((base_mts, k)) = self.config.adaptive_mts {
286            let sigma = self.contribution_variance(features);
287            self.rolling_contribution_sigma =
288                0.999 * self.rolling_contribution_sigma + 0.001 * sigma;
289
290            let normalized = if self.rolling_contribution_sigma > 1e-10 {
291                sigma / self.rolling_contribution_sigma
292            } else {
293                1.0
294            };
295            let factor = 1.0 / (1.0 + k * normalized);
296            let floor = (base_mts as f64 * self.config.adaptive_mts_floor)
297                .max(self.config.grace_period as f64 * 2.0);
298            let effective_mts = ((base_mts as f64) * factor).max(floor) as u64;
299            for step in &mut self.steps {
300                step.slot_mut().set_max_tree_samples(Some(effective_mts));
301            }
302        }
303
304        let prune_alpha = self
305            .config
306            .quality_prune_alpha
307            .or_else(|| self.config.proactive_prune_interval.map(|_| 0.01));
308        let prune_threshold = self.config.quality_prune_threshold;
309        let prune_patience = self.config.quality_prune_patience;
310
311        // Error-weighted sample importance: compute weight from prediction error
312        let error_weight = if let Some(ew_alpha) = self.config.error_weight_alpha {
313            let abs_error = crate::math::abs(target - current_pred);
314            if self.rolling_mean_error > 1e-15 {
315                let w = (1.0 + abs_error / (self.rolling_mean_error + 1e-15)).min(10.0);
316                self.rolling_mean_error =
317                    ew_alpha * abs_error + (1.0 - ew_alpha) * self.rolling_mean_error;
318                w
319            } else {
320                self.rolling_mean_error = abs_error.max(1e-15);
321                1.0 // first sample, no reweighting
322            }
323        } else {
324            1.0
325        };
326
327        // Sequential boosting: each step targets the residual of all prior steps
328        for s in 0..self.steps.len() {
329            let gradient = self.loss.gradient(target, current_pred) * error_weight;
330            let hessian = self.loss.hessian(target, current_pred) * error_weight;
331            let train_count = self
332                .config
333                .variant
334                .train_count(hessian, &mut self.rng_state);
335
336            let step_pred =
337                self.steps[s].train_and_predict(features, gradient, hessian, train_count);
338
339            current_pred += self.config.learning_rate * step_pred;
340
341            // Quality-based tree pruning: track contribution and replace dead wood
342            if let Some(alpha) = prune_alpha {
343                let contribution = crate::math::abs(self.config.learning_rate * step_pred);
344                self.contribution_ewma[s] =
345                    alpha * contribution + (1.0 - alpha) * self.contribution_ewma[s];
346
347                if self.contribution_ewma[s] < prune_threshold {
348                    self.low_contrib_count[s] += 1;
349                    if self.low_contrib_count[s] >= prune_patience {
350                        self.steps[s].reset();
351                        self.contribution_ewma[s] = 0.0;
352                        self.low_contrib_count[s] = 0;
353                    }
354                } else {
355                    self.low_contrib_count[s] = 0;
356                }
357            }
358        }
359
360        // Proactive pruning: replace worst-contributing tree at interval
361        if let Some(interval) = self.config.proactive_prune_interval {
362            if self.samples_seen % interval == 0
363                && self.samples_seen > 0
364                && !self.contribution_ewma.is_empty()
365            {
366                let min_age = interval / 2;
367                let worst_idx = self
368                    .steps
369                    .iter()
370                    .enumerate()
371                    .zip(self.contribution_ewma.iter())
372                    .filter(|((_, step), _)| step.n_samples_seen() >= min_age)
373                    .min_by(|((_, _), a_ewma), ((_, _), b_ewma)| {
374                        a_ewma
375                            .partial_cmp(b_ewma)
376                            .unwrap_or(core::cmp::Ordering::Equal)
377                    })
378                    .map(|((i, _), _)| i);
379
380                if let Some(idx) = worst_idx {
381                    self.steps[idx].reset();
382                    self.contribution_ewma[idx] = 0.0;
383                    self.low_contrib_count[idx] = 0;
384                }
385            }
386        }
387
388        // Refresh auto-bandwidths when trees have been replaced or not yet computed.
389        self.refresh_bandwidths();
390    }
391
392    /// Train on a batch of observations.
393    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
394        for sample in samples {
395            self.train_one(sample);
396        }
397    }
398
399    /// Train on a batch with periodic callback for cooperative yielding.
400    ///
401    /// The callback is invoked every `interval` samples with the number of
402    /// samples processed so far. This allows long-running training to yield
403    /// to other tasks in an async runtime, update progress bars, or perform
404    /// periodic checkpointing.
405    ///
406    /// # Example
407    ///
408    /// ```ignore
409    /// use irithyll::{SGBTConfig, SGBT};
410    ///
411    /// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
412    /// let mut model = SGBT::new(config);
413    /// let data: Vec<(Vec<f64>, f64)> = Vec::new(); // your data
414    ///
415    /// model.train_batch_with_callback(&data, 1000, |processed| {
416    ///     println!("Trained {} samples", processed);
417    /// });
418    /// ```
419    pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
420        &mut self,
421        samples: &[O],
422        interval: usize,
423        mut callback: F,
424    ) {
425        let interval = interval.max(1); // Prevent zero interval
426        for (i, sample) in samples.iter().enumerate() {
427            self.train_one(sample);
428            if (i + 1) % interval == 0 {
429                callback(i + 1);
430            }
431        }
432        // Final callback if the total isn't a multiple of interval
433        let total = samples.len();
434        if total % interval != 0 {
435            callback(total);
436        }
437    }
438
439    /// Train on a random subsample of a batch using reservoir sampling.
440    ///
441    /// When `max_samples < samples.len()`, selects a representative subset
442    /// using Algorithm R (Vitter, 1985) -- a uniform random sample without
443    /// replacement. The selected samples are then trained in their original
444    /// order to preserve sequential dependencies.
445    ///
446    /// This is ideal for large replay buffers where training on the full
447    /// dataset is prohibitively slow but a representative subset gives
448    /// equivalent model quality (e.g., 1M of 4.3M samples with R²=0.997).
449    ///
450    /// When `max_samples >= samples.len()`, all samples are trained.
451    pub fn train_batch_subsampled<O: Observation>(&mut self, samples: &[O], max_samples: usize) {
452        if max_samples >= samples.len() {
453            self.train_batch(samples);
454            return;
455        }
456
457        // Reservoir sampling (Algorithm R) to select indices
458        let mut reservoir: Vec<usize> = (0..max_samples).collect();
459        let mut rng = self.rng_state;
460
461        for i in max_samples..samples.len() {
462            // Generate random index in [0, i]
463            rng ^= rng << 13;
464            rng ^= rng >> 7;
465            rng ^= rng << 17;
466            let j = (rng % (i as u64 + 1)) as usize;
467            if j < max_samples {
468                reservoir[j] = i;
469            }
470        }
471
472        self.rng_state = rng;
473
474        // Sort to preserve original order (important for EWMA/drift state)
475        reservoir.sort_unstable();
476
477        // Train on the selected subset
478        for &idx in &reservoir {
479            self.train_one(&samples[idx]);
480        }
481    }
482
483    /// Train on a batch with both subsampling and periodic callbacks.
484    ///
485    /// Combines reservoir subsampling with cooperative yield points.
486    /// Ideal for long-running daemon training where you need both
487    /// efficiency (subsampling) and cooperation (yielding).
488    pub fn train_batch_subsampled_with_callback<O: Observation, F: FnMut(usize)>(
489        &mut self,
490        samples: &[O],
491        max_samples: usize,
492        interval: usize,
493        mut callback: F,
494    ) {
495        if max_samples >= samples.len() {
496            self.train_batch_with_callback(samples, interval, callback);
497            return;
498        }
499
500        // Reservoir sampling
501        let mut reservoir: Vec<usize> = (0..max_samples).collect();
502        let mut rng = self.rng_state;
503
504        for i in max_samples..samples.len() {
505            rng ^= rng << 13;
506            rng ^= rng >> 7;
507            rng ^= rng << 17;
508            let j = (rng % (i as u64 + 1)) as usize;
509            if j < max_samples {
510                reservoir[j] = i;
511            }
512        }
513
514        self.rng_state = rng;
515        reservoir.sort_unstable();
516
517        let interval = interval.max(1);
518        for (i, &idx) in reservoir.iter().enumerate() {
519            self.train_one(&samples[idx]);
520            if (i + 1) % interval == 0 {
521                callback(i + 1);
522            }
523        }
524        let total = reservoir.len();
525        if total % interval != 0 {
526            callback(total);
527        }
528    }
529
530    /// Predict the raw output for a feature vector.
531    ///
532    /// Always uses sigmoid-blended soft routing with auto-calibrated per-feature
533    /// bandwidths derived from median split threshold gaps. Features that have
534    /// never been split on use hard routing (bandwidth = infinity).
535    pub fn predict(&self, features: &[f64]) -> f64 {
536        let mut pred = self.base_prediction;
537        if self.auto_bandwidths.is_empty() {
538            // No bandwidths computed yet (no training) — hard routing fallback
539            for step in &self.steps {
540                pred += self.config.learning_rate * step.predict(features);
541            }
542        } else {
543            for step in &self.steps {
544                pred += self.config.learning_rate
545                    * step.predict_smooth_auto(features, &self.auto_bandwidths);
546            }
547        }
548        pred
549    }
550
551    /// Predict using sigmoid-blended soft routing with an explicit bandwidth.
552    ///
553    /// Uses a single bandwidth for all features. For auto-calibrated per-feature
554    /// bandwidths, use [`predict()`](SGBT::predict) which always uses smooth routing.
555    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
556        let mut pred = self.base_prediction;
557        for step in &self.steps {
558            pred += self.config.learning_rate * step.predict_smooth(features, bandwidth);
559        }
560        pred
561    }
562
563    /// Per-feature auto-calibrated bandwidths used by `predict()`.
564    ///
565    /// Empty before the first training sample. Each entry corresponds to a
566    /// feature index; `f64::INFINITY` means that feature has no splits and
567    /// uses hard routing.
568    pub fn auto_bandwidths(&self) -> &[f64] {
569        &self.auto_bandwidths
570    }
571
572    /// Predict with parent-leaf linear interpolation.
573    ///
574    /// Blends each leaf prediction with its parent's preserved prediction
575    /// based on sample count, preventing stale predictions from fresh leaves.
576    pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
577        let mut pred = self.base_prediction;
578        for step in &self.steps {
579            pred += self.config.learning_rate * step.predict_interpolated(features);
580        }
581        pred
582    }
583
584    /// Predict with sibling-based interpolation for feature-continuous predictions.
585    ///
586    /// At each split node near the threshold boundary, blends left and right
587    /// subtree predictions linearly based on distance from the threshold.
588    /// Uses auto-calibrated bandwidths as the interpolation margin.
589    /// Predictions vary continuously as features change, eliminating
590    /// step-function artifacts.
591    pub fn predict_sibling_interpolated(&self, features: &[f64]) -> f64 {
592        let mut pred = self.base_prediction;
593        for step in &self.steps {
594            pred += self.config.learning_rate
595                * step.predict_sibling_interpolated(features, &self.auto_bandwidths);
596        }
597        pred
598    }
599
600    /// Predict with graduated active-shadow blending.
601    ///
602    /// Smoothly transitions between active and shadow trees during replacement,
603    /// eliminating prediction dips. Requires `shadow_warmup` to be configured.
604    /// When disabled, equivalent to `predict()`.
605    pub fn predict_graduated(&self, features: &[f64]) -> f64 {
606        let mut pred = self.base_prediction;
607        for step in &self.steps {
608            pred += self.config.learning_rate * step.predict_graduated(features);
609        }
610        pred
611    }
612
613    /// Predict with graduated blending + sibling interpolation (premium path).
614    ///
615    /// Combines graduated active-shadow handoff (no prediction dips during
616    /// tree replacement) with feature-continuous sibling interpolation
617    /// (no step-function artifacts near split boundaries).
618    pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> f64 {
619        let mut pred = self.base_prediction;
620        for step in &self.steps {
621            pred += self.config.learning_rate
622                * step.predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
623        }
624        pred
625    }
626
627    /// Predict with loss transform applied (e.g., sigmoid for logistic loss).
628    pub fn predict_transformed(&self, features: &[f64]) -> f64 {
629        self.loss.predict_transform(self.predict(features))
630    }
631
632    /// Predict probability (alias for `predict_transformed`).
633    pub fn predict_proba(&self, features: &[f64]) -> f64 {
634        self.predict_transformed(features)
635    }
636
637    /// Predict with confidence estimation.
638    ///
639    /// Returns `(prediction, confidence)` where confidence = 1 / sqrt(sum_variance).
640    /// Higher confidence indicates more certain predictions (leaves have seen
641    /// more hessian mass). Confidence of 0.0 means the model has no information.
642    ///
643    /// This enables execution engines to modulate aggressiveness:
644    /// - High confidence + favorable prediction → act immediately
645    /// - Low confidence → fall back to simpler models or wait for more data
646    ///
647    /// The variance per tree is estimated as `1 / (H_sum + lambda)` at the
648    /// leaf where the sample lands. The ensemble variance is the sum of
649    /// per-tree variances (scaled by learning_rate²), and confidence is
650    /// the reciprocal of the standard deviation.
651    pub fn predict_with_confidence(&self, features: &[f64]) -> (f64, f64) {
652        let mut pred = self.base_prediction;
653        let mut total_variance = 0.0;
654        let lr2 = self.config.learning_rate * self.config.learning_rate;
655
656        for step in &self.steps {
657            let (value, variance) = step.predict_with_variance(features);
658            pred += self.config.learning_rate * value;
659            total_variance += lr2 * variance;
660        }
661
662        let confidence = if total_variance > 0.0 && total_variance.is_finite() {
663            1.0 / crate::math::sqrt(total_variance)
664        } else {
665            0.0
666        };
667
668        (pred, confidence)
669    }
670
671    /// Batch prediction.
672    pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
673        feature_matrix.iter().map(|f| self.predict(f)).collect()
674    }
675
676    /// Number of boosting steps.
677    pub fn n_steps(&self) -> usize {
678        self.steps.len()
679    }
680
681    /// Total trees (active + alternates).
682    pub fn n_trees(&self) -> usize {
683        self.steps.len() + self.steps.iter().filter(|s| s.has_alternate()).count()
684    }
685
686    /// Total leaves across all active trees.
687    pub fn total_leaves(&self) -> usize {
688        self.steps.iter().map(|s| s.n_leaves()).sum()
689    }
690
691    /// Total samples trained.
692    pub fn n_samples_seen(&self) -> u64 {
693        self.samples_seen
694    }
695
696    /// The current base prediction.
697    pub fn base_prediction(&self) -> f64 {
698        self.base_prediction
699    }
700
701    /// Whether the base prediction has been initialized.
702    pub fn is_initialized(&self) -> bool {
703        self.base_initialized
704    }
705
706    /// Access the configuration.
707    pub fn config(&self) -> &SGBTConfig {
708        &self.config
709    }
710
711    /// Set the learning rate for future boosting rounds.
712    ///
713    /// This allows external schedulers (e.g., [`lr_schedule::LRScheduler`]) to
714    /// adapt the rate over time without rebuilding the model.
715    ///
716    /// # Arguments
717    ///
718    /// * `lr` -- New learning rate (should be positive and finite)
719    #[inline]
720    pub fn set_learning_rate(&mut self, lr: f64) {
721        self.config.learning_rate = lr;
722    }
723
724    /// Immutable access to the boosting steps.
725    ///
726    /// Useful for model inspection and export (e.g., ONNX serialization).
727    pub fn steps(&self) -> &[BoostingStep] {
728        &self.steps
729    }
730
731    /// Immutable access to the loss function.
732    pub fn loss(&self) -> &L {
733        &self.loss
734    }
735
736    /// Feature importances based on accumulated split gains across all trees.
737    ///
738    /// Returns normalized importances (sum to 1.0) indexed by feature.
739    /// Returns an empty Vec if no splits have occurred yet.
740    pub fn feature_importances(&self) -> Vec<f64> {
741        // Aggregate split gains across all boosting steps.
742        let mut totals: Vec<f64> = Vec::new();
743        for step in &self.steps {
744            let gains = step.slot().split_gains();
745            if totals.is_empty() && !gains.is_empty() {
746                totals.resize(gains.len(), 0.0);
747            }
748            for (i, &g) in gains.iter().enumerate() {
749                if i < totals.len() {
750                    totals[i] += g;
751                }
752            }
753        }
754
755        // Normalize to sum to 1.0.
756        let sum: f64 = totals.iter().sum();
757        if sum > 0.0 {
758            totals.iter_mut().for_each(|v| *v /= sum);
759        }
760        totals
761    }
762
763    /// Feature names, if configured.
764    pub fn feature_names(&self) -> Option<&[String]> {
765        self.config.feature_names.as_deref()
766    }
767
768    /// Feature importances paired with their names.
769    ///
770    /// Returns `None` if feature names are not configured. Otherwise returns
771    /// `(name, importance)` pairs sorted by importance descending.
772    pub fn named_feature_importances(&self) -> Option<Vec<(String, f64)>> {
773        let names = self.config.feature_names.as_ref()?;
774        let importances = self.feature_importances();
775        let mut pairs: Vec<(String, f64)> = names
776            .iter()
777            .zip(importances.iter().chain(core::iter::repeat(&0.0)))
778            .map(|(n, &v)| (n.clone(), v))
779            .collect();
780        pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
781        Some(pairs)
782    }
783
784    /// Train on a single sample with named features.
785    ///
786    /// Converts a `HashMap<String, f64>` of named features into a positional
787    /// vector using the configured feature names. Missing features default to 0.0.
788    ///
789    /// # Panics
790    ///
791    /// Panics if `feature_names` is not configured.
792    #[cfg(feature = "std")]
793    pub fn train_one_named(
794        &mut self,
795        features: &std::collections::HashMap<alloc::string::String, f64>,
796        target: f64,
797    ) {
798        let names = self
799            .config
800            .feature_names
801            .as_ref()
802            .expect("train_one_named requires feature_names to be configured");
803        let vec: Vec<f64> = names
804            .iter()
805            .map(|name| features.get(name).copied().unwrap_or(0.0))
806            .collect();
807        self.train_one(&(&vec[..], target));
808    }
809
810    /// Predict with named features.
811    ///
812    /// Converts named features into a positional vector, same as `train_one_named`.
813    ///
814    /// # Panics
815    ///
816    /// Panics if `feature_names` is not configured.
817    #[cfg(feature = "std")]
818    pub fn predict_named(
819        &self,
820        features: &std::collections::HashMap<alloc::string::String, f64>,
821    ) -> f64 {
822        let names = self
823            .config
824            .feature_names
825            .as_ref()
826            .expect("predict_named requires feature_names to be configured");
827        let vec: Vec<f64> = names
828            .iter()
829            .map(|name| features.get(name).copied().unwrap_or(0.0))
830            .collect();
831        self.predict(&vec)
832    }
833
834    // NOTE: explain() and explain_named() require the `explain` module which
835    // lives in the full `irithyll` crate, not in `irithyll-core`. Those methods
836    // are provided via the re-export layer in `irithyll::ensemble`.
837
838    /// Compute tree contribution standard deviation (σ proxy for adaptive_mts).
839    ///
840    /// Measures how much individual tree predictions vary across the ensemble,
841    /// which serves as a proxy for model uncertainty in the base SGBT
842    /// (the distributional variant uses honest_sigma instead).
843    fn contribution_variance(&self, features: &[f64]) -> f64 {
844        let n = self.steps.len();
845        if n <= 1 {
846            return 0.0;
847        }
848
849        let lr = self.config.learning_rate;
850        let mut sum = 0.0;
851        let mut sq_sum = 0.0;
852        for step in &self.steps {
853            let c = lr * step.predict(features);
854            sum += c;
855            sq_sum += c * c;
856        }
857        let n_f = n as f64;
858        let mean = sum / n_f;
859        let var = (sq_sum / n_f) - (mean * mean);
860        // Bessel's correction + numerical safety
861        crate::math::sqrt((var.abs() * n_f / (n_f - 1.0)).max(0.0))
862    }
863
864    /// Refresh auto-bandwidths if any tree has been replaced since last computation.
865    fn refresh_bandwidths(&mut self) {
866        let current_sum: u64 = self.steps.iter().map(|s| s.slot().replacements()).sum();
867        if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
868            self.auto_bandwidths = self.compute_auto_bandwidths();
869            self.last_replacement_sum = current_sum;
870        }
871    }
872
873    /// Compute per-feature auto-calibrated bandwidths from all trees.
874    ///
875    /// For each feature, collects all split thresholds across all trees,
876    /// computes the median gap between consecutive unique thresholds, and
877    /// returns `median_gap * K` (K = 2.0).
878    ///
879    /// Edge cases:
880    /// - Feature with < 3 unique thresholds: `range / n_bins * K`
881    /// - Feature never split on (< 2 unique thresholds): `f64::INFINITY` (hard routing)
882    fn compute_auto_bandwidths(&self) -> Vec<f64> {
883        const K: f64 = 2.0;
884
885        // Determine n_features from the trees
886        let n_features = self
887            .steps
888            .iter()
889            .filter_map(|s| s.slot().active_tree().n_features())
890            .max()
891            .unwrap_or(0);
892
893        if n_features == 0 {
894            return Vec::new();
895        }
896
897        // Collect all thresholds from all trees per feature
898        let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
899
900        for step in &self.steps {
901            let tree_thresholds = step
902                .slot()
903                .active_tree()
904                .collect_split_thresholds_per_feature();
905            for (i, ts) in tree_thresholds.into_iter().enumerate() {
906                if i < n_features {
907                    all_thresholds[i].extend(ts);
908                }
909            }
910        }
911
912        let n_bins = self.config.n_bins as f64;
913
914        // Compute per-feature bandwidth
915        all_thresholds
916            .iter()
917            .map(|ts| {
918                if ts.is_empty() {
919                    return f64::INFINITY; // Never split on → hard routing
920                }
921
922                let mut sorted = ts.clone();
923                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
924                sorted.dedup_by(|a, b| crate::math::abs(*a - *b) < 1e-15);
925
926                if sorted.len() < 2 {
927                    return f64::INFINITY; // Single threshold → hard routing
928                }
929
930                // Compute gaps between consecutive unique thresholds
931                let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
932
933                if sorted.len() < 3 {
934                    // Fallback: feature_range / n_bins * K
935                    let range = sorted.last().unwrap() - sorted.first().unwrap();
936                    if range < 1e-15 {
937                        return f64::INFINITY;
938                    }
939                    return (range / n_bins) * K;
940                }
941
942                // Median gap
943                gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
944                let median_gap = if gaps.len() % 2 == 0 {
945                    (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
946                } else {
947                    gaps[gaps.len() / 2]
948                };
949
950                if median_gap < 1e-15 {
951                    f64::INFINITY
952                } else {
953                    median_gap * K
954                }
955            })
956            .collect()
957    }
958
959    /// Reset the ensemble to initial state.
960    pub fn reset(&mut self) {
961        for step in &mut self.steps {
962            step.reset();
963        }
964        self.base_prediction = 0.0;
965        self.base_initialized = false;
966        self.initial_targets.clear();
967        self.samples_seen = 0;
968        self.rng_state = self.config.seed;
969        self.rolling_contribution_sigma = 0.0;
970        self.auto_bandwidths.clear();
971        self.last_replacement_sum = 0;
972    }
973
974    /// Serialize the model into a [`ModelState`](crate::serde_support::ModelState).
975    ///
976    /// Auto-detects the [`LossType`](crate::loss::LossType) from the loss
977    /// function's [`Loss::loss_type()`] implementation.
978    ///
979    /// # Errors
980    ///
981    /// Returns [`IrithyllError::Serialization`](crate::IrithyllError::Serialization)
982    /// if the loss does not implement `loss_type()` (returns `None`). For custom
983    /// losses, use [`to_model_state_with`](Self::to_model_state_with) instead.
984    #[cfg(feature = "_serde_support")]
985    pub fn to_model_state(&self) -> crate::error::Result<crate::serde_support::ModelState> {
986        let loss_type = self.loss.loss_type().ok_or_else(|| {
987            crate::error::IrithyllError::Serialization(
988                "cannot auto-detect loss type for serialization: \
989                 implement Loss::loss_type() or use to_model_state_with()"
990                    .into(),
991            )
992        })?;
993        Ok(self.to_model_state_with(loss_type))
994    }
995
996    /// Serialize the model with an explicit [`LossType`](crate::loss::LossType) tag.
997    ///
998    /// Use this for custom loss functions that don't implement `loss_type()`.
999    #[cfg(feature = "_serde_support")]
1000    pub fn to_model_state_with(
1001        &self,
1002        loss_type: crate::loss::LossType,
1003    ) -> crate::serde_support::ModelState {
1004        use crate::serde_support::{ModelState, StepSnapshot};
1005
1006        let steps = self
1007            .steps
1008            .iter()
1009            .map(|step| {
1010                let slot = step.slot();
1011                let tree_snap = snapshot_tree(slot.active_tree());
1012                let alt_snap = slot.alternate_tree().map(snapshot_tree);
1013                let drift_state = slot.detector().serialize_state();
1014                let alt_drift_state = slot.alt_detector().and_then(|d| d.serialize_state());
1015                StepSnapshot {
1016                    tree: tree_snap,
1017                    alternate_tree: alt_snap,
1018                    drift_state,
1019                    alt_drift_state,
1020                }
1021            })
1022            .collect();
1023
1024        ModelState {
1025            config: self.config.clone(),
1026            loss_type,
1027            base_prediction: self.base_prediction,
1028            base_initialized: self.base_initialized,
1029            initial_targets: self.initial_targets.clone(),
1030            initial_target_count: self.initial_target_count,
1031            samples_seen: self.samples_seen,
1032            rng_state: self.rng_state,
1033            steps,
1034            rolling_mean_error: self.rolling_mean_error,
1035            contribution_ewma: self.contribution_ewma.clone(),
1036            low_contrib_count: self.low_contrib_count.clone(),
1037        }
1038    }
1039}
1040
1041// ---------------------------------------------------------------------------
1042// DynSGBT: deserialization returns a dynamically-dispatched model
1043// ---------------------------------------------------------------------------
1044
1045#[cfg(feature = "_serde_support")]
1046impl SGBT<Box<dyn Loss>> {
1047    /// Reconstruct an SGBT model from a [`ModelState`](crate::serde_support::ModelState).
1048    ///
1049    /// Returns a [`DynSGBT`] (`SGBT<Box<dyn Loss>>`) because the concrete
1050    /// loss type is determined at runtime from the serialized tag.
1051    ///
1052    /// Rebuilds the full ensemble including tree topology and leaf values.
1053    /// Histogram accumulators are left empty and will rebuild from continued
1054    /// training. If drift detector state was serialized, it is restored;
1055    /// otherwise a fresh detector is created from the config.
1056    pub fn from_model_state(state: crate::serde_support::ModelState) -> Self {
1057        use crate::ensemble::replacement::TreeSlot;
1058
1059        let loss = state.loss_type.into_loss();
1060
1061        let leaf_decay_alpha = state
1062            .config
1063            .leaf_half_life
1064            .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
1065        let max_tree_samples = state.config.max_tree_samples;
1066
1067        let steps: Vec<BoostingStep> = state
1068            .steps
1069            .iter()
1070            .enumerate()
1071            .map(|(i, step_snap)| {
1072                let tree_config = TreeConfig::new()
1073                    .max_depth(state.config.max_depth)
1074                    .n_bins(state.config.n_bins)
1075                    .lambda(state.config.lambda)
1076                    .gamma(state.config.gamma)
1077                    .grace_period(state.config.grace_period)
1078                    .delta(state.config.delta)
1079                    .feature_subsample_rate(state.config.feature_subsample_rate)
1080                    .leaf_decay_alpha_opt(leaf_decay_alpha)
1081                    .split_reeval_interval_opt(state.config.split_reeval_interval)
1082                    .feature_types_opt(state.config.feature_types.clone())
1083                    .gradient_clip_sigma_opt(state.config.gradient_clip_sigma)
1084                    .monotone_constraints_opt(state.config.monotone_constraints.clone())
1085                    .adaptive_depth_opt(state.config.adaptive_depth)
1086                    .leaf_model_type(state.config.leaf_model_type.clone())
1087                    .seed(state.config.seed ^ (i as u64));
1088
1089                let active = rebuild_tree(&step_snap.tree, tree_config.clone());
1090                let alternate = step_snap
1091                    .alternate_tree
1092                    .as_ref()
1093                    .map(|snap| rebuild_tree(snap, tree_config.clone()));
1094
1095                let mut detector = state.config.drift_detector.create();
1096                if let Some(ref ds) = step_snap.drift_state {
1097                    detector.restore_state(ds);
1098                }
1099                let mut slot = TreeSlot::from_trees(
1100                    active,
1101                    alternate,
1102                    tree_config,
1103                    detector,
1104                    max_tree_samples,
1105                );
1106                if let Some(ref ads) = step_snap.alt_drift_state {
1107                    if let Some(alt_det) = slot.alt_detector_mut() {
1108                        alt_det.restore_state(ads);
1109                    }
1110                }
1111                BoostingStep::from_slot(slot)
1112            })
1113            .collect();
1114
1115        let n = steps.len();
1116        let has_pruning = state.config.quality_prune_alpha.is_some()
1117            || state.config.proactive_prune_interval.is_some();
1118
1119        // Restore pruning state if available, otherwise initialize
1120        let contribution_ewma = if !state.contribution_ewma.is_empty() {
1121            state.contribution_ewma
1122        } else if has_pruning {
1123            vec![0.0; n]
1124        } else {
1125            Vec::new()
1126        };
1127        let low_contrib_count = if !state.low_contrib_count.is_empty() {
1128            state.low_contrib_count
1129        } else if has_pruning {
1130            vec![0; n]
1131        } else {
1132            Vec::new()
1133        };
1134
1135        Self {
1136            config: state.config,
1137            steps,
1138            loss,
1139            base_prediction: state.base_prediction,
1140            base_initialized: state.base_initialized,
1141            initial_targets: state.initial_targets,
1142            initial_target_count: state.initial_target_count,
1143            samples_seen: state.samples_seen,
1144            rng_state: state.rng_state,
1145            contribution_ewma,
1146            low_contrib_count,
1147            rolling_mean_error: state.rolling_mean_error,
1148            rolling_contribution_sigma: 0.0,
1149            auto_bandwidths: Vec::new(),
1150            last_replacement_sum: 0,
1151        }
1152    }
1153}
1154
1155// ---------------------------------------------------------------------------
1156// Shared snapshot/rebuild helpers for serde (used by SGBT + DistributionalSGBT)
1157// ---------------------------------------------------------------------------
1158
1159/// Snapshot a [`HoeffdingTree`] into a serializable [`TreeSnapshot`].
1160#[cfg(feature = "_serde_support")]
1161pub(crate) fn snapshot_tree(
1162    tree: &crate::tree::hoeffding::HoeffdingTree,
1163) -> crate::serde_support::TreeSnapshot {
1164    use crate::serde_support::TreeSnapshot;
1165    use crate::tree::StreamingTree;
1166    let arena = tree.arena();
1167    TreeSnapshot {
1168        feature_idx: arena.feature_idx.clone(),
1169        threshold: arena.threshold.clone(),
1170        left: arena.left.iter().map(|id| id.0).collect(),
1171        right: arena.right.iter().map(|id| id.0).collect(),
1172        leaf_value: arena.leaf_value.clone(),
1173        is_leaf: arena.is_leaf.clone(),
1174        depth: arena.depth.clone(),
1175        sample_count: arena.sample_count.clone(),
1176        n_features: tree.n_features(),
1177        samples_seen: tree.n_samples_seen(),
1178        rng_state: tree.rng_state(),
1179        categorical_mask: arena.categorical_mask.clone(),
1180    }
1181}
1182
1183/// Rebuild a [`HoeffdingTree`] from a [`TreeSnapshot`] and a [`TreeConfig`].
1184#[cfg(feature = "_serde_support")]
1185pub(crate) fn rebuild_tree(
1186    snapshot: &crate::serde_support::TreeSnapshot,
1187    tree_config: TreeConfig,
1188) -> crate::tree::hoeffding::HoeffdingTree {
1189    use crate::tree::hoeffding::HoeffdingTree;
1190    use crate::tree::node::{NodeId, TreeArena};
1191
1192    let mut arena = TreeArena::new();
1193    let n = snapshot.feature_idx.len();
1194
1195    for i in 0..n {
1196        arena.feature_idx.push(snapshot.feature_idx[i]);
1197        arena.threshold.push(snapshot.threshold[i]);
1198        arena.left.push(NodeId(snapshot.left[i]));
1199        arena.right.push(NodeId(snapshot.right[i]));
1200        arena.leaf_value.push(snapshot.leaf_value[i]);
1201        arena.is_leaf.push(snapshot.is_leaf[i]);
1202        arena.depth.push(snapshot.depth[i]);
1203        arena.sample_count.push(snapshot.sample_count[i]);
1204        let mask = snapshot.categorical_mask.get(i).copied().flatten();
1205        arena.categorical_mask.push(mask);
1206    }
1207
1208    HoeffdingTree::from_arena(
1209        tree_config,
1210        arena,
1211        snapshot.n_features,
1212        snapshot.samples_seen,
1213        snapshot.rng_state,
1214    )
1215}
1216
1217#[cfg(test)]
1218mod tests {
1219    use super::*;
1220    use alloc::boxed::Box;
1221    use alloc::vec;
1222    use alloc::vec::Vec;
1223
1224    fn default_config() -> SGBTConfig {
1225        SGBTConfig::builder()
1226            .n_steps(10)
1227            .learning_rate(0.1)
1228            .grace_period(20)
1229            .max_depth(4)
1230            .n_bins(16)
1231            .build()
1232            .unwrap()
1233    }
1234
1235    #[test]
1236    fn new_model_predicts_zero() {
1237        let model = SGBT::new(default_config());
1238        let pred = model.predict(&[1.0, 2.0, 3.0]);
1239        assert!(pred.abs() < 1e-12);
1240    }
1241
1242    #[test]
1243    fn train_one_does_not_panic() {
1244        let mut model = SGBT::new(default_config());
1245        model.train_one(&Sample::new(vec![1.0, 2.0, 3.0], 5.0));
1246        assert_eq!(model.n_samples_seen(), 1);
1247    }
1248
1249    #[test]
1250    fn prediction_changes_after_training() {
1251        let mut model = SGBT::new(default_config());
1252        let features = vec![1.0, 2.0, 3.0];
1253        for i in 0..100 {
1254            model.train_one(&Sample::new(features.clone(), (i as f64) * 0.1));
1255        }
1256        let pred = model.predict(&features);
1257        assert!(pred.is_finite());
1258    }
1259
1260    #[test]
1261    fn linear_signal_rmse_improves() {
1262        let config = SGBTConfig::builder()
1263            .n_steps(20)
1264            .learning_rate(0.1)
1265            .grace_period(10)
1266            .max_depth(3)
1267            .n_bins(16)
1268            .build()
1269            .unwrap();
1270        let mut model = SGBT::new(config);
1271
1272        let mut rng: u64 = 12345;
1273        let mut early_errors = Vec::new();
1274        let mut late_errors = Vec::new();
1275
1276        for i in 0..500 {
1277            rng ^= rng << 13;
1278            rng ^= rng >> 7;
1279            rng ^= rng << 17;
1280            let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1281            rng ^= rng << 13;
1282            rng ^= rng >> 7;
1283            rng ^= rng << 17;
1284            let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1285            let target = 2.0 * x1 + 3.0 * x2;
1286
1287            let pred = model.predict(&[x1, x2]);
1288            let error = (pred - target).powi(2);
1289
1290            if (50..150).contains(&i) {
1291                early_errors.push(error);
1292            }
1293            if i >= 400 {
1294                late_errors.push(error);
1295            }
1296
1297            model.train_one(&Sample::new(vec![x1, x2], target));
1298        }
1299
1300        let early_rmse = (early_errors.iter().sum::<f64>() / early_errors.len() as f64).sqrt();
1301        let late_rmse = (late_errors.iter().sum::<f64>() / late_errors.len() as f64).sqrt();
1302
1303        assert!(
1304            late_rmse < early_rmse,
1305            "RMSE should decrease: early={:.4}, late={:.4}",
1306            early_rmse,
1307            late_rmse
1308        );
1309    }
1310
1311    #[test]
1312    fn train_batch_equivalent_to_sequential() {
1313        let config = default_config();
1314        let mut model_seq = SGBT::new(config.clone());
1315        let mut model_batch = SGBT::new(config);
1316
1317        let samples: Vec<Sample> = (0..20)
1318            .map(|i| {
1319                let x = i as f64 * 0.5;
1320                Sample::new(vec![x, x * 2.0], x * 3.0)
1321            })
1322            .collect();
1323
1324        for s in &samples {
1325            model_seq.train_one(s);
1326        }
1327        model_batch.train_batch(&samples);
1328
1329        let pred_seq = model_seq.predict(&[1.0, 2.0]);
1330        let pred_batch = model_batch.predict(&[1.0, 2.0]);
1331
1332        assert!(
1333            (pred_seq - pred_batch).abs() < 1e-10,
1334            "seq={}, batch={}",
1335            pred_seq,
1336            pred_batch
1337        );
1338    }
1339
1340    #[test]
1341    fn reset_returns_to_initial() {
1342        let mut model = SGBT::new(default_config());
1343        for i in 0..100 {
1344            model.train_one(&Sample::new(vec![1.0, 2.0], i as f64));
1345        }
1346        model.reset();
1347        assert_eq!(model.n_samples_seen(), 0);
1348        assert!(!model.is_initialized());
1349        assert!(model.predict(&[1.0, 2.0]).abs() < 1e-12);
1350    }
1351
1352    #[test]
1353    fn base_prediction_initializes() {
1354        let mut model = SGBT::new(default_config());
1355        for i in 0..50 {
1356            model.train_one(&Sample::new(vec![1.0], i as f64 + 100.0));
1357        }
1358        assert!(model.is_initialized());
1359        let expected = (100.0 + 149.0) / 2.0;
1360        assert!((model.base_prediction() - expected).abs() < 1.0);
1361    }
1362
1363    #[test]
1364    fn with_loss_uses_custom_loss() {
1365        use crate::loss::logistic::LogisticLoss;
1366        let model = SGBT::with_loss(default_config(), LogisticLoss);
1367        let pred = model.predict_transformed(&[1.0, 2.0]);
1368        assert!(
1369            (pred - 0.5).abs() < 1e-6,
1370            "sigmoid(0) should be 0.5, got {}",
1371            pred
1372        );
1373    }
1374
1375    #[test]
1376    fn ewma_config_propagates_and_trains() {
1377        let config = SGBTConfig::builder()
1378            .n_steps(5)
1379            .learning_rate(0.1)
1380            .grace_period(10)
1381            .max_depth(3)
1382            .n_bins(16)
1383            .leaf_half_life(50)
1384            .build()
1385            .unwrap();
1386        let mut model = SGBT::new(config);
1387
1388        for i in 0..200 {
1389            let x = (i as f64) * 0.1;
1390            model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
1391        }
1392
1393        let pred = model.predict(&[1.0, 2.0]);
1394        assert!(
1395            pred.is_finite(),
1396            "EWMA-enabled model should produce finite predictions, got {}",
1397            pred
1398        );
1399    }
1400
1401    #[test]
1402    fn max_tree_samples_config_propagates() {
1403        let config = SGBTConfig::builder()
1404            .n_steps(5)
1405            .learning_rate(0.1)
1406            .grace_period(10)
1407            .max_depth(3)
1408            .n_bins(16)
1409            .max_tree_samples(200)
1410            .build()
1411            .unwrap();
1412        let mut model = SGBT::new(config);
1413
1414        for i in 0..500 {
1415            let x = (i as f64) * 0.1;
1416            model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
1417        }
1418
1419        let pred = model.predict(&[1.0, 2.0]);
1420        assert!(
1421            pred.is_finite(),
1422            "max_tree_samples model should produce finite predictions, got {}",
1423            pred
1424        );
1425    }
1426
1427    #[test]
1428    fn split_reeval_config_propagates() {
1429        let config = SGBTConfig::builder()
1430            .n_steps(5)
1431            .learning_rate(0.1)
1432            .grace_period(10)
1433            .max_depth(2)
1434            .n_bins(16)
1435            .split_reeval_interval(50)
1436            .build()
1437            .unwrap();
1438        let mut model = SGBT::new(config);
1439
1440        let mut rng: u64 = 12345;
1441        for _ in 0..1000 {
1442            rng ^= rng << 13;
1443            rng ^= rng >> 7;
1444            rng ^= rng << 17;
1445            let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1446            rng ^= rng << 13;
1447            rng ^= rng >> 7;
1448            rng ^= rng << 17;
1449            let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1450            let target = 2.0 * x1 + 3.0 * x2;
1451            model.train_one(&Sample::new(vec![x1, x2], target));
1452        }
1453
1454        let pred = model.predict(&[1.0, 2.0]);
1455        assert!(
1456            pred.is_finite(),
1457            "split re-eval model should produce finite predictions, got {}",
1458            pred
1459        );
1460    }
1461
1462    #[test]
1463    fn loss_accessor_works() {
1464        use crate::loss::logistic::LogisticLoss;
1465        let model = SGBT::with_loss(default_config(), LogisticLoss);
1466        // Verify we can access the concrete loss type
1467        let _loss: &LogisticLoss = model.loss();
1468        assert_eq!(_loss.n_outputs(), 1);
1469    }
1470
1471    #[test]
1472    fn clone_produces_independent_copy() {
1473        let config = default_config();
1474        let mut model = SGBT::new(config);
1475
1476        // Train the original on some data
1477        let mut rng: u64 = 99999;
1478        for _ in 0..200 {
1479            rng ^= rng << 13;
1480            rng ^= rng >> 7;
1481            rng ^= rng << 17;
1482            let x = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1483            let target = 2.0 * x + 1.0;
1484            model.train_one(&Sample::new(vec![x], target));
1485        }
1486
1487        // Clone the model
1488        let mut cloned = model.clone();
1489
1490        // Both should produce identical predictions
1491        let test_features = [3.0];
1492        let pred_original = model.predict(&test_features);
1493        let pred_cloned = cloned.predict(&test_features);
1494        assert!(
1495            (pred_original - pred_cloned).abs() < 1e-12,
1496            "clone should predict identically: original={pred_original}, cloned={pred_cloned}"
1497        );
1498
1499        // Train only the clone further -- models should diverge
1500        for _ in 0..200 {
1501            rng ^= rng << 13;
1502            rng ^= rng >> 7;
1503            rng ^= rng << 17;
1504            let x = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
1505            let target = -3.0 * x + 5.0; // Different relationship
1506            cloned.train_one(&Sample::new(vec![x], target));
1507        }
1508
1509        let pred_original_after = model.predict(&test_features);
1510        let pred_cloned_after = cloned.predict(&test_features);
1511
1512        // Original should be unchanged
1513        assert!(
1514            (pred_original - pred_original_after).abs() < 1e-12,
1515            "original should be unchanged after training clone"
1516        );
1517
1518        // Clone should have diverged
1519        assert!(
1520            (pred_original_after - pred_cloned_after).abs() > 1e-6,
1521            "clone should diverge after independent training"
1522        );
1523    }
1524
1525    // -------------------------------------------------------------------
1526    // predict_with_confidence returns finite values
1527    // -------------------------------------------------------------------
1528    #[test]
1529    fn predict_with_confidence_finite() {
1530        let config = SGBTConfig::builder()
1531            .n_steps(5)
1532            .grace_period(10)
1533            .build()
1534            .unwrap();
1535        let mut model = SGBT::new(config);
1536
1537        // Train enough to initialize
1538        for i in 0..100 {
1539            let x = i as f64 * 0.1;
1540            model.train_one(&(&[x, x * 2.0][..], x + 1.0));
1541        }
1542
1543        let (pred, confidence) = model.predict_with_confidence(&[1.0, 2.0]);
1544        assert!(pred.is_finite(), "prediction should be finite");
1545        assert!(confidence.is_finite(), "confidence should be finite");
1546        assert!(
1547            confidence > 0.0,
1548            "confidence should be positive after training"
1549        );
1550    }
1551
1552    // -------------------------------------------------------------------
1553    // predict_with_confidence positive after training
1554    // -------------------------------------------------------------------
1555    #[test]
1556    fn predict_with_confidence_positive_after_training() {
1557        let config = SGBTConfig::builder()
1558            .n_steps(5)
1559            .grace_period(10)
1560            .build()
1561            .unwrap();
1562        let mut model = SGBT::new(config);
1563
1564        // Train enough to initialize and build structure
1565        for i in 0..200 {
1566            let x = i as f64 * 0.05;
1567            model.train_one(&(&[x][..], x * 2.0));
1568        }
1569
1570        let (pred, confidence) = model.predict_with_confidence(&[1.0]);
1571
1572        assert!(pred.is_finite(), "prediction should be finite");
1573        assert!(
1574            confidence > 0.0 && confidence.is_finite(),
1575            "confidence should be finite and positive, got {}",
1576            confidence,
1577        );
1578
1579        // Multiple queries should give consistent confidence
1580        let (pred2, conf2) = model.predict_with_confidence(&[1.0]);
1581        assert!(
1582            (pred - pred2).abs() < 1e-12,
1583            "same input should give same prediction"
1584        );
1585        assert!(
1586            (confidence - conf2).abs() < 1e-12,
1587            "same input should give same confidence"
1588        );
1589    }
1590
1591    // -------------------------------------------------------------------
1592    // predict_with_confidence agrees with predict on point estimate
1593    // -------------------------------------------------------------------
1594    #[test]
1595    fn predict_with_confidence_matches_predict() {
1596        let config = SGBTConfig::builder()
1597            .n_steps(10)
1598            .grace_period(10)
1599            .build()
1600            .unwrap();
1601        let mut model = SGBT::new(config);
1602
1603        for i in 0..200 {
1604            let x = (i as f64 - 100.0) * 0.01;
1605            model.train_one(&(&[x, x * x][..], x * 3.0 + 1.0));
1606        }
1607
1608        let pred = model.predict(&[0.5, 0.25]);
1609        let (conf_pred, _) = model.predict_with_confidence(&[0.5, 0.25]);
1610
1611        assert!(
1612            (pred - conf_pred).abs() < 1e-10,
1613            "prediction mismatch: predict()={} vs predict_with_confidence()={}",
1614            pred,
1615            conf_pred,
1616        );
1617    }
1618
1619    // -------------------------------------------------------------------
1620    // gradient clipping config round-trips through builder
1621    // -------------------------------------------------------------------
1622    #[test]
1623    fn gradient_clip_config_builder() {
1624        let config = SGBTConfig::builder()
1625            .n_steps(10)
1626            .gradient_clip_sigma(3.0)
1627            .build()
1628            .unwrap();
1629
1630        assert_eq!(config.gradient_clip_sigma, Some(3.0));
1631    }
1632
1633    // -------------------------------------------------------------------
1634    // monotonic constraints config round-trips through builder
1635    // -------------------------------------------------------------------
1636    #[test]
1637    fn monotone_constraints_config_builder() {
1638        let config = SGBTConfig::builder()
1639            .n_steps(10)
1640            .monotone_constraints(vec![1, -1, 0])
1641            .build()
1642            .unwrap();
1643
1644        assert_eq!(config.monotone_constraints, Some(vec![1, -1, 0]));
1645    }
1646
1647    // -------------------------------------------------------------------
1648    // monotonic constraints validation rejects invalid values
1649    // -------------------------------------------------------------------
1650    #[test]
1651    fn monotone_constraints_invalid_value_rejected() {
1652        let result = SGBTConfig::builder()
1653            .n_steps(10)
1654            .monotone_constraints(vec![1, 2, 0])
1655            .build();
1656
1657        assert!(result.is_err(), "constraint value 2 should be rejected");
1658    }
1659
1660    // -------------------------------------------------------------------
1661    // gradient clipping validation rejects non-positive sigma
1662    // -------------------------------------------------------------------
1663    #[test]
1664    fn gradient_clip_sigma_negative_rejected() {
1665        let result = SGBTConfig::builder()
1666            .n_steps(10)
1667            .gradient_clip_sigma(-1.0)
1668            .build();
1669
1670        assert!(result.is_err(), "negative sigma should be rejected");
1671    }
1672
1673    // -------------------------------------------------------------------
1674    // gradient clipping ensemble-level reduces outlier impact
1675    // -------------------------------------------------------------------
1676    #[test]
1677    fn gradient_clipping_reduces_outlier_impact() {
1678        // Without clipping
1679        let config_no_clip = SGBTConfig::builder()
1680            .n_steps(5)
1681            .grace_period(10)
1682            .build()
1683            .unwrap();
1684        let mut model_no_clip = SGBT::new(config_no_clip);
1685
1686        // With clipping
1687        let config_clip = SGBTConfig::builder()
1688            .n_steps(5)
1689            .grace_period(10)
1690            .gradient_clip_sigma(3.0)
1691            .build()
1692            .unwrap();
1693        let mut model_clip = SGBT::new(config_clip);
1694
1695        // Train both on identical normal data
1696        for i in 0..100 {
1697            let x = (i as f64) * 0.01;
1698            let sample = (&[x][..], x * 2.0);
1699            model_no_clip.train_one(&sample);
1700            model_clip.train_one(&sample);
1701        }
1702
1703        let pred_no_clip_before = model_no_clip.predict(&[0.5]);
1704        let pred_clip_before = model_clip.predict(&[0.5]);
1705
1706        // Inject outlier
1707        let outlier = (&[0.5_f64][..], 10000.0);
1708        model_no_clip.train_one(&outlier);
1709        model_clip.train_one(&outlier);
1710
1711        let pred_no_clip_after = model_no_clip.predict(&[0.5]);
1712        let pred_clip_after = model_clip.predict(&[0.5]);
1713
1714        let delta_no_clip = (pred_no_clip_after - pred_no_clip_before).abs();
1715        let delta_clip = (pred_clip_after - pred_clip_before).abs();
1716
1717        // Clipped model should be less affected by the outlier
1718        assert!(
1719            delta_clip <= delta_no_clip + 1e-10,
1720            "clipped model should be less affected: delta_clip={}, delta_no_clip={}",
1721            delta_clip,
1722            delta_no_clip,
1723        );
1724    }
1725
1726    // -------------------------------------------------------------------
1727    // train_batch_with_callback fires at correct intervals
1728    // -------------------------------------------------------------------
1729    #[test]
1730    fn train_batch_with_callback_fires() {
1731        let config = SGBTConfig::builder()
1732            .n_steps(3)
1733            .grace_period(5)
1734            .build()
1735            .unwrap();
1736        let mut model = SGBT::new(config);
1737
1738        let data: Vec<(Vec<f64>, f64)> = (0..25)
1739            .map(|i| (vec![i as f64 * 0.1], i as f64 * 0.5))
1740            .collect();
1741
1742        let mut callbacks = Vec::new();
1743        model.train_batch_with_callback(&data, 10, |n| {
1744            callbacks.push(n);
1745        });
1746
1747        // Should fire at 10, 20, and 25 (final)
1748        assert_eq!(callbacks, vec![10, 20, 25]);
1749    }
1750
1751    // -------------------------------------------------------------------
1752    // train_batch_subsampled produces deterministic subset
1753    // -------------------------------------------------------------------
1754    #[test]
1755    fn train_batch_subsampled_trains_subset() {
1756        let config = SGBTConfig::builder()
1757            .n_steps(3)
1758            .grace_period(5)
1759            .build()
1760            .unwrap();
1761        let mut model = SGBT::new(config);
1762
1763        let data: Vec<(Vec<f64>, f64)> = (0..100)
1764            .map(|i| (vec![i as f64 * 0.01], i as f64 * 0.1))
1765            .collect();
1766
1767        // Train on only 20 of 100 samples
1768        model.train_batch_subsampled(&data, 20);
1769
1770        // Model should have seen some samples
1771        assert!(
1772            model.n_samples_seen() > 0,
1773            "model should have trained on subset"
1774        );
1775        assert!(
1776            model.n_samples_seen() <= 20,
1777            "model should have trained at most 20 samples, got {}",
1778            model.n_samples_seen(),
1779        );
1780    }
1781
1782    // -------------------------------------------------------------------
1783    // train_batch_subsampled full dataset = train_batch
1784    // -------------------------------------------------------------------
1785    #[test]
1786    fn train_batch_subsampled_full_equals_batch() {
1787        let config1 = SGBTConfig::builder()
1788            .n_steps(3)
1789            .grace_period(5)
1790            .build()
1791            .unwrap();
1792        let config2 = config1.clone();
1793
1794        let mut model1 = SGBT::new(config1);
1795        let mut model2 = SGBT::new(config2);
1796
1797        let data: Vec<(Vec<f64>, f64)> = (0..50)
1798            .map(|i| (vec![i as f64 * 0.1], i as f64 * 0.5))
1799            .collect();
1800
1801        model1.train_batch(&data);
1802        model2.train_batch_subsampled(&data, 1000); // max_samples > data.len()
1803
1804        // Both should have identical state
1805        assert_eq!(model1.n_samples_seen(), model2.n_samples_seen());
1806        let pred1 = model1.predict(&[2.5]);
1807        let pred2 = model2.predict(&[2.5]);
1808        assert!(
1809            (pred1 - pred2).abs() < 1e-12,
1810            "full subsample should equal batch: {} vs {}",
1811            pred1,
1812            pred2,
1813        );
1814    }
1815
1816    // -------------------------------------------------------------------
1817    // train_batch_subsampled_with_callback combines both
1818    // -------------------------------------------------------------------
1819    #[test]
1820    fn train_batch_subsampled_with_callback_works() {
1821        let config = SGBTConfig::builder()
1822            .n_steps(3)
1823            .grace_period(5)
1824            .build()
1825            .unwrap();
1826        let mut model = SGBT::new(config);
1827
1828        let data: Vec<(Vec<f64>, f64)> = (0..200)
1829            .map(|i| (vec![i as f64 * 0.01], i as f64 * 0.1))
1830            .collect();
1831
1832        let mut callbacks = Vec::new();
1833        model.train_batch_subsampled_with_callback(&data, 50, 10, |n| {
1834            callbacks.push(n);
1835        });
1836
1837        // Should have trained ~50 samples with callbacks at 10, 20, 30, 40, 50
1838        assert!(!callbacks.is_empty(), "should have received callbacks");
1839        assert_eq!(
1840            *callbacks.last().unwrap(),
1841            50,
1842            "final callback should be total samples"
1843        );
1844    }
1845
1846    // ---------------------------------------------------------------
1847    // Linear leaf model integration tests
1848    // ---------------------------------------------------------------
1849
1850    /// xorshift64 PRNG for deterministic test data.
1851    fn xorshift64(state: &mut u64) -> u64 {
1852        let mut s = *state;
1853        s ^= s << 13;
1854        s ^= s >> 7;
1855        s ^= s << 17;
1856        *state = s;
1857        s
1858    }
1859
1860    fn rand_f64(state: &mut u64) -> f64 {
1861        xorshift64(state) as f64 / u64::MAX as f64
1862    }
1863
1864    fn linear_leaves_config() -> SGBTConfig {
1865        SGBTConfig::builder()
1866            .n_steps(10)
1867            .learning_rate(0.1)
1868            .grace_period(20)
1869            .max_depth(2) // low depth -- linear leaves should shine
1870            .n_bins(16)
1871            .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1872                learning_rate: 0.1,
1873                decay: None,
1874                use_adagrad: false,
1875            })
1876            .build()
1877            .unwrap()
1878    }
1879
1880    #[test]
1881    fn linear_leaves_trains_without_panic() {
1882        let mut model = SGBT::new(linear_leaves_config());
1883        let mut rng = 42u64;
1884        for _ in 0..200 {
1885            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1886            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1887            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1888            model.train_one(&Sample::new(vec![x1, x2], y));
1889        }
1890        assert_eq!(model.n_samples_seen(), 200);
1891    }
1892
1893    #[test]
1894    fn linear_leaves_prediction_finite() {
1895        let mut model = SGBT::new(linear_leaves_config());
1896        let mut rng = 42u64;
1897        for _ in 0..200 {
1898            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1899            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1900            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1901            model.train_one(&Sample::new(vec![x1, x2], y));
1902        }
1903        let pred = model.predict(&[0.5, -0.3]);
1904        assert!(pred.is_finite(), "prediction should be finite, got {pred}");
1905    }
1906
1907    #[test]
1908    fn linear_leaves_learns_linear_target() {
1909        let mut model = SGBT::new(linear_leaves_config());
1910        let mut rng = 42u64;
1911        for _ in 0..500 {
1912            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1913            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1914            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1915            model.train_one(&Sample::new(vec![x1, x2], y));
1916        }
1917
1918        // Test on a few points -- should be reasonably close for a linear target.
1919        let mut total_error = 0.0;
1920        for _ in 0..50 {
1921            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1922            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1923            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1924            let pred = model.predict(&[x1, x2]);
1925            total_error += (pred - y).powi(2);
1926        }
1927        let mse = total_error / 50.0;
1928        assert!(
1929            mse < 5.0,
1930            "linear leaves MSE on linear target should be < 5.0, got {mse}"
1931        );
1932    }
1933
1934    #[test]
1935    fn linear_leaves_better_than_constant_at_low_depth() {
1936        // Train two models on a linear target at depth 2:
1937        // one with constant leaves, one with linear leaves.
1938        let constant_config = SGBTConfig::builder()
1939            .n_steps(10)
1940            .learning_rate(0.1)
1941            .grace_period(20)
1942            .max_depth(2)
1943            .n_bins(16)
1944            .seed(0xDEAD)
1945            .build()
1946            .unwrap();
1947        let linear_config = SGBTConfig::builder()
1948            .n_steps(10)
1949            .learning_rate(0.1)
1950            .grace_period(20)
1951            .max_depth(2)
1952            .n_bins(16)
1953            .seed(0xDEAD)
1954            .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
1955                learning_rate: 0.1,
1956                decay: None,
1957                use_adagrad: false,
1958            })
1959            .build()
1960            .unwrap();
1961
1962        let mut constant_model = SGBT::new(constant_config);
1963        let mut linear_model = SGBT::new(linear_config);
1964        let mut rng = 42u64;
1965
1966        for _ in 0..500 {
1967            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1968            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1969            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1970            let sample = Sample::new(vec![x1, x2], y);
1971            constant_model.train_one(&sample);
1972            linear_model.train_one(&sample);
1973        }
1974
1975        // Evaluate both on test set.
1976        let mut constant_mse = 0.0;
1977        let mut linear_mse = 0.0;
1978        for _ in 0..100 {
1979            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1980            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1981            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
1982            constant_mse += (constant_model.predict(&[x1, x2]) - y).powi(2);
1983            linear_mse += (linear_model.predict(&[x1, x2]) - y).powi(2);
1984        }
1985        constant_mse /= 100.0;
1986        linear_mse /= 100.0;
1987
1988        // Linear leaves should outperform constant leaves on a linear target.
1989        assert!(
1990            linear_mse < constant_mse,
1991            "linear leaves MSE ({linear_mse:.4}) should be less than constant ({constant_mse:.4})"
1992        );
1993    }
1994
1995    #[test]
1996    fn adaptive_leaves_trains_without_panic() {
1997        let config = SGBTConfig::builder()
1998            .n_steps(10)
1999            .learning_rate(0.1)
2000            .grace_period(20)
2001            .max_depth(3)
2002            .n_bins(16)
2003            .leaf_model_type(crate::tree::leaf_model::LeafModelType::Adaptive {
2004                promote_to: Box::new(crate::tree::leaf_model::LeafModelType::Linear {
2005                    learning_rate: 0.1,
2006                    decay: None,
2007                    use_adagrad: false,
2008                }),
2009            })
2010            .build()
2011            .unwrap();
2012
2013        let mut model = SGBT::new(config);
2014        let mut rng = 42u64;
2015        for _ in 0..500 {
2016            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
2017            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
2018            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
2019            model.train_one(&Sample::new(vec![x1, x2], y));
2020        }
2021        let pred = model.predict(&[0.5, -0.3]);
2022        assert!(
2023            pred.is_finite(),
2024            "adaptive leaf prediction should be finite, got {pred}"
2025        );
2026    }
2027
2028    #[test]
2029    fn linear_leaves_with_decay_trains_without_panic() {
2030        let config = SGBTConfig::builder()
2031            .n_steps(10)
2032            .learning_rate(0.1)
2033            .grace_period(20)
2034            .max_depth(3)
2035            .n_bins(16)
2036            .leaf_model_type(crate::tree::leaf_model::LeafModelType::Linear {
2037                learning_rate: 0.1,
2038                decay: Some(0.995),
2039                use_adagrad: false,
2040            })
2041            .build()
2042            .unwrap();
2043
2044        let mut model = SGBT::new(config);
2045        let mut rng = 42u64;
2046        for _ in 0..500 {
2047            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
2048            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
2049            let y = 3.0 * x1 + 2.0 * x2 + 1.0;
2050            model.train_one(&Sample::new(vec![x1, x2], y));
2051        }
2052        let pred = model.predict(&[0.5, -0.3]);
2053        assert!(
2054            pred.is_finite(),
2055            "decay leaf prediction should be finite, got {pred}"
2056        );
2057    }
2058
2059    // -------------------------------------------------------------------
2060    // predict_smooth returns finite values
2061    // -------------------------------------------------------------------
2062    #[test]
2063    fn predict_smooth_returns_finite() {
2064        let config = SGBTConfig::builder()
2065            .n_steps(5)
2066            .learning_rate(0.1)
2067            .grace_period(10)
2068            .build()
2069            .unwrap();
2070        let mut model = SGBT::new(config);
2071
2072        for i in 0..200 {
2073            let x = (i as f64) * 0.1;
2074            model.train_one(&Sample::new(vec![x, x.sin()], 2.0 * x + 1.0));
2075        }
2076
2077        let pred_hard = model.predict(&[1.0, 1.0_f64.sin()]);
2078        let pred_smooth = model.predict_smooth(&[1.0, 1.0_f64.sin()], 0.5);
2079
2080        assert!(pred_hard.is_finite(), "hard prediction should be finite");
2081        assert!(
2082            pred_smooth.is_finite(),
2083            "smooth prediction should be finite"
2084        );
2085    }
2086
2087    // -------------------------------------------------------------------
2088    // predict_smooth converges to hard predict at small bandwidth
2089    // -------------------------------------------------------------------
2090    #[test]
2091    fn predict_smooth_converges_to_hard_at_small_bandwidth() {
2092        let config = SGBTConfig::builder()
2093            .n_steps(5)
2094            .learning_rate(0.1)
2095            .grace_period(10)
2096            .build()
2097            .unwrap();
2098        let mut model = SGBT::new(config);
2099
2100        for i in 0..300 {
2101            let x = (i as f64) * 0.1;
2102            model.train_one(&Sample::new(vec![x, x * 0.5], 2.0 * x + 1.0));
2103        }
2104
2105        let features = [5.0, 2.5];
2106        let hard = model.predict(&features);
2107        let smooth = model.predict_smooth(&features, 0.001);
2108
2109        assert!(
2110            (hard - smooth).abs() < 0.5,
2111            "smooth with tiny bandwidth should approximate hard: hard={}, smooth={}",
2112            hard,
2113            smooth,
2114        );
2115    }
2116
2117    #[test]
2118    fn auto_bandwidth_computed_after_training() {
2119        let config = SGBTConfig::builder()
2120            .n_steps(5)
2121            .learning_rate(0.1)
2122            .grace_period(10)
2123            .build()
2124            .unwrap();
2125        let mut model = SGBT::new(config);
2126
2127        // Before training, no bandwidths
2128        assert!(model.auto_bandwidths().is_empty());
2129
2130        for i in 0..200 {
2131            let x = (i as f64) * 0.1;
2132            model.train_one(&Sample::new(vec![x, x * 0.5], 2.0 * x + 1.0));
2133        }
2134
2135        // After training, auto_bandwidths should be populated
2136        let bws = model.auto_bandwidths();
2137        assert_eq!(bws.len(), 2, "should have 2 feature bandwidths");
2138
2139        // predict() always uses smooth routing with auto-bandwidths
2140        let pred = model.predict(&[5.0, 2.5]);
2141        assert!(
2142            pred.is_finite(),
2143            "auto-bandwidth predict should be finite: {}",
2144            pred
2145        );
2146    }
2147
2148    #[test]
2149    fn predict_interpolated_returns_finite() {
2150        let config = SGBTConfig::builder()
2151            .n_steps(5)
2152            .learning_rate(0.01)
2153            .build()
2154            .unwrap();
2155        let mut model = SGBT::new(config);
2156
2157        for i in 0..200 {
2158            let x = (i as f64) * 0.1;
2159            model.train_one(&Sample::new(vec![x, x.sin()], x.cos()));
2160        }
2161
2162        let pred = model.predict_interpolated(&[1.0, 0.5]);
2163        assert!(
2164            pred.is_finite(),
2165            "interpolated prediction should be finite: {}",
2166            pred
2167        );
2168    }
2169
2170    #[test]
2171    fn predict_sibling_interpolated_varies_with_features() {
2172        let config = SGBTConfig::builder()
2173            .n_steps(10)
2174            .learning_rate(0.1)
2175            .grace_period(10)
2176            .max_depth(6)
2177            .delta(0.1)
2178            .build()
2179            .unwrap();
2180        let mut model = SGBT::new(config);
2181
2182        for i in 0..2000 {
2183            let x = (i as f64) * 0.01;
2184            let y = x.sin() * x + 0.5 * (x * 2.0).cos();
2185            model.train_one(&Sample::new(vec![x, x * 0.3], y));
2186        }
2187
2188        // Verify the method is callable and produces finite predictions
2189        let pred = model.predict_sibling_interpolated(&[5.0, 1.5]);
2190        assert!(pred.is_finite(), "sibling interpolated should be finite");
2191
2192        // If bandwidths are finite, verify sibling produces at least as much
2193        // variation as hard routing across a feature sweep
2194        let bws = model.auto_bandwidths();
2195        if bws.iter().any(|&b| b.is_finite()) {
2196            let hard: Vec<f64> = (0..200)
2197                .map(|i| model.predict(&[i as f64 * 0.1, i as f64 * 0.03]))
2198                .collect();
2199            let sib: Vec<f64> = (0..200)
2200                .map(|i| model.predict_sibling_interpolated(&[i as f64 * 0.1, i as f64 * 0.03]))
2201                .collect();
2202            let hc = hard
2203                .windows(2)
2204                .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2205                .count();
2206            let sc = sib
2207                .windows(2)
2208                .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2209                .count();
2210            assert!(
2211                sc >= hc,
2212                "sibling should produce >= hard changes: sib={}, hard={}",
2213                sc,
2214                hc
2215            );
2216        }
2217    }
2218
2219    #[test]
2220    fn predict_graduated_returns_finite() {
2221        let config = SGBTConfig::builder()
2222            .n_steps(5)
2223            .learning_rate(0.01)
2224            .max_tree_samples(200)
2225            .shadow_warmup(50)
2226            .build()
2227            .unwrap();
2228        let mut model = SGBT::new(config);
2229
2230        for i in 0..300 {
2231            let x = (i as f64) * 0.1;
2232            model.train_one(&Sample::new(vec![x, x.sin()], x.cos()));
2233        }
2234
2235        let pred = model.predict_graduated(&[1.0, 0.5]);
2236        assert!(
2237            pred.is_finite(),
2238            "graduated prediction should be finite: {}",
2239            pred
2240        );
2241
2242        let pred2 = model.predict_graduated_sibling_interpolated(&[1.0, 0.5]);
2243        assert!(
2244            pred2.is_finite(),
2245            "graduated+sibling prediction should be finite: {}",
2246            pred2
2247        );
2248    }
2249
2250    #[test]
2251    fn shadow_warmup_validation() {
2252        let result = SGBTConfig::builder()
2253            .n_steps(5)
2254            .learning_rate(0.01)
2255            .shadow_warmup(0)
2256            .build();
2257        assert!(result.is_err(), "shadow_warmup=0 should fail validation");
2258    }
2259
2260    // -------------------------------------------------------------------
2261    // adaptive_mts config tests
2262    // -------------------------------------------------------------------
2263
2264    #[test]
2265    fn adaptive_mts_defaults_to_none() {
2266        let cfg = SGBTConfig::default();
2267        assert!(
2268            cfg.adaptive_mts.is_none(),
2269            "adaptive_mts should default to None"
2270        );
2271    }
2272
2273    #[test]
2274    fn adaptive_mts_config_builder() {
2275        let cfg = SGBTConfig::builder()
2276            .n_steps(10)
2277            .adaptive_mts(500, 2.0)
2278            .build()
2279            .unwrap();
2280        assert_eq!(
2281            cfg.adaptive_mts,
2282            Some((500, 2.0)),
2283            "adaptive_mts should store (base_mts, k)"
2284        );
2285    }
2286
2287    #[test]
2288    fn adaptive_mts_validation_rejects_low_base() {
2289        let result = SGBTConfig::builder()
2290            .n_steps(5)
2291            .adaptive_mts(50, 1.0)
2292            .build();
2293        assert!(
2294            result.is_err(),
2295            "adaptive_mts with base_mts < 100 should fail"
2296        );
2297    }
2298
2299    #[test]
2300    fn adaptive_mts_validation_rejects_zero_k() {
2301        let result = SGBTConfig::builder()
2302            .n_steps(5)
2303            .adaptive_mts(500, 0.0)
2304            .build();
2305        assert!(result.is_err(), "adaptive_mts with k=0 should fail");
2306    }
2307
2308    #[test]
2309    fn adaptive_mts_trains_without_panic() {
2310        let config = SGBTConfig::builder()
2311            .n_steps(5)
2312            .learning_rate(0.1)
2313            .grace_period(10)
2314            .max_depth(3)
2315            .n_bins(16)
2316            .adaptive_mts(200, 1.0)
2317            .build()
2318            .unwrap();
2319        let mut model = SGBT::new(config);
2320
2321        for i in 0..500 {
2322            let x = (i as f64) * 0.1;
2323            model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
2324        }
2325
2326        let pred = model.predict(&[1.0, 2.0]);
2327        assert!(
2328            pred.is_finite(),
2329            "adaptive_mts model should produce finite predictions, got {}",
2330            pred
2331        );
2332    }
2333
2334    // -------------------------------------------------------------------
2335    // proactive_prune config tests
2336    // -------------------------------------------------------------------
2337
2338    #[test]
2339    fn proactive_prune_defaults_to_none() {
2340        let cfg = SGBTConfig::default();
2341        assert!(
2342            cfg.proactive_prune_interval.is_none(),
2343            "proactive_prune_interval should default to None"
2344        );
2345    }
2346
2347    #[test]
2348    fn proactive_prune_config_builder() {
2349        let cfg = SGBTConfig::builder()
2350            .n_steps(10)
2351            .proactive_prune_interval(500)
2352            .build()
2353            .unwrap();
2354        assert_eq!(
2355            cfg.proactive_prune_interval,
2356            Some(500),
2357            "proactive_prune_interval should be set"
2358        );
2359    }
2360
2361    #[test]
2362    fn proactive_prune_validation_rejects_low_interval() {
2363        let result = SGBTConfig::builder()
2364            .n_steps(5)
2365            .proactive_prune_interval(50)
2366            .build();
2367        assert!(
2368            result.is_err(),
2369            "proactive_prune_interval < 100 should fail"
2370        );
2371    }
2372
2373    #[test]
2374    fn proactive_prune_enables_contribution_tracking() {
2375        let config = SGBTConfig::builder()
2376            .n_steps(5)
2377            .learning_rate(0.1)
2378            .grace_period(10)
2379            .max_depth(3)
2380            .n_bins(16)
2381            .proactive_prune_interval(200)
2382            .build()
2383            .unwrap();
2384
2385        // quality_prune_alpha is None, but proactive_prune_interval is set
2386        assert!(config.quality_prune_alpha.is_none());
2387
2388        let mut model = SGBT::new(config);
2389
2390        // Train a few samples so contribution_ewma gets populated
2391        for i in 0..100 {
2392            let x = (i as f64) * 0.1;
2393            model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
2394        }
2395
2396        // contribution_ewma should be populated (not empty)
2397        // We can verify by checking that model trains without panic and produces
2398        // finite predictions -- the tracking is happening internally.
2399        let pred = model.predict(&[1.0, 2.0]);
2400        assert!(
2401            pred.is_finite(),
2402            "proactive_prune model should produce finite predictions, got {}",
2403            pred
2404        );
2405    }
2406
2407    #[test]
2408    fn proactive_prune_trains_without_panic() {
2409        let config = SGBTConfig::builder()
2410            .n_steps(5)
2411            .learning_rate(0.1)
2412            .grace_period(10)
2413            .max_depth(3)
2414            .n_bins(16)
2415            .proactive_prune_interval(200)
2416            .build()
2417            .unwrap();
2418        let mut model = SGBT::new(config);
2419
2420        // Train past the prune interval to exercise the proactive prune path
2421        for i in 0..500 {
2422            let x = (i as f64) * 0.1;
2423            model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
2424        }
2425
2426        let pred = model.predict(&[1.0, 2.0]);
2427        assert!(
2428            pred.is_finite(),
2429            "proactive_prune model should produce finite predictions after pruning, got {}",
2430            pred
2431        );
2432    }
2433}