Skip to main content

irithyll_core/ensemble/
core.rs

1//! SGBT core: struct definition, Clone/Debug, and constructors.
2//!
3//! This module isolates the structural definition and initialization logic,
4//! keeping the hot path (train_one, predict) separate for clarity.
5
6use alloc::collections::VecDeque;
7use alloc::string::String;
8use alloc::vec;
9use alloc::vec::Vec;
10
11use core::fmt;
12
13use crate::ensemble::config::SGBTConfig;
14use crate::ensemble::step::BoostingStep;
15use crate::loss::squared::SquaredLoss;
16use crate::loss::Loss;
17use crate::sample::Observation;
18#[allow(unused_imports)] // Used in doc links + tests
19use crate::sample::Sample;
20
21/// Cached diagnostic state for SGBT, separated from the core training state
22/// to improve struct clarity and cache locality in the prediction path.
23#[derive(Debug, Clone, Default)]
24#[allow(dead_code)]
25pub(crate) struct DiagnosticCache {
26    /// Previous per-tree contributions for residual alignment (cosine similarity).
27    pub(crate) prev_contributions: Vec<f64>,
28    /// Contributions from two calls ago, for delta-based alignment.
29    pub(crate) prev_prev_contributions: Vec<f64>,
30    /// Cached cosine similarity of consecutive tree contribution vectors.
31    pub(crate) cached_residual_alignment: f64,
32    /// Cached mean |G|/(H+λ)² across all leaves.
33    pub(crate) cached_reg_sensitivity: f64,
34    /// Cached F-statistic (between-leaf / within-leaf variance).
35    pub(crate) cached_depth_sufficiency: f64,
36    /// Cached trace(H/(H+λ)) across all leaves.
37    pub(crate) cached_effective_dof: f64,
38    /// Per-tree EWMA of signed contribution accuracy. Positive = helps, negative = hurts.
39    pub(crate) contribution_accuracy: Vec<f64>,
40    /// EWMA alpha for contribution accuracy tracking.
41    pub(crate) prune_alpha: f64,
42}
43
44/// Streaming Gradient Boosted Trees ensemble.
45///
46/// The primary entry point for training and prediction. Generic over `L: Loss`
47/// so the loss function's gradient/hessian calls are monomorphized (inlined)
48/// into the boosting hot loop -- no virtual dispatch overhead.
49///
50/// The default type parameter `L = SquaredLoss` means `SGBT::new(config)`
51/// creates a regression model without specifying the loss type explicitly.
52///
53/// # Examples
54///
55/// ```ignore
56/// use irithyll::{SGBTConfig, SGBT};
57///
58/// // Regression with squared loss (default):
59/// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
60/// let model = SGBT::new(config);
61/// ```
62///
63/// ```ignore
64/// use irithyll::{SGBTConfig, SGBT};
65/// use irithyll::loss::logistic::LogisticLoss;
66///
67/// // Classification with logistic loss -- no Box::new()!
68/// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
69/// let model = SGBT::with_loss(config, LogisticLoss);
70/// ```
71pub struct SGBT<L: Loss = SquaredLoss> {
72    /// Configuration.
73    pub(crate) config: SGBTConfig,
74    /// Boosting steps (one tree + drift detector each).
75    pub(crate) steps: Vec<BoostingStep>,
76    /// Loss function (monomorphized -- no vtable).
77    pub(crate) loss: L,
78    /// Base prediction (initial constant, computed from first batch of targets).
79    pub(crate) base_prediction: f64,
80    /// Whether base_prediction has been initialized.
81    pub(crate) base_initialized: bool,
82    /// Running collection of initial targets for computing base_prediction.
83    pub(crate) initial_targets: Vec<f64>,
84    /// Number of initial targets to collect before setting base_prediction.
85    pub(crate) initial_target_count: usize,
86    /// Total samples trained.
87    pub(crate) samples_seen: u64,
88    /// RNG state for variant skip logic.
89    pub(crate) rng_state: u64,
90    /// Per-step EWMA of |marginal contribution| for quality-based pruning.
91    /// Empty when `quality_prune_alpha` is `None`.
92    pub(crate) contribution_ewma: Vec<f64>,
93    /// Per-step consecutive low-contribution sample counter.
94    /// Empty when `quality_prune_alpha` is `None`.
95    pub(crate) low_contrib_count: Vec<u64>,
96    /// Rolling mean absolute error for error-weighted sample importance.
97    /// Only used when `error_weight_alpha` is `Some`.
98    pub(crate) rolling_mean_error: f64,
99    /// Per-feature auto-calibrated bandwidths for smooth prediction.
100    /// Computed from median split threshold gaps across all trees.
101    pub(crate) auto_bandwidths: Vec<f64>,
102    /// Sum of replacement counts across all steps at last bandwidth computation.
103    /// Used to detect when trees have been replaced and bandwidths need refresh.
104    pub(crate) last_replacement_sum: u64,
105    /// EWMA of contribution variance (sigma) across trees for adaptive_mts.
106    /// Used as the denominator when computing sigma_ratio for tree lifetime modulation.
107    pub(crate) rolling_contribution_sigma: f64,
108    /// Ring buffer of sigma_ratio values for end-of-cycle adaptive MTS.
109    /// Capacity = grace_period. MTS updates only at tree replacement boundaries.
110    pub(crate) sigma_ring: VecDeque<f64>,
111    /// Sum of replacement counts at last MTS update (replacement boundary detection).
112    pub(crate) mts_replacement_sum: u64,
113    // -----------------------------------------------------------------------
114    // Diagnostic caches — not used in predict hot path.
115    // -----------------------------------------------------------------------
116    /// Diagnostic caches — not used in predict hot path.
117    pub(crate) diag: DiagnosticCache,
118}
119
120impl<L: Loss + Clone> Clone for SGBT<L> {
121    fn clone(&self) -> Self {
122        Self {
123            config: self.config.clone(),
124            steps: self.steps.clone(),
125            loss: self.loss.clone(),
126            base_prediction: self.base_prediction,
127            base_initialized: self.base_initialized,
128            initial_targets: self.initial_targets.clone(),
129            initial_target_count: self.initial_target_count,
130            samples_seen: self.samples_seen,
131            rng_state: self.rng_state,
132            contribution_ewma: self.contribution_ewma.clone(),
133            low_contrib_count: self.low_contrib_count.clone(),
134            rolling_mean_error: self.rolling_mean_error,
135            auto_bandwidths: self.auto_bandwidths.clone(),
136            last_replacement_sum: self.last_replacement_sum,
137            rolling_contribution_sigma: self.rolling_contribution_sigma,
138            sigma_ring: self.sigma_ring.clone(),
139            mts_replacement_sum: self.mts_replacement_sum,
140            diag: self.diag.clone(),
141        }
142    }
143}
144
145impl<L: Loss> fmt::Debug for SGBT<L> {
146    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147        f.debug_struct("SGBT")
148            .field("n_steps", &self.steps.len())
149            .field("samples_seen", &self.samples_seen)
150            .field("base_prediction", &self.base_prediction)
151            .field("base_initialized", &self.base_initialized)
152            .finish()
153    }
154}
155
156// ---------------------------------------------------------------------------
157// Convenience constructor for the default loss (SquaredLoss)
158// ---------------------------------------------------------------------------
159
160impl SGBT<SquaredLoss> {
161    /// Create a new SGBT ensemble with squared loss (regression).
162    ///
163    /// This is the most common constructor. For classification or custom
164    /// losses, use [`with_loss`](SGBT::with_loss).
165    pub fn new(config: SGBTConfig) -> Self {
166        Self::with_loss(config, SquaredLoss)
167    }
168}
169
170// ---------------------------------------------------------------------------
171// General impl for all Loss types
172// ---------------------------------------------------------------------------
173
174impl<L: Loss> SGBT<L> {
175    /// Create a new SGBT ensemble with a specific loss function.
176    ///
177    /// The loss is stored by value (monomorphized), giving zero-cost
178    /// gradient/hessian dispatch.
179    ///
180    /// ```ignore
181    /// use irithyll::{SGBTConfig, SGBT};
182    /// use irithyll::loss::logistic::LogisticLoss;
183    ///
184    /// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
185    /// let model = SGBT::with_loss(config, LogisticLoss);
186    /// ```
187    pub fn with_loss(config: SGBTConfig, loss: L) -> Self {
188        let leaf_decay_alpha = config
189            .leaf_half_life
190            .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
191
192        let tree_config = crate::ensemble::config::build_tree_config(&config)
193            .leaf_decay_alpha_opt(leaf_decay_alpha);
194
195        let max_tree_samples = config.max_tree_samples;
196
197        let shadow_warmup = config.shadow_warmup.unwrap_or(0);
198        let steps: Vec<BoostingStep> = (0..config.n_steps)
199            .map(|i| {
200                let mut tc = tree_config.clone();
201                tc.seed = config.seed ^ (i as u64);
202                let detector = config.drift_detector.create();
203                if shadow_warmup > 0 {
204                    BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
205                } else {
206                    BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
207                }
208            })
209            .collect();
210
211        let seed = config.seed;
212        let initial_target_count = config.initial_target_count;
213        let n = config.n_steps;
214        let has_pruning =
215            config.quality_prune_alpha.is_some() || config.proactive_prune_interval.is_some();
216        let grace_period = config.grace_period;
217        Self {
218            config,
219            steps,
220            loss,
221            base_prediction: 0.0,
222            base_initialized: false,
223            initial_targets: Vec::new(),
224            initial_target_count,
225            samples_seen: 0,
226            rng_state: seed,
227            contribution_ewma: if has_pruning {
228                vec![0.0; n]
229            } else {
230                Vec::new()
231            },
232            low_contrib_count: if has_pruning { vec![0; n] } else { Vec::new() },
233            rolling_mean_error: 0.0,
234            rolling_contribution_sigma: 0.0,
235            auto_bandwidths: Vec::new(),
236            last_replacement_sum: 0,
237            sigma_ring: VecDeque::with_capacity(grace_period),
238            mts_replacement_sum: 0,
239            diag: DiagnosticCache {
240                contribution_accuracy: vec![0.0; n],
241                ..Default::default()
242            },
243        }
244    }
245
246    // ---------------------------------------------------------------------------
247    // Training
248    // ---------------------------------------------------------------------------
249
250    /// Train on a single observation.
251    ///
252    /// Accepts any type implementing [`Observation`], including [`Sample`],
253    /// `SampleRef`, or tuples like `(&[f64], f64)` for zero-copy training.
254    pub fn train_one(&mut self, sample: &impl Observation) {
255        self.samples_seen += 1;
256        let target = sample.target();
257        let features = sample.features();
258
259        // Guard: skip non-finite inputs to prevent NaN/Inf from corrupting model state.
260        if !target.is_finite() || !features.iter().all(|f| f.is_finite()) {
261            return;
262        }
263
264        // Initialize base prediction from first few targets
265        if !self.base_initialized {
266            self.initial_targets.push(target);
267            if self.initial_targets.len() >= self.initial_target_count {
268                self.base_prediction = self.loss.initial_prediction(&self.initial_targets);
269                self.base_initialized = true;
270                self.initial_targets.clear();
271                self.initial_targets.shrink_to_fit();
272            }
273        }
274
275        // Current prediction starts from base
276        let mut current_pred = self.base_prediction;
277
278        // Adaptive MTS: compute contribution variance and set effective max_tree_samples
279        if let Some((base_mts, k)) = self.config.adaptive_mts {
280            let sigma = self.contribution_variance(features);
281            self.rolling_contribution_sigma =
282                0.999 * self.rolling_contribution_sigma + 0.001 * sigma;
283
284            let normalized = if self.rolling_contribution_sigma > 1e-10 {
285                sigma / self.rolling_contribution_sigma
286            } else {
287                1.0
288            };
289            let factor = 1.0 / (1.0 + k * normalized);
290            let floor = (base_mts as f64 * self.config.adaptive_mts_floor)
291                .max(self.config.grace_period as f64 * 2.0);
292            let effective_mts = ((base_mts as f64) * factor).max(floor) as u64;
293            for step in &mut self.steps {
294                step.slot_mut().set_max_tree_samples(Some(effective_mts));
295            }
296        }
297
298        let prune_alpha = self
299            .config
300            .quality_prune_alpha
301            .or_else(|| self.config.proactive_prune_interval.map(|_| 0.01));
302        let prune_threshold = self.config.quality_prune_threshold;
303        let prune_patience = self.config.quality_prune_patience;
304
305        // Track which trees were replaced by quality pruning this step (for double-fire prevention).
306        let mut replaced_this_step = vec![false; self.steps.len()];
307
308        // Error-weighted sample importance: compute weight from prediction error
309        let error_weight = if let Some(ew_alpha) = self.config.error_weight_alpha {
310            let abs_error = crate::math::abs(target - current_pred);
311            if self.rolling_mean_error > 1e-15 {
312                let w = (1.0 + abs_error / (self.rolling_mean_error + 1e-15)).min(10.0);
313                self.rolling_mean_error =
314                    ew_alpha * abs_error + (1.0 - ew_alpha) * self.rolling_mean_error;
315                w
316            } else {
317                self.rolling_mean_error = abs_error.max(1e-15);
318                1.0 // first sample, no reweighting
319            }
320        } else {
321            1.0
322        };
323
324        // Sequential boosting: each step targets the residual of all prior steps
325        #[allow(clippy::needless_range_loop)]
326        for s in 0..self.steps.len() {
327            let gradient = self.loss.gradient(target, current_pred) * error_weight;
328            let hessian = self.loss.hessian(target, current_pred) * error_weight;
329            let train_count = self
330                .config
331                .variant
332                .train_count(hessian, &mut self.rng_state);
333
334            let step_pred =
335                self.steps[s].train_and_predict(features, gradient, hessian, train_count);
336
337            current_pred += self.config.learning_rate * step_pred;
338
339            // Quality-based tree pruning: track contribution and replace dead wood
340            if let Some(alpha) = prune_alpha {
341                let contribution = crate::math::abs(self.config.learning_rate * step_pred);
342                self.contribution_ewma[s] =
343                    alpha * contribution + (1.0 - alpha) * self.contribution_ewma[s];
344
345                if self.contribution_ewma[s] < prune_threshold {
346                    self.low_contrib_count[s] += 1;
347                    if self.low_contrib_count[s] >= prune_patience {
348                        self.steps[s].reset();
349                        self.contribution_ewma[s] = 0.0;
350                        self.low_contrib_count[s] = 0;
351                        replaced_this_step[s] = true;
352                    }
353                } else {
354                    self.low_contrib_count[s] = 0;
355                }
356            }
357        }
358
359        // Proactive pruning: replace worst-contributing tree at interval
360        if let Some(interval) = self.config.proactive_prune_interval {
361            if self.samples_seen % interval == 0
362                && self.samples_seen > 0
363                && !self.contribution_ewma.is_empty()
364            {
365                let min_age = interval / 2;
366
367                // Collect (idx, ewma) for mature trees that weren't already replaced by quality pruning.
368                let mature: Vec<(usize, f64)> = self
369                    .steps
370                    .iter()
371                    .enumerate()
372                    .zip(self.contribution_ewma.iter())
373                    .filter(|((i, step), _)| {
374                        step.n_samples_seen() >= min_age && !replaced_this_step[*i]
375                    })
376                    .map(|((i, _), &ewma)| (i, ewma))
377                    .collect();
378
379                if !mature.is_empty() {
380                    // Compute p25 of contribution_ewma across mature trees
381                    let mut sorted_ewma: Vec<f64> = mature.iter().map(|(_, e)| *e).collect();
382                    sorted_ewma
383                        .sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
384                    let p25_idx = (sorted_ewma.len().saturating_sub(1)) / 4;
385                    let p25 = sorted_ewma[p25_idx];
386
387                    // Only prune if the worst is below p25
388                    let worst = mature.iter().min_by(|(_, a), (_, b)| {
389                        a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal)
390                    });
391
392                    if let Some(&(worst_idx, worst_ewma)) = worst {
393                        if worst_ewma < p25 {
394                            self.steps[worst_idx].reset();
395                            self.contribution_ewma[worst_idx] = 0.0;
396                            self.low_contrib_count[worst_idx] = 0;
397                        }
398                    }
399                }
400            }
401        }
402
403        // Refresh auto-bandwidths when trees have been replaced or not yet computed.
404        self.refresh_bandwidths();
405    }
406
407    /// Train on a batch of observations.
408    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
409        for sample in samples {
410            self.train_one(sample);
411        }
412    }
413
414    /// Train on a batch with periodic callback for cooperative yielding.
415    pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
416        &mut self,
417        samples: &[O],
418        interval: usize,
419        mut callback: F,
420    ) {
421        let interval = interval.max(1);
422        for (i, sample) in samples.iter().enumerate() {
423            self.train_one(sample);
424            if (i + 1) % interval == 0 {
425                callback(i + 1);
426            }
427        }
428        let total = samples.len();
429        if total % interval != 0 {
430            callback(total);
431        }
432    }
433
434    /// Train on a random subsample of a batch using reservoir sampling (Algorithm R).
435    pub fn train_batch_subsampled<O: Observation>(&mut self, samples: &[O], max_samples: usize) {
436        if max_samples >= samples.len() {
437            self.train_batch(samples);
438            return;
439        }
440        let mut reservoir: Vec<usize> = (0..max_samples).collect();
441        let mut rng = self.rng_state;
442        for i in max_samples..samples.len() {
443            rng ^= rng << 13;
444            rng ^= rng >> 7;
445            rng ^= rng << 17;
446            let j = (rng % (i as u64 + 1)) as usize;
447            if j < max_samples {
448                reservoir[j] = i;
449            }
450        }
451        self.rng_state = rng;
452        reservoir.sort_unstable();
453        for &idx in &reservoir {
454            self.train_one(&samples[idx]);
455        }
456    }
457
458    /// Train on a batch with both subsampling and periodic callbacks.
459    pub fn train_batch_subsampled_with_callback<O: Observation, F: FnMut(usize)>(
460        &mut self,
461        samples: &[O],
462        max_samples: usize,
463        interval: usize,
464        mut callback: F,
465    ) {
466        if max_samples >= samples.len() {
467            self.train_batch_with_callback(samples, interval, callback);
468            return;
469        }
470        let mut reservoir: Vec<usize> = (0..max_samples).collect();
471        let mut rng = self.rng_state;
472        for i in max_samples..samples.len() {
473            rng ^= rng << 13;
474            rng ^= rng >> 7;
475            rng ^= rng << 17;
476            let j = (rng % (i as u64 + 1)) as usize;
477            if j < max_samples {
478                reservoir[j] = i;
479            }
480        }
481        self.rng_state = rng;
482        reservoir.sort_unstable();
483        let interval = interval.max(1);
484        for (i, &idx) in reservoir.iter().enumerate() {
485            self.train_one(&samples[idx]);
486            if (i + 1) % interval == 0 {
487                callback(i + 1);
488            }
489        }
490        let total = reservoir.len();
491        if total % interval != 0 {
492            callback(total);
493        }
494    }
495
496    // ---------------------------------------------------------------------------
497    // Prediction
498    // ---------------------------------------------------------------------------
499
500    /// Predict the raw output for a feature vector.
501    ///
502    /// Uses auto-calibrated per-feature bandwidths for smooth (soft) routing.
503    /// Falls back to hard routing before any training has occurred.
504    pub fn predict(&self, features: &[f64]) -> f64 {
505        let mut pred = self.base_prediction;
506        if self.auto_bandwidths.is_empty() {
507            for step in &self.steps {
508                pred += self.config.learning_rate * step.predict(features);
509            }
510        } else {
511            for step in &self.steps {
512                pred += self.config.learning_rate
513                    * step.predict_smooth_auto(features, &self.auto_bandwidths);
514            }
515        }
516        pred
517    }
518
519    /// Predict using sigmoid-blended soft routing with an explicit bandwidth.
520    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
521        let mut pred = self.base_prediction;
522        for step in &self.steps {
523            pred += self.config.learning_rate * step.predict_smooth(features, bandwidth);
524        }
525        pred
526    }
527
528    /// Per-feature auto-calibrated bandwidths used by `predict()`.
529    pub fn auto_bandwidths(&self) -> &[f64] {
530        &self.auto_bandwidths
531    }
532
533    /// Predict with parent-leaf linear interpolation.
534    pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
535        let mut pred = self.base_prediction;
536        for step in &self.steps {
537            pred += self.config.learning_rate * step.predict_interpolated(features);
538        }
539        pred
540    }
541
542    /// Predict with sibling-based interpolation for feature-continuous predictions.
543    pub fn predict_sibling_interpolated(&self, features: &[f64]) -> f64 {
544        let mut pred = self.base_prediction;
545        for step in &self.steps {
546            pred += self.config.learning_rate
547                * step.predict_sibling_interpolated(features, &self.auto_bandwidths);
548        }
549        pred
550    }
551
552    /// Predict with graduated active-shadow blending.
553    pub fn predict_graduated(&self, features: &[f64]) -> f64 {
554        let mut pred = self.base_prediction;
555        for step in &self.steps {
556            pred += self.config.learning_rate * step.predict_graduated(features);
557        }
558        pred
559    }
560
561    /// Predict with graduated blending + sibling interpolation.
562    pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> f64 {
563        let mut pred = self.base_prediction;
564        for step in &self.steps {
565            pred += self.config.learning_rate
566                * step.predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
567        }
568        pred
569    }
570
571    /// Predict with loss transform applied (e.g., sigmoid for logistic loss).
572    pub fn predict_transformed(&self, features: &[f64]) -> f64 {
573        self.loss.predict_transform(self.predict(features))
574    }
575
576    /// Predict probability (alias for `predict_transformed`).
577    pub fn predict_proba(&self, features: &[f64]) -> f64 {
578        self.predict_transformed(features)
579    }
580
581    /// Predict with confidence estimation.
582    ///
583    /// Returns `(prediction, confidence)` where confidence = 1 / sqrt(sum_variance).
584    pub fn predict_with_confidence(&self, features: &[f64]) -> (f64, f64) {
585        let mut pred = self.base_prediction;
586        let mut total_variance = 0.0;
587        let lr2 = self.config.learning_rate * self.config.learning_rate;
588        for step in &self.steps {
589            let (value, variance) = step.predict_with_variance(features);
590            pred += self.config.learning_rate * value;
591            total_variance += lr2 * variance;
592        }
593        let confidence = if total_variance > 0.0 && total_variance.is_finite() {
594            1.0 / crate::math::sqrt(total_variance)
595        } else {
596            0.0
597        };
598        (pred, confidence)
599    }
600
601    /// Batch prediction.
602    pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
603        feature_matrix.iter().map(|f| self.predict(f)).collect()
604    }
605
606    // ---------------------------------------------------------------------------
607    // Accessors
608    // ---------------------------------------------------------------------------
609
610    /// Number of boosting steps.
611    pub fn n_steps(&self) -> usize {
612        self.steps.len()
613    }
614
615    /// Total trees (active + alternates).
616    pub fn n_trees(&self) -> usize {
617        self.steps.len() + self.steps.iter().filter(|s| s.has_alternate()).count()
618    }
619
620    /// Total leaves across all active trees.
621    pub fn total_leaves(&self) -> usize {
622        self.steps.iter().map(|s| s.n_leaves()).sum()
623    }
624
625    /// Total samples trained.
626    pub fn n_samples_seen(&self) -> u64 {
627        self.samples_seen
628    }
629
630    /// The current base prediction.
631    pub fn base_prediction(&self) -> f64 {
632        self.base_prediction
633    }
634
635    /// Whether the base prediction has been initialized.
636    pub fn is_initialized(&self) -> bool {
637        self.base_initialized
638    }
639
640    /// Access the configuration.
641    pub fn config(&self) -> &SGBTConfig {
642        &self.config
643    }
644
645    /// Set the learning rate for future boosting rounds.
646    #[inline]
647    pub fn set_learning_rate(&mut self, lr: f64) {
648        self.config.learning_rate = lr;
649    }
650
651    /// Immutable access to the boosting steps.
652    pub fn steps(&self) -> &[BoostingStep] {
653        &self.steps
654    }
655
656    /// Immutable access to the loss function.
657    pub fn loss(&self) -> &L {
658        &self.loss
659    }
660
661    /// Feature importances based on accumulated split gains across all trees.
662    ///
663    /// Returns normalized importances (sum to 1.0) indexed by feature.
664    pub fn feature_importances(&self) -> Vec<f64> {
665        let mut totals: Vec<f64> = Vec::new();
666        for step in &self.steps {
667            let gains = step.slot().split_gains();
668            if totals.is_empty() && !gains.is_empty() {
669                totals.resize(gains.len(), 0.0);
670            }
671            for (i, &g) in gains.iter().enumerate() {
672                if i < totals.len() {
673                    totals[i] += g;
674                }
675            }
676        }
677        let sum: f64 = totals.iter().sum();
678        if sum > 0.0 {
679            totals.iter_mut().for_each(|v| *v /= sum);
680        }
681        totals
682    }
683
684    /// Feature names, if configured.
685    pub fn feature_names(&self) -> Option<&[String]> {
686        self.config.feature_names.as_deref()
687    }
688
689    /// Feature importances paired with their names.
690    ///
691    /// Returns `None` if feature names are not configured. Otherwise returns
692    /// `(name, importance)` pairs sorted by importance descending.
693    pub fn named_feature_importances(&self) -> Option<Vec<(String, f64)>> {
694        let names = self.config.feature_names.as_ref()?;
695        let importances = self.feature_importances();
696        let mut pairs: Vec<(String, f64)> = names
697            .iter()
698            .zip(importances.iter().chain(core::iter::repeat(&0.0)))
699            .map(|(n, &v)| (n.clone(), v))
700            .collect();
701        pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
702        Some(pairs)
703    }
704
705    /// Train on a single sample with named features.
706    #[cfg(feature = "std")]
707    pub fn train_one_named(
708        &mut self,
709        features: &std::collections::HashMap<alloc::string::String, f64>,
710        target: f64,
711    ) {
712        let names = self
713            .config
714            .feature_names
715            .as_ref()
716            .expect("train_one_named requires feature_names to be configured");
717        let vec: Vec<f64> = names
718            .iter()
719            .map(|name| features.get(name).copied().unwrap_or(0.0))
720            .collect();
721        self.train_one(&(&vec[..], target));
722    }
723
724    /// Predict with named features.
725    #[cfg(feature = "std")]
726    pub fn predict_named(
727        &self,
728        features: &std::collections::HashMap<alloc::string::String, f64>,
729    ) -> f64 {
730        let names = self
731            .config
732            .feature_names
733            .as_ref()
734            .expect("predict_named requires feature_names to be configured");
735        let vec: Vec<f64> = names
736            .iter()
737            .map(|name| features.get(name).copied().unwrap_or(0.0))
738            .collect();
739        self.predict(&vec)
740    }
741
742    // ---------------------------------------------------------------------------
743    // Reset
744    // ---------------------------------------------------------------------------
745
746    /// Reset the ensemble to initial state.
747    pub fn reset(&mut self) {
748        for step in &mut self.steps {
749            step.reset();
750        }
751        self.base_prediction = 0.0;
752        self.base_initialized = false;
753        self.initial_targets.clear();
754        self.samples_seen = 0;
755        self.rng_state = self.config.seed;
756        self.rolling_mean_error = 0.0;
757        self.rolling_contribution_sigma = 0.0;
758        self.auto_bandwidths.clear();
759        self.last_replacement_sum = 0;
760        self.sigma_ring.clear();
761        self.mts_replacement_sum = 0;
762        self.diag = DiagnosticCache {
763            contribution_accuracy: vec![0.0; self.steps.len()],
764            ..Default::default()
765        };
766        if !self.contribution_ewma.is_empty() {
767            self.contribution_ewma.iter_mut().for_each(|v| *v = 0.0);
768        }
769        if !self.low_contrib_count.is_empty() {
770            self.low_contrib_count.iter_mut().for_each(|v| *v = 0);
771        }
772    }
773
774    // ---------------------------------------------------------------------------
775    // Internal helpers
776    // ---------------------------------------------------------------------------
777
778    /// Compute tree contribution standard deviation (σ proxy for adaptive_mts).
779    fn contribution_variance(&self, features: &[f64]) -> f64 {
780        let n = self.steps.len();
781        if n <= 1 {
782            return 0.0;
783        }
784        let lr = self.config.learning_rate;
785        let mut sum = 0.0;
786        let mut sq_sum = 0.0;
787        for step in &self.steps {
788            let c = lr * step.predict(features);
789            sum += c;
790            sq_sum += c * c;
791        }
792        let n_f = n as f64;
793        let mean = sum / n_f;
794        let var = (sq_sum / n_f) - (mean * mean);
795        crate::math::sqrt((var.abs() * n_f / (n_f - 1.0)).max(0.0))
796    }
797
798    /// Refresh auto-bandwidths if any tree has been replaced since last computation.
799    fn refresh_bandwidths(&mut self) {
800        let current_sum: u64 = self.steps.iter().map(|s| s.slot().replacements()).sum();
801        if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
802            self.auto_bandwidths = self.compute_auto_bandwidths();
803            self.last_replacement_sum = current_sum;
804        }
805    }
806
807    /// Compute per-feature auto-calibrated bandwidths from all trees.
808    fn compute_auto_bandwidths(&self) -> Vec<f64> {
809        const K: f64 = 2.0;
810        let n_features = self
811            .steps
812            .iter()
813            .filter_map(|s| s.slot().active_tree().n_features())
814            .max()
815            .unwrap_or(0);
816
817        if n_features == 0 {
818            return Vec::new();
819        }
820
821        let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
822        for step in &self.steps {
823            let tree_thresholds = step
824                .slot()
825                .active_tree()
826                .collect_split_thresholds_per_feature();
827            for (i, ts) in tree_thresholds.into_iter().enumerate() {
828                if i < n_features {
829                    all_thresholds[i].extend(ts);
830                }
831            }
832        }
833
834        let n_bins = self.config.n_bins as f64;
835
836        all_thresholds
837            .iter()
838            .map(|ts| {
839                if ts.is_empty() {
840                    return f64::INFINITY;
841                }
842                let mut sorted = ts.clone();
843                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
844                sorted.dedup_by(|a, b| crate::math::abs(*a - *b) < 1e-15);
845
846                if sorted.len() < 2 {
847                    return f64::INFINITY;
848                }
849
850                let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
851
852                if sorted.len() < 3 {
853                    let range = sorted.last().unwrap() - sorted.first().unwrap();
854                    if range < 1e-15 {
855                        return f64::INFINITY;
856                    }
857                    return (range / n_bins) * K;
858                }
859
860                gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
861                let median_gap = if gaps.len() % 2 == 0 {
862                    (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
863                } else {
864                    gaps[gaps.len() / 2]
865                };
866
867                if median_gap < 1e-15 {
868                    f64::INFINITY
869                } else {
870                    median_gap * K
871                }
872            })
873            .collect()
874    }
875}