Skip to main content

irithyll_core/ensemble/
distributional.rs

1//! Distributional SGBT -- outputs Gaussian N(μ, σ²) instead of a point estimate.
2//!
3//! [`DistributionalSGBT`] supports two scale estimation modes via
4//! [`ScaleMode`]:
5//!
6//! ## Empirical σ (default)
7//!
8//! Tracks an EWMA of squared prediction errors:
9//!
10//! ```text
11//! err = target - mu
12//! ewma_sq_err = alpha * err² + (1 - alpha) * ewma_sq_err
13//! sigma = sqrt(ewma_sq_err)
14//! ```
15//!
16//! Always calibrated (σ literally *is* recent error magnitude), zero tuning,
17//! O(1) memory and compute.  When `uncertainty_modulated_lr` is enabled,
18//! high recent errors → σ large → location LR scales up → faster correction.
19//!
20//! ## Tree chain (NGBoost-style)
21//!
22//! Maintains two independent tree ensembles: one for location (μ), one for
23//! scale (log σ).  Gives feature-conditional uncertainty but requires strong
24//! scale-gradient signal for the trees to split.
25//!
26//! # References
27//!
28//! Duan et al. (2020). "NGBoost: Natural Gradient Boosting for Probabilistic Prediction."
29
30use alloc::vec;
31use alloc::vec::Vec;
32
33use crate::ensemble::config::{SGBTConfig, ScaleMode};
34use crate::ensemble::step::BoostingStep;
35use crate::sample::{Observation, SampleRef};
36use crate::tree::builder::TreeConfig;
37
38/// Cached packed f32 binary for fast location-only inference.
39///
40/// Re-exported periodically from the location ensemble. Predictions use
41/// contiguous BFS-packed memory for cache-optimal tree traversal.
42struct PackedInferenceCache {
43    bytes: Vec<u8>,
44    base: f64,
45    n_features: usize,
46}
47
48impl Clone for PackedInferenceCache {
49    fn clone(&self) -> Self {
50        Self {
51            bytes: self.bytes.clone(),
52            base: self.base,
53            n_features: self.n_features,
54        }
55    }
56}
57
58/// Prediction from a distributional model: full Gaussian N(μ, σ²).
59#[derive(Debug, Clone, Copy)]
60pub struct GaussianPrediction {
61    /// Location parameter (mean).
62    pub mu: f64,
63    /// Scale parameter (standard deviation, always > 0).
64    pub sigma: f64,
65    /// Log of scale parameter (raw model output for scale ensemble).
66    pub log_sigma: f64,
67}
68
69impl GaussianPrediction {
70    /// Lower bound of a symmetric confidence interval.
71    ///
72    /// For 95% CI, use `z = 1.96`.
73    #[inline]
74    pub fn lower(&self, z: f64) -> f64 {
75        self.mu - z * self.sigma
76    }
77
78    /// Upper bound of a symmetric confidence interval.
79    #[inline]
80    pub fn upper(&self, z: f64) -> f64 {
81        self.mu + z * self.sigma
82    }
83}
84
85// ---------------------------------------------------------------------------
86// Diagnostic structs
87// ---------------------------------------------------------------------------
88
89/// Per-tree diagnostic summary.
90#[derive(Debug, Clone)]
91pub struct TreeDiagnostic {
92    /// Number of leaf nodes in this tree.
93    pub n_leaves: usize,
94    /// Maximum depth reached by any leaf.
95    pub max_depth_reached: usize,
96    /// Total samples this tree has seen.
97    pub samples_seen: u64,
98    /// Leaf weight statistics: `(min, max, mean, std)`.
99    pub leaf_weight_stats: (f64, f64, f64, f64),
100    /// Feature indices this tree has split on (non-zero gain).
101    pub split_features: Vec<usize>,
102    /// Per-leaf sample counts showing data distribution across leaves.
103    pub leaf_sample_counts: Vec<u64>,
104    /// Running mean of predictions from this tree (Welford online).
105    pub prediction_mean: f64,
106    /// Running standard deviation of predictions from this tree.
107    pub prediction_std: f64,
108}
109
110/// Full model diagnostics for [`DistributionalSGBT`].
111///
112/// Contains per-tree summaries, feature usage, base predictions, and
113/// empirical σ state.
114#[derive(Debug, Clone)]
115pub struct ModelDiagnostics {
116    /// Per-tree diagnostic summaries (location trees first, then scale trees).
117    pub trees: Vec<TreeDiagnostic>,
118    /// Location trees only (view into `trees`).
119    pub location_trees: Vec<TreeDiagnostic>,
120    /// Scale trees only (view into `trees`).
121    pub scale_trees: Vec<TreeDiagnostic>,
122    /// How many trees each feature is used in (split count per feature).
123    pub feature_split_counts: Vec<usize>,
124    /// Base prediction for location (mean).
125    pub location_base: f64,
126    /// Base prediction for scale (log-sigma).
127    pub scale_base: f64,
128    /// Current empirical σ (`sqrt(ewma_sq_err)`), always available.
129    pub empirical_sigma: f64,
130    /// Scale mode in use.
131    pub scale_mode: ScaleMode,
132    /// Number of scale trees that actually split (>1 leaf). 0 = frozen chain.
133    pub scale_trees_active: usize,
134    /// Per-feature auto-calibrated bandwidths for smooth prediction.
135    /// `f64::INFINITY` means that feature uses hard routing.
136    pub auto_bandwidths: Vec<f64>,
137    /// Ensemble-level gradient running mean.
138    pub ensemble_grad_mean: f64,
139    /// Ensemble-level gradient standard deviation.
140    pub ensemble_grad_std: f64,
141}
142
143/// Decomposed prediction showing each tree's contribution.
144#[derive(Debug, Clone)]
145pub struct DecomposedPrediction {
146    /// Base location prediction (mean of initial targets).
147    pub location_base: f64,
148    /// Base scale prediction (log-sigma of initial targets).
149    pub scale_base: f64,
150    /// Per-step location contributions: `learning_rate * tree_prediction`.
151    /// `location_base + sum(location_contributions)` = μ.
152    pub location_contributions: Vec<f64>,
153    /// Per-step scale contributions: `learning_rate * tree_prediction`.
154    /// `scale_base + sum(scale_contributions)` = log(σ).
155    pub scale_contributions: Vec<f64>,
156}
157
158impl DecomposedPrediction {
159    /// Reconstruct the final μ from base + contributions.
160    pub fn mu(&self) -> f64 {
161        self.location_base + self.location_contributions.iter().sum::<f64>()
162    }
163
164    /// Reconstruct the final log(σ) from base + contributions.
165    pub fn log_sigma(&self) -> f64 {
166        self.scale_base + self.scale_contributions.iter().sum::<f64>()
167    }
168
169    /// Reconstruct the final σ (exponentiated).
170    pub fn sigma(&self) -> f64 {
171        crate::math::exp(self.log_sigma()).max(1e-8)
172    }
173}
174
175// ---------------------------------------------------------------------------
176// DistributionalSGBT
177// ---------------------------------------------------------------------------
178
179/// NGBoost-style distributional streaming gradient boosted trees.
180///
181/// Outputs a full Gaussian predictive distribution N(μ, σ²) by maintaining two
182/// independent ensembles -- one for location (mean) and one for scale (log-sigma).
183///
184/// # Example
185///
186/// ```text
187/// use irithyll::SGBTConfig;
188/// use irithyll::ensemble::distributional::DistributionalSGBT;
189///
190/// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
191/// let mut model = DistributionalSGBT::new(config);
192///
193/// // Train on streaming data
194/// model.train_one(&(vec![1.0, 2.0], 3.5));
195///
196/// // Get full distributional prediction
197/// let pred = model.predict(&[1.0, 2.0]);
198/// println!("mean={}, sigma={}", pred.mu, pred.sigma);
199/// ```
200pub struct DistributionalSGBT {
201    /// Configuration (shared between location and scale ensembles).
202    config: SGBTConfig,
203    /// Location (mean) boosting steps.
204    location_steps: Vec<BoostingStep>,
205    /// Scale (log-sigma) boosting steps (only used in `TreeChain` mode).
206    scale_steps: Vec<BoostingStep>,
207    /// Base prediction for location (mean of initial targets).
208    location_base: f64,
209    /// Base prediction for scale (log of std of initial targets).
210    scale_base: f64,
211    /// Whether base predictions have been initialized.
212    base_initialized: bool,
213    /// Running collection of initial targets for computing base predictions.
214    initial_targets: Vec<f64>,
215    /// Number of initial targets to collect before setting base predictions.
216    initial_target_count: usize,
217    /// Total samples trained.
218    samples_seen: u64,
219    /// RNG state for variant logic.
220    rng_state: u64,
221    /// Whether σ-modulated learning rate is enabled.
222    uncertainty_modulated_lr: bool,
223    /// EWMA of the model's predicted σ -- used as the denominator in σ-ratio.
224    ///
225    /// Updated with alpha = 0.001 (slow adaptation) after each training step.
226    /// Initialized from the standard deviation of the initial target collection.
227    rolling_sigma_mean: f64,
228    /// Scale estimation mode: Empirical (default) or TreeChain.
229    scale_mode: ScaleMode,
230    /// EWMA of squared prediction errors (empirical σ mode).
231    ///
232    /// `sigma = sqrt(ewma_sq_err)`. Updated every training step with
233    /// `ewma_sq_err = alpha * err² + (1 - alpha) * ewma_sq_err`.
234    ewma_sq_err: f64,
235    /// EWMA alpha for empirical σ.
236    empirical_sigma_alpha: f64,
237    /// Previous empirical σ value for computing σ velocity.
238    prev_sigma: f64,
239    /// EWMA-smoothed derivative of empirical σ (σ velocity).
240    /// Positive = σ increasing (errors growing), negative = σ decreasing.
241    sigma_velocity: f64,
242    /// Per-feature auto-calibrated bandwidths for smooth prediction.
243    auto_bandwidths: Vec<f64>,
244    /// Sum of replacement counts at last bandwidth computation.
245    last_replacement_sum: u64,
246    /// Ensemble-level gradient running mean (Welford).
247    ensemble_grad_mean: f64,
248    /// Ensemble-level gradient M2 (Welford sum of squared deviations).
249    ensemble_grad_m2: f64,
250    /// Ensemble-level gradient sample count.
251    ensemble_grad_count: u64,
252    /// Packed f32 cache for fast location-only inference (dual-path).
253    packed_cache: Option<PackedInferenceCache>,
254    /// Samples trained since last packed cache refresh.
255    samples_since_refresh: u64,
256    /// How often to refresh the packed cache (0 = disabled).
257    packed_refresh_interval: u64,
258}
259
260impl Clone for DistributionalSGBT {
261    fn clone(&self) -> Self {
262        Self {
263            config: self.config.clone(),
264            location_steps: self.location_steps.clone(),
265            scale_steps: self.scale_steps.clone(),
266            location_base: self.location_base,
267            scale_base: self.scale_base,
268            base_initialized: self.base_initialized,
269            initial_targets: self.initial_targets.clone(),
270            initial_target_count: self.initial_target_count,
271            samples_seen: self.samples_seen,
272            rng_state: self.rng_state,
273            uncertainty_modulated_lr: self.uncertainty_modulated_lr,
274            rolling_sigma_mean: self.rolling_sigma_mean,
275            scale_mode: self.scale_mode,
276            ewma_sq_err: self.ewma_sq_err,
277            empirical_sigma_alpha: self.empirical_sigma_alpha,
278            prev_sigma: self.prev_sigma,
279            sigma_velocity: self.sigma_velocity,
280            auto_bandwidths: self.auto_bandwidths.clone(),
281            last_replacement_sum: self.last_replacement_sum,
282            ensemble_grad_mean: self.ensemble_grad_mean,
283            ensemble_grad_m2: self.ensemble_grad_m2,
284            ensemble_grad_count: self.ensemble_grad_count,
285            packed_cache: self.packed_cache.clone(),
286            samples_since_refresh: self.samples_since_refresh,
287            packed_refresh_interval: self.packed_refresh_interval,
288        }
289    }
290}
291
292impl core::fmt::Debug for DistributionalSGBT {
293    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
294        let mut s = f.debug_struct("DistributionalSGBT");
295        s.field("n_steps", &self.location_steps.len())
296            .field("samples_seen", &self.samples_seen)
297            .field("location_base", &self.location_base)
298            .field("scale_mode", &self.scale_mode)
299            .field("base_initialized", &self.base_initialized);
300        match self.scale_mode {
301            ScaleMode::Empirical => {
302                s.field("empirical_sigma", &crate::math::sqrt(self.ewma_sq_err));
303            }
304            ScaleMode::TreeChain => {
305                s.field("scale_base", &self.scale_base);
306            }
307        }
308        if self.uncertainty_modulated_lr {
309            s.field("rolling_sigma_mean", &self.rolling_sigma_mean);
310        }
311        s.finish()
312    }
313}
314
315impl DistributionalSGBT {
316    /// Create a new distributional SGBT with the given configuration.
317    ///
318    /// When `scale_mode` is `Empirical` (default), scale trees are still allocated
319    /// but never trained — only the EWMA error tracker produces σ.  When
320    /// `scale_mode` is `TreeChain`, both location and scale ensembles are active.
321    pub fn new(config: SGBTConfig) -> Self {
322        let leaf_decay_alpha = config
323            .leaf_half_life
324            .map(|hl| crate::math::exp(-(crate::math::ln(2.0_f64)) / hl as f64));
325
326        let tree_config = TreeConfig::new()
327            .max_depth(config.max_depth)
328            .n_bins(config.n_bins)
329            .lambda(config.lambda)
330            .gamma(config.gamma)
331            .grace_period(config.grace_period)
332            .delta(config.delta)
333            .feature_subsample_rate(config.feature_subsample_rate)
334            .leaf_decay_alpha_opt(leaf_decay_alpha)
335            .split_reeval_interval_opt(config.split_reeval_interval)
336            .feature_types_opt(config.feature_types.clone())
337            .gradient_clip_sigma_opt(config.gradient_clip_sigma)
338            .monotone_constraints_opt(config.monotone_constraints.clone())
339            .max_leaf_output_opt(config.max_leaf_output)
340            .min_hessian_sum_opt(config.min_hessian_sum)
341            .leaf_model_type(config.leaf_model_type.clone());
342
343        let max_tree_samples = config.max_tree_samples;
344
345        // Location ensemble (seed offset: 0)
346        let shadow_warmup = config.shadow_warmup.unwrap_or(0);
347        let location_steps: Vec<BoostingStep> = (0..config.n_steps)
348            .map(|i| {
349                let mut tc = tree_config.clone();
350                tc.seed = config.seed ^ (i as u64);
351                let detector = config.drift_detector.create();
352                if shadow_warmup > 0 {
353                    BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
354                } else {
355                    BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
356                }
357            })
358            .collect();
359
360        // Scale ensemble (seed offset: 0xSCALE) -- only trained in TreeChain mode
361        let scale_steps: Vec<BoostingStep> = (0..config.n_steps)
362            .map(|i| {
363                let mut tc = tree_config.clone();
364                tc.seed = config.seed ^ (i as u64) ^ 0x0005_CA1E_0000_0000;
365                let detector = config.drift_detector.create();
366                if shadow_warmup > 0 {
367                    BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
368                } else {
369                    BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
370                }
371            })
372            .collect();
373
374        let seed = config.seed;
375        let initial_target_count = config.initial_target_count;
376        let uncertainty_modulated_lr = config.uncertainty_modulated_lr;
377        let scale_mode = config.scale_mode;
378        let empirical_sigma_alpha = config.empirical_sigma_alpha;
379        let packed_refresh_interval = config.packed_refresh_interval;
380        Self {
381            config,
382            location_steps,
383            scale_steps,
384            location_base: 0.0,
385            scale_base: 0.0,
386            base_initialized: false,
387            initial_targets: Vec::new(),
388            initial_target_count,
389            samples_seen: 0,
390            rng_state: seed,
391            uncertainty_modulated_lr,
392            rolling_sigma_mean: 1.0, // overwritten during base initialization
393            scale_mode,
394            ewma_sq_err: 1.0, // overwritten during base initialization
395            empirical_sigma_alpha,
396            prev_sigma: 0.0,
397            sigma_velocity: 0.0,
398            auto_bandwidths: Vec::new(),
399            last_replacement_sum: 0,
400            ensemble_grad_mean: 0.0,
401            ensemble_grad_m2: 0.0,
402            ensemble_grad_count: 0,
403            packed_cache: None,
404            samples_since_refresh: 0,
405            packed_refresh_interval,
406        }
407    }
408
409    /// Train on a single observation.
410    pub fn train_one(&mut self, sample: &impl Observation) {
411        self.samples_seen += 1;
412        let target = sample.target();
413        let features = sample.features();
414
415        // Initialize base predictions from first few targets
416        if !self.base_initialized {
417            self.initial_targets.push(target);
418            if self.initial_targets.len() >= self.initial_target_count {
419                // Location base = mean
420                let sum: f64 = self.initial_targets.iter().sum();
421                let mean = sum / self.initial_targets.len() as f64;
422                self.location_base = mean;
423
424                // Scale base = log(std) -- clamped for stability
425                let var: f64 = self
426                    .initial_targets
427                    .iter()
428                    .map(|&y| (y - mean) * (y - mean))
429                    .sum::<f64>()
430                    / self.initial_targets.len() as f64;
431                let initial_std = crate::math::sqrt(var).max(1e-6);
432                self.scale_base = crate::math::ln(initial_std);
433
434                // Initialize rolling sigma mean and ewma from initial targets std
435                self.rolling_sigma_mean = initial_std;
436                self.ewma_sq_err = var.max(1e-12);
437
438                // Initialize PD sigma state
439                self.prev_sigma = initial_std;
440                self.sigma_velocity = 0.0;
441
442                self.base_initialized = true;
443                self.initial_targets.clear();
444                self.initial_targets.shrink_to_fit();
445            }
446            return;
447        }
448
449        match self.scale_mode {
450            ScaleMode::Empirical => self.train_one_empirical(target, features),
451            ScaleMode::TreeChain => self.train_one_tree_chain(target, features),
452        }
453
454        // Refresh auto-bandwidths when trees have been replaced
455        self.refresh_bandwidths();
456    }
457
458    /// Empirical-σ training: location trees only, σ from EWMA of squared errors.
459    fn train_one_empirical(&mut self, target: f64, features: &[f64]) {
460        // Current location prediction (before this step's update)
461        let mut mu = self.location_base;
462        for s in 0..self.location_steps.len() {
463            mu += self.config.learning_rate * self.location_steps[s].predict(features);
464        }
465
466        // Empirical sigma from EWMA of squared prediction errors
467        let err = target - mu;
468        let alpha = self.empirical_sigma_alpha;
469        self.ewma_sq_err = (1.0 - alpha) * self.ewma_sq_err + alpha * err * err;
470        let empirical_sigma = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
471
472        // Compute σ-ratio for uncertainty-modulated learning rate (PD controller)
473        let sigma_ratio = if self.uncertainty_modulated_lr {
474            // Compute sigma velocity (derivative of sigma over time)
475            let d_sigma = empirical_sigma - self.prev_sigma;
476            self.prev_sigma = empirical_sigma;
477
478            // EWMA-smooth the velocity (same alpha as empirical sigma for synchronization)
479            self.sigma_velocity = (1.0 - alpha) * self.sigma_velocity + alpha * d_sigma;
480
481            // Adaptive derivative gain: self-calibrating, no config needed
482            let k_d = if self.rolling_sigma_mean > 1e-12 {
483                crate::math::abs(self.sigma_velocity) / self.rolling_sigma_mean
484            } else {
485                0.0
486            };
487
488            // PD ratio: proportional + derivative
489            let pd_sigma = empirical_sigma + k_d * self.sigma_velocity;
490            let ratio = (pd_sigma / self.rolling_sigma_mean).clamp(0.1, 10.0);
491
492            // Update rolling sigma mean with slow EWMA
493            const SIGMA_EWMA_ALPHA: f64 = 0.001;
494            self.rolling_sigma_mean = (1.0 - SIGMA_EWMA_ALPHA) * self.rolling_sigma_mean
495                + SIGMA_EWMA_ALPHA * empirical_sigma;
496
497            ratio
498        } else {
499            1.0
500        };
501
502        let base_lr = self.config.learning_rate;
503
504        // Train location steps only -- no scale trees needed
505        let mut mu_accum = self.location_base;
506        for s in 0..self.location_steps.len() {
507            let (g_mu, h_mu) = self.location_gradient(mu_accum, target);
508            // Welford update for ensemble gradient stats
509            self.update_ensemble_grad_stats(g_mu);
510            let train_count = self.config.variant.train_count(h_mu, &mut self.rng_state);
511            let loc_pred =
512                self.location_steps[s].train_and_predict(features, g_mu, h_mu, train_count);
513            mu_accum += (base_lr * sigma_ratio) * loc_pred;
514        }
515
516        // Refresh packed cache if interval reached
517        self.maybe_refresh_packed_cache();
518    }
519
520    /// Tree-chain training: full NGBoost dual-chain with location + scale trees.
521    fn train_one_tree_chain(&mut self, target: f64, features: &[f64]) {
522        let mut mu = self.location_base;
523        let mut log_sigma = self.scale_base;
524
525        // Compute σ-ratio for uncertainty-modulated learning rate (PD controller).
526        let sigma_ratio = if self.uncertainty_modulated_lr {
527            let current_sigma = crate::math::exp(log_sigma).max(1e-8);
528
529            // Compute sigma velocity (derivative of sigma over time)
530            let d_sigma = current_sigma - self.prev_sigma;
531            self.prev_sigma = current_sigma;
532
533            // EWMA-smooth the velocity (same alpha as empirical sigma for synchronization)
534            let alpha = self.empirical_sigma_alpha;
535            self.sigma_velocity = (1.0 - alpha) * self.sigma_velocity + alpha * d_sigma;
536
537            // Adaptive derivative gain: self-calibrating, no config needed
538            let k_d = if self.rolling_sigma_mean > 1e-12 {
539                crate::math::abs(self.sigma_velocity) / self.rolling_sigma_mean
540            } else {
541                0.0
542            };
543
544            // PD ratio: proportional + derivative
545            let pd_sigma = current_sigma + k_d * self.sigma_velocity;
546            let ratio = (pd_sigma / self.rolling_sigma_mean).clamp(0.1, 10.0);
547
548            const SIGMA_EWMA_ALPHA: f64 = 0.001;
549            self.rolling_sigma_mean = (1.0 - SIGMA_EWMA_ALPHA) * self.rolling_sigma_mean
550                + SIGMA_EWMA_ALPHA * current_sigma;
551
552            ratio
553        } else {
554            1.0
555        };
556
557        let base_lr = self.config.learning_rate;
558
559        // Sequential boosting: both ensembles target their respective residuals
560        for s in 0..self.location_steps.len() {
561            let sigma = crate::math::exp(log_sigma).max(1e-8);
562            let z = (target - mu) / sigma;
563
564            // Location gradients (squared or Huber loss w.r.t. mu)
565            let (g_mu, h_mu) = self.location_gradient(mu, target);
566            // Welford update for ensemble gradient stats
567            self.update_ensemble_grad_stats(g_mu);
568
569            // Scale gradients (NLL w.r.t. log_sigma)
570            let g_sigma = 1.0 - z * z;
571            let h_sigma = (2.0 * z * z).clamp(0.01, 100.0);
572
573            let train_count = self.config.variant.train_count(h_mu, &mut self.rng_state);
574
575            // Train location step -- σ-modulated LR when enabled
576            let loc_pred =
577                self.location_steps[s].train_and_predict(features, g_mu, h_mu, train_count);
578            mu += (base_lr * sigma_ratio) * loc_pred;
579
580            // Train scale step -- ALWAYS at unmodulated base rate.
581            let scale_pred =
582                self.scale_steps[s].train_and_predict(features, g_sigma, h_sigma, train_count);
583            log_sigma += base_lr * scale_pred;
584        }
585
586        // Also update empirical sigma tracker for diagnostics
587        let err = target - mu;
588        let alpha = self.empirical_sigma_alpha;
589        self.ewma_sq_err = (1.0 - alpha) * self.ewma_sq_err + alpha * err * err;
590
591        // Refresh packed cache if interval reached
592        self.maybe_refresh_packed_cache();
593    }
594
595    /// Predict the full Gaussian distribution for a feature vector.
596    ///
597    /// When a packed cache is available, uses it for the location (μ) prediction
598    /// via contiguous BFS-packed memory traversal. Falls back to full tree
599    /// traversal if the cache is absent or produces non-finite results.
600    ///
601    /// Sigma computation always uses the primary path (EWMA or scale chain)
602    /// and is unaffected by the packed cache.
603    pub fn predict(&self, features: &[f64]) -> GaussianPrediction {
604        // Try packed cache for mu if available
605        let mu = if let Some(ref cache) = self.packed_cache {
606            let features_f32: Vec<f32> = features.iter().map(|&v| v as f32).collect();
607            match crate::EnsembleView::from_bytes(&cache.bytes) {
608                Ok(view) => {
609                    let packed_mu = cache.base + view.predict(&features_f32) as f64;
610                    if packed_mu.is_finite() {
611                        packed_mu
612                    } else {
613                        self.predict_full_trees(features)
614                    }
615                }
616                Err(_) => self.predict_full_trees(features),
617            }
618        } else {
619            self.predict_full_trees(features)
620        };
621
622        let (sigma, log_sigma) = match self.scale_mode {
623            ScaleMode::Empirical => {
624                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
625                (s, crate::math::ln(s))
626            }
627            ScaleMode::TreeChain => {
628                let mut ls = self.scale_base;
629                if self.auto_bandwidths.is_empty() {
630                    for s in 0..self.scale_steps.len() {
631                        ls += self.config.learning_rate * self.scale_steps[s].predict(features);
632                    }
633                } else {
634                    for s in 0..self.scale_steps.len() {
635                        ls += self.config.learning_rate
636                            * self.scale_steps[s]
637                                .predict_smooth_auto(features, &self.auto_bandwidths);
638                    }
639                }
640                (crate::math::exp(ls).max(1e-8), ls)
641            }
642        };
643
644        GaussianPrediction {
645            mu,
646            sigma,
647            log_sigma,
648        }
649    }
650
651    /// Full-tree location prediction (fallback when packed cache is unavailable).
652    fn predict_full_trees(&self, features: &[f64]) -> f64 {
653        let mut mu = self.location_base;
654        if self.auto_bandwidths.is_empty() {
655            for s in 0..self.location_steps.len() {
656                mu += self.config.learning_rate * self.location_steps[s].predict(features);
657            }
658        } else {
659            for s in 0..self.location_steps.len() {
660                mu += self.config.learning_rate
661                    * self.location_steps[s].predict_smooth_auto(features, &self.auto_bandwidths);
662            }
663        }
664        mu
665    }
666
667    /// Predict using sigmoid-blended soft routing for smooth interpolation.
668    ///
669    /// Instead of hard left/right routing at tree split nodes, each split
670    /// uses sigmoid blending: `alpha = sigmoid((threshold - feature) / bandwidth)`.
671    /// The result is a continuous function that varies smoothly with every
672    /// feature change.
673    ///
674    /// `bandwidth` controls transition sharpness: smaller = sharper (closer
675    /// to hard splits), larger = smoother.
676    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> GaussianPrediction {
677        let mut mu = self.location_base;
678        for s in 0..self.location_steps.len() {
679            mu += self.config.learning_rate
680                * self.location_steps[s].predict_smooth(features, bandwidth);
681        }
682
683        let (sigma, log_sigma) = match self.scale_mode {
684            ScaleMode::Empirical => {
685                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
686                (s, crate::math::ln(s))
687            }
688            ScaleMode::TreeChain => {
689                let mut ls = self.scale_base;
690                for s in 0..self.scale_steps.len() {
691                    ls += self.config.learning_rate
692                        * self.scale_steps[s].predict_smooth(features, bandwidth);
693                }
694                (crate::math::exp(ls).max(1e-8), ls)
695            }
696        };
697
698        GaussianPrediction {
699            mu,
700            sigma,
701            log_sigma,
702        }
703    }
704
705    /// Predict with parent-leaf linear interpolation.
706    ///
707    /// Blends each leaf prediction with its parent's preserved prediction
708    /// based on sample count, preventing stale predictions from fresh leaves.
709    pub fn predict_interpolated(&self, features: &[f64]) -> GaussianPrediction {
710        let mut mu = self.location_base;
711        for s in 0..self.location_steps.len() {
712            mu += self.config.learning_rate * self.location_steps[s].predict_interpolated(features);
713        }
714
715        let (sigma, log_sigma) = match self.scale_mode {
716            ScaleMode::Empirical => {
717                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
718                (s, crate::math::ln(s))
719            }
720            ScaleMode::TreeChain => {
721                let mut ls = self.scale_base;
722                for s in 0..self.scale_steps.len() {
723                    ls += self.config.learning_rate
724                        * self.scale_steps[s].predict_interpolated(features);
725                }
726                (crate::math::exp(ls).max(1e-8), ls)
727            }
728        };
729
730        GaussianPrediction {
731            mu,
732            sigma,
733            log_sigma,
734        }
735    }
736
737    /// Predict with sibling-based interpolation for feature-continuous predictions.
738    ///
739    /// At each split node near the threshold boundary, blends left and right
740    /// subtree predictions linearly. Uses auto-calibrated bandwidths as the
741    /// interpolation margin. Predictions vary continuously as features change.
742    pub fn predict_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
743        let mut mu = self.location_base;
744        for s in 0..self.location_steps.len() {
745            mu += self.config.learning_rate
746                * self.location_steps[s]
747                    .predict_sibling_interpolated(features, &self.auto_bandwidths);
748        }
749
750        let (sigma, log_sigma) = match self.scale_mode {
751            ScaleMode::Empirical => {
752                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
753                (s, crate::math::ln(s))
754            }
755            ScaleMode::TreeChain => {
756                let mut ls = self.scale_base;
757                for s in 0..self.scale_steps.len() {
758                    ls += self.config.learning_rate
759                        * self.scale_steps[s]
760                            .predict_sibling_interpolated(features, &self.auto_bandwidths);
761                }
762                (crate::math::exp(ls).max(1e-8), ls)
763            }
764        };
765
766        GaussianPrediction {
767            mu,
768            sigma,
769            log_sigma,
770        }
771    }
772
773    /// Predict with graduated active-shadow blending.
774    ///
775    /// Smoothly transitions between active and shadow trees during replacement.
776    /// Requires `shadow_warmup` to be configured.
777    pub fn predict_graduated(&self, features: &[f64]) -> GaussianPrediction {
778        let mut mu = self.location_base;
779        for s in 0..self.location_steps.len() {
780            mu += self.config.learning_rate * self.location_steps[s].predict_graduated(features);
781        }
782
783        let (sigma, log_sigma) = match self.scale_mode {
784            ScaleMode::Empirical => {
785                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
786                (s, crate::math::ln(s))
787            }
788            ScaleMode::TreeChain => {
789                let mut ls = self.scale_base;
790                for s in 0..self.scale_steps.len() {
791                    ls +=
792                        self.config.learning_rate * self.scale_steps[s].predict_graduated(features);
793                }
794                (crate::math::exp(ls).max(1e-8), ls)
795            }
796        };
797
798        GaussianPrediction {
799            mu,
800            sigma,
801            log_sigma,
802        }
803    }
804
805    /// Predict with graduated blending + sibling interpolation (premium path).
806    pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
807        let mut mu = self.location_base;
808        for s in 0..self.location_steps.len() {
809            mu += self.config.learning_rate
810                * self.location_steps[s]
811                    .predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
812        }
813
814        let (sigma, log_sigma) = match self.scale_mode {
815            ScaleMode::Empirical => {
816                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
817                (s, crate::math::ln(s))
818            }
819            ScaleMode::TreeChain => {
820                let mut ls = self.scale_base;
821                for s in 0..self.scale_steps.len() {
822                    ls += self.config.learning_rate
823                        * self.scale_steps[s].predict_graduated_sibling_interpolated(
824                            features,
825                            &self.auto_bandwidths,
826                        );
827                }
828                (crate::math::exp(ls).max(1e-8), ls)
829            }
830        };
831
832        GaussianPrediction {
833            mu,
834            sigma,
835            log_sigma,
836        }
837    }
838
839    /// Predict with σ-ratio diagnostic exposed.
840    ///
841    /// Returns `(mu, sigma, sigma_ratio)` where `sigma_ratio` is
842    /// `current_sigma / rolling_sigma_mean` -- the multiplier applied to the
843    /// location learning rate when [`uncertainty_modulated_lr`](SGBTConfig::uncertainty_modulated_lr)
844    /// is enabled.
845    ///
846    /// When σ-modulation is disabled, `sigma_ratio` is always `1.0`.
847    pub fn predict_distributional(&self, features: &[f64]) -> (f64, f64, f64) {
848        let pred = self.predict(features);
849        let sigma_ratio = if self.uncertainty_modulated_lr {
850            (pred.sigma / self.rolling_sigma_mean).clamp(0.1, 10.0)
851        } else {
852            1.0
853        };
854        (pred.mu, pred.sigma, sigma_ratio)
855    }
856
857    /// Current empirical sigma (`sqrt(ewma_sq_err)`).
858    ///
859    /// Returns the model's recent error magnitude. Available in both scale modes.
860    #[inline]
861    pub fn empirical_sigma(&self) -> f64 {
862        crate::math::sqrt(self.ewma_sq_err)
863    }
864
865    /// Current scale mode.
866    #[inline]
867    pub fn scale_mode(&self) -> ScaleMode {
868        self.scale_mode
869    }
870
871    /// Current σ velocity -- the EWMA-smoothed derivative of empirical σ.
872    ///
873    /// Positive values indicate growing prediction errors (model deteriorating
874    /// or regime change). Negative values indicate improving predictions.
875    /// Only meaningful when `ScaleMode::Empirical` is active.
876    #[inline]
877    pub fn sigma_velocity(&self) -> f64 {
878        self.sigma_velocity
879    }
880
881    /// Predict the mean (location parameter) only.
882    #[inline]
883    pub fn predict_mu(&self, features: &[f64]) -> f64 {
884        self.predict(features).mu
885    }
886
887    /// Predict the standard deviation (scale parameter) only.
888    #[inline]
889    pub fn predict_sigma(&self, features: &[f64]) -> f64 {
890        self.predict(features).sigma
891    }
892
893    /// Predict a symmetric confidence interval.
894    ///
895    /// `confidence` is the Z-score multiplier:
896    /// - 1.0 → 68% CI
897    /// - 1.96 → 95% CI
898    /// - 2.576 → 99% CI
899    pub fn predict_interval(&self, features: &[f64], confidence: f64) -> (f64, f64) {
900        let pred = self.predict(features);
901        (pred.lower(confidence), pred.upper(confidence))
902    }
903
904    /// Batch prediction.
905    pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<GaussianPrediction> {
906        feature_matrix.iter().map(|f| self.predict(f)).collect()
907    }
908
909    /// Train on a batch of observations.
910    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
911        for sample in samples {
912            self.train_one(sample);
913        }
914    }
915
916    /// Train on a batch with periodic callback.
917    pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
918        &mut self,
919        samples: &[O],
920        interval: usize,
921        mut callback: F,
922    ) {
923        let interval = interval.max(1);
924        for (i, sample) in samples.iter().enumerate() {
925            self.train_one(sample);
926            if (i + 1) % interval == 0 {
927                callback(i + 1);
928            }
929        }
930        let total = samples.len();
931        if total % interval != 0 {
932            callback(total);
933        }
934    }
935
936    /// Compute location gradient (squared or adaptive Huber).
937    ///
938    /// When `huber_k` is configured, uses Huber loss with adaptive
939    /// `delta = k * empirical_sigma` for bounded gradients.
940    #[inline]
941    fn location_gradient(&self, mu: f64, target: f64) -> (f64, f64) {
942        if let Some(k) = self.config.huber_k {
943            let delta = k * crate::math::sqrt(self.ewma_sq_err).max(1e-8);
944            let residual = mu - target;
945            if crate::math::abs(residual) <= delta {
946                (residual, 1.0)
947            } else {
948                (delta * residual.signum(), 1e-6)
949            }
950        } else {
951            (mu - target, 1.0)
952        }
953    }
954
955    /// Welford update for ensemble-level gradient statistics.
956    #[inline]
957    fn update_ensemble_grad_stats(&mut self, gradient: f64) {
958        self.ensemble_grad_count += 1;
959        let delta = gradient - self.ensemble_grad_mean;
960        self.ensemble_grad_mean += delta / self.ensemble_grad_count as f64;
961        let delta2 = gradient - self.ensemble_grad_mean;
962        self.ensemble_grad_m2 += delta * delta2;
963    }
964
965    /// Ensemble-level gradient standard deviation.
966    pub fn ensemble_grad_std(&self) -> f64 {
967        if self.ensemble_grad_count < 2 {
968            return 0.0;
969        }
970        crate::math::fmax(
971            crate::math::sqrt(self.ensemble_grad_m2 / (self.ensemble_grad_count - 1) as f64),
972            0.0,
973        )
974    }
975
976    /// Ensemble-level gradient mean.
977    pub fn ensemble_grad_mean(&self) -> f64 {
978        self.ensemble_grad_mean
979    }
980
981    /// Check if packed cache should be refreshed and do so if interval reached.
982    fn maybe_refresh_packed_cache(&mut self) {
983        if self.packed_refresh_interval > 0 {
984            self.samples_since_refresh += 1;
985            if self.samples_since_refresh >= self.packed_refresh_interval {
986                self.refresh_packed_cache();
987                self.samples_since_refresh = 0;
988            }
989        }
990    }
991
992    /// Re-export the location ensemble into the packed cache.
993    ///
994    /// In `irithyll-core` this is a no-op because `export_embedded` is not
995    /// available. The full `irithyll` crate overrides this via its re-export
996    /// layer to populate the cache with packed f32 bytes.
997    fn refresh_packed_cache(&mut self) {
998        // export_distributional_packed lives in irithyll (not irithyll-core),
999        // so packed cache is not populated in core-only builds.
1000        // The full irithyll crate hooks into this via its own DistributionalSGBT
1001        // re-export which adds the packed cache logic.
1002    }
1003
1004    /// Enable or reconfigure the packed inference cache at runtime.
1005    ///
1006    /// Sets the refresh interval and immediately builds the initial cache
1007    /// if the model has been initialized. Pass `0` to disable.
1008    pub fn enable_packed_cache(&mut self, interval: u64) {
1009        self.packed_refresh_interval = interval;
1010        self.samples_since_refresh = 0;
1011        if interval > 0 && self.base_initialized {
1012            self.refresh_packed_cache();
1013        } else if interval == 0 {
1014            self.packed_cache = None;
1015        }
1016    }
1017
1018    /// Whether the packed inference cache is currently populated.
1019    #[inline]
1020    pub fn has_packed_cache(&self) -> bool {
1021        self.packed_cache.is_some()
1022    }
1023
1024    /// Refresh auto-bandwidths if any tree has been replaced since last computation.
1025    fn refresh_bandwidths(&mut self) {
1026        let current_sum: u64 = self
1027            .location_steps
1028            .iter()
1029            .chain(self.scale_steps.iter())
1030            .map(|s| s.slot().replacements())
1031            .sum();
1032        if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
1033            self.auto_bandwidths = self.compute_auto_bandwidths();
1034            self.last_replacement_sum = current_sum;
1035        }
1036    }
1037
1038    /// Compute per-feature auto-calibrated bandwidths from all trees.
1039    fn compute_auto_bandwidths(&self) -> Vec<f64> {
1040        const K: f64 = 2.0;
1041
1042        let n_features = self
1043            .location_steps
1044            .iter()
1045            .chain(self.scale_steps.iter())
1046            .filter_map(|s| s.slot().active_tree().n_features())
1047            .max()
1048            .unwrap_or(0);
1049
1050        if n_features == 0 {
1051            return Vec::new();
1052        }
1053
1054        let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
1055
1056        for step in self.location_steps.iter().chain(self.scale_steps.iter()) {
1057            let tree_thresholds = step
1058                .slot()
1059                .active_tree()
1060                .collect_split_thresholds_per_feature();
1061            for (i, ts) in tree_thresholds.into_iter().enumerate() {
1062                if i < n_features {
1063                    all_thresholds[i].extend(ts);
1064                }
1065            }
1066        }
1067
1068        let n_bins = self.config.n_bins as f64;
1069
1070        all_thresholds
1071            .iter()
1072            .map(|ts| {
1073                if ts.is_empty() {
1074                    return f64::INFINITY;
1075                }
1076
1077                let mut sorted = ts.clone();
1078                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
1079                sorted.dedup_by(|a, b| crate::math::abs(*a - *b) < 1e-15);
1080
1081                if sorted.len() < 2 {
1082                    return f64::INFINITY;
1083                }
1084
1085                let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
1086
1087                if sorted.len() < 3 {
1088                    let range = sorted.last().unwrap() - sorted.first().unwrap();
1089                    if range < 1e-15 {
1090                        return f64::INFINITY;
1091                    }
1092                    return (range / n_bins) * K;
1093                }
1094
1095                gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
1096                let median_gap = if gaps.len() % 2 == 0 {
1097                    (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
1098                } else {
1099                    gaps[gaps.len() / 2]
1100                };
1101
1102                if median_gap < 1e-15 {
1103                    f64::INFINITY
1104                } else {
1105                    median_gap * K
1106                }
1107            })
1108            .collect()
1109    }
1110
1111    /// Per-feature auto-calibrated bandwidths used by `predict()`.
1112    pub fn auto_bandwidths(&self) -> &[f64] {
1113        &self.auto_bandwidths
1114    }
1115
1116    /// Reset to initial untrained state.
1117    pub fn reset(&mut self) {
1118        for step in &mut self.location_steps {
1119            step.reset();
1120        }
1121        for step in &mut self.scale_steps {
1122            step.reset();
1123        }
1124        self.location_base = 0.0;
1125        self.scale_base = 0.0;
1126        self.base_initialized = false;
1127        self.initial_targets.clear();
1128        self.samples_seen = 0;
1129        self.rng_state = self.config.seed;
1130        self.rolling_sigma_mean = 1.0;
1131        self.ewma_sq_err = 1.0;
1132        self.prev_sigma = 0.0;
1133        self.sigma_velocity = 0.0;
1134        self.auto_bandwidths.clear();
1135        self.last_replacement_sum = 0;
1136        self.ensemble_grad_mean = 0.0;
1137        self.ensemble_grad_m2 = 0.0;
1138        self.ensemble_grad_count = 0;
1139        self.packed_cache = None;
1140        self.samples_since_refresh = 0;
1141    }
1142
1143    /// Total samples trained.
1144    #[inline]
1145    pub fn n_samples_seen(&self) -> u64 {
1146        self.samples_seen
1147    }
1148
1149    /// Number of boosting steps (same for location and scale).
1150    #[inline]
1151    pub fn n_steps(&self) -> usize {
1152        self.location_steps.len()
1153    }
1154
1155    /// Total trees (location + scale, active + alternates).
1156    pub fn n_trees(&self) -> usize {
1157        let loc = self.location_steps.len()
1158            + self
1159                .location_steps
1160                .iter()
1161                .filter(|s| s.has_alternate())
1162                .count();
1163        let scale = self.scale_steps.len()
1164            + self
1165                .scale_steps
1166                .iter()
1167                .filter(|s| s.has_alternate())
1168                .count();
1169        loc + scale
1170    }
1171
1172    /// Total leaves across all active trees (location + scale).
1173    pub fn total_leaves(&self) -> usize {
1174        let loc: usize = self.location_steps.iter().map(|s| s.n_leaves()).sum();
1175        let scale: usize = self.scale_steps.iter().map(|s| s.n_leaves()).sum();
1176        loc + scale
1177    }
1178
1179    /// Whether base predictions have been initialized.
1180    #[inline]
1181    pub fn is_initialized(&self) -> bool {
1182        self.base_initialized
1183    }
1184
1185    /// Access the configuration.
1186    #[inline]
1187    pub fn config(&self) -> &SGBTConfig {
1188        &self.config
1189    }
1190
1191    /// Access the location boosting steps (for export/inspection).
1192    pub fn location_steps(&self) -> &[BoostingStep] {
1193        &self.location_steps
1194    }
1195
1196    /// Base prediction for the location (mean) ensemble.
1197    #[inline]
1198    pub fn location_base(&self) -> f64 {
1199        self.location_base
1200    }
1201
1202    /// Learning rate from the model configuration.
1203    #[inline]
1204    pub fn learning_rate(&self) -> f64 {
1205        self.config.learning_rate
1206    }
1207
1208    /// Current rolling σ mean (EWMA of predicted σ).
1209    ///
1210    /// Returns `1.0` if the model hasn't been initialized yet.
1211    #[inline]
1212    pub fn rolling_sigma_mean(&self) -> f64 {
1213        self.rolling_sigma_mean
1214    }
1215
1216    /// Whether σ-modulated learning rate is active.
1217    #[inline]
1218    pub fn is_uncertainty_modulated(&self) -> bool {
1219        self.uncertainty_modulated_lr
1220    }
1221
1222    // -------------------------------------------------------------------
1223    // Diagnostics
1224    // -------------------------------------------------------------------
1225
1226    /// Full model diagnostics: per-tree structure, feature usage, base predictions.
1227    ///
1228    /// The `trees` vector contains location trees first (indices `0..n_steps`),
1229    /// then scale trees (`n_steps..2*n_steps`).
1230    ///
1231    /// `scale_trees_active` counts how many scale trees have actually split
1232    /// (more than 1 leaf). If this is 0, the scale chain is effectively frozen.
1233    pub fn diagnostics(&self) -> ModelDiagnostics {
1234        let n = self.location_steps.len();
1235        let mut trees = Vec::with_capacity(2 * n);
1236        let mut feature_split_counts: Vec<usize> = Vec::new();
1237
1238        fn collect_tree_diags(
1239            steps: &[BoostingStep],
1240            trees: &mut Vec<TreeDiagnostic>,
1241            feature_split_counts: &mut Vec<usize>,
1242        ) {
1243            for step in steps {
1244                let slot = step.slot();
1245                let tree = slot.active_tree();
1246                let arena = tree.arena();
1247
1248                let leaf_values: Vec<f64> = (0..arena.is_leaf.len())
1249                    .filter(|&i| arena.is_leaf[i])
1250                    .map(|i| arena.leaf_value[i])
1251                    .collect();
1252
1253                let leaf_sample_counts: Vec<u64> = (0..arena.is_leaf.len())
1254                    .filter(|&i| arena.is_leaf[i])
1255                    .map(|i| arena.sample_count[i])
1256                    .collect();
1257
1258                let max_depth_reached = (0..arena.is_leaf.len())
1259                    .filter(|&i| arena.is_leaf[i])
1260                    .map(|i| arena.depth[i] as usize)
1261                    .max()
1262                    .unwrap_or(0);
1263
1264                let leaf_weight_stats = if leaf_values.is_empty() {
1265                    (0.0, 0.0, 0.0, 0.0)
1266                } else {
1267                    let min = leaf_values.iter().cloned().fold(f64::INFINITY, f64::min);
1268                    let max = leaf_values
1269                        .iter()
1270                        .cloned()
1271                        .fold(f64::NEG_INFINITY, f64::max);
1272                    let sum: f64 = leaf_values.iter().sum();
1273                    let mean = sum / leaf_values.len() as f64;
1274                    let var: f64 = leaf_values
1275                        .iter()
1276                        .map(|v| crate::math::powi(v - mean, 2))
1277                        .sum::<f64>()
1278                        / leaf_values.len() as f64;
1279                    (min, max, mean, crate::math::sqrt(var))
1280                };
1281
1282                let gains = slot.split_gains();
1283                let split_features: Vec<usize> = gains
1284                    .iter()
1285                    .enumerate()
1286                    .filter(|(_, &g)| g > 0.0)
1287                    .map(|(i, _)| i)
1288                    .collect();
1289
1290                if !gains.is_empty() {
1291                    if feature_split_counts.is_empty() {
1292                        feature_split_counts.resize(gains.len(), 0);
1293                    }
1294                    for &fi in &split_features {
1295                        if fi < feature_split_counts.len() {
1296                            feature_split_counts[fi] += 1;
1297                        }
1298                    }
1299                }
1300
1301                trees.push(TreeDiagnostic {
1302                    n_leaves: leaf_values.len(),
1303                    max_depth_reached,
1304                    samples_seen: step.n_samples_seen(),
1305                    leaf_weight_stats,
1306                    split_features,
1307                    leaf_sample_counts,
1308                    prediction_mean: slot.prediction_mean(),
1309                    prediction_std: slot.prediction_std(),
1310                });
1311            }
1312        }
1313
1314        collect_tree_diags(&self.location_steps, &mut trees, &mut feature_split_counts);
1315        collect_tree_diags(&self.scale_steps, &mut trees, &mut feature_split_counts);
1316
1317        let location_trees = trees[..n].to_vec();
1318        let scale_trees = trees[n..].to_vec();
1319        let scale_trees_active = scale_trees.iter().filter(|t| t.n_leaves > 1).count();
1320
1321        ModelDiagnostics {
1322            trees,
1323            location_trees,
1324            scale_trees,
1325            feature_split_counts,
1326            location_base: self.location_base,
1327            scale_base: self.scale_base,
1328            empirical_sigma: crate::math::sqrt(self.ewma_sq_err),
1329            scale_mode: self.scale_mode,
1330            scale_trees_active,
1331            auto_bandwidths: self.auto_bandwidths.clone(),
1332            ensemble_grad_mean: self.ensemble_grad_mean,
1333            ensemble_grad_std: self.ensemble_grad_std(),
1334        }
1335    }
1336
1337    /// Per-tree contribution to the final prediction.
1338    ///
1339    /// Returns two vectors: location contributions and scale contributions.
1340    /// Each entry is `learning_rate * tree_prediction` -- the additive
1341    /// contribution of that boosting step to the final μ or log(σ).
1342    ///
1343    /// Summing `location_base + sum(location_contributions)` recovers μ.
1344    /// Summing `scale_base + sum(scale_contributions)` recovers log(σ).
1345    ///
1346    /// In `Empirical` scale mode, `scale_base` is `ln(empirical_sigma)` and
1347    /// `scale_contributions` are all zero (σ is not tree-derived).
1348    pub fn predict_decomposed(&self, features: &[f64]) -> DecomposedPrediction {
1349        let lr = self.config.learning_rate;
1350        let location: Vec<f64> = self
1351            .location_steps
1352            .iter()
1353            .map(|s| lr * s.predict(features))
1354            .collect();
1355
1356        let (sb, scale) = match self.scale_mode {
1357            ScaleMode::Empirical => {
1358                let empirical_sigma = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
1359                (
1360                    crate::math::ln(empirical_sigma),
1361                    vec![0.0; self.location_steps.len()],
1362                )
1363            }
1364            ScaleMode::TreeChain => {
1365                let s: Vec<f64> = self
1366                    .scale_steps
1367                    .iter()
1368                    .map(|s| lr * s.predict(features))
1369                    .collect();
1370                (self.scale_base, s)
1371            }
1372        };
1373
1374        DecomposedPrediction {
1375            location_base: self.location_base,
1376            scale_base: sb,
1377            location_contributions: location,
1378            scale_contributions: scale,
1379        }
1380    }
1381
1382    /// Feature importances based on accumulated split gains across all trees.
1383    ///
1384    /// Aggregates gains from both location and scale ensembles, then
1385    /// normalizes to sum to 1.0. Indexed by feature.
1386    /// Returns an empty Vec if no splits have occurred yet.
1387    pub fn feature_importances(&self) -> Vec<f64> {
1388        let mut totals: Vec<f64> = Vec::new();
1389        for steps in [&self.location_steps, &self.scale_steps] {
1390            for step in steps {
1391                let gains = step.slot().split_gains();
1392                if totals.is_empty() && !gains.is_empty() {
1393                    totals.resize(gains.len(), 0.0);
1394                }
1395                for (i, &g) in gains.iter().enumerate() {
1396                    if i < totals.len() {
1397                        totals[i] += g;
1398                    }
1399                }
1400            }
1401        }
1402        let sum: f64 = totals.iter().sum();
1403        if sum > 0.0 {
1404            totals.iter_mut().for_each(|v| *v /= sum);
1405        }
1406        totals
1407    }
1408
1409    /// Feature importances split by ensemble: `(location_importances, scale_importances)`.
1410    ///
1411    /// Each vector is independently normalized to sum to 1.0.
1412    /// Useful for understanding which features drive the mean vs. the uncertainty.
1413    pub fn feature_importances_split(&self) -> (Vec<f64>, Vec<f64>) {
1414        fn aggregate(steps: &[BoostingStep]) -> Vec<f64> {
1415            let mut totals: Vec<f64> = Vec::new();
1416            for step in steps {
1417                let gains = step.slot().split_gains();
1418                if totals.is_empty() && !gains.is_empty() {
1419                    totals.resize(gains.len(), 0.0);
1420                }
1421                for (i, &g) in gains.iter().enumerate() {
1422                    if i < totals.len() {
1423                        totals[i] += g;
1424                    }
1425                }
1426            }
1427            let sum: f64 = totals.iter().sum();
1428            if sum > 0.0 {
1429                totals.iter_mut().for_each(|v| *v /= sum);
1430            }
1431            totals
1432        }
1433        (
1434            aggregate(&self.location_steps),
1435            aggregate(&self.scale_steps),
1436        )
1437    }
1438
1439    /// Convert this model into a serializable [`crate::serde_support::DistributionalModelState`].
1440    ///
1441    /// Captures the full ensemble state (both location and scale trees) for
1442    /// persistence. Histogram accumulators are NOT serialized -- they rebuild
1443    /// naturally from continued training.
1444    #[cfg(feature = "_serde_support")]
1445    pub fn to_distributional_state(&self) -> crate::serde_support::DistributionalModelState {
1446        use super::snapshot_tree;
1447        use crate::serde_support::{DistributionalModelState, StepSnapshot};
1448
1449        fn snapshot_step(step: &BoostingStep) -> StepSnapshot {
1450            let slot = step.slot();
1451            let tree_snap = snapshot_tree(slot.active_tree());
1452            let alt_snap = slot.alternate_tree().map(snapshot_tree);
1453            let drift_state = slot.detector().serialize_state();
1454            let alt_drift_state = slot.alt_detector().and_then(|d| d.serialize_state());
1455            StepSnapshot {
1456                tree: tree_snap,
1457                alternate_tree: alt_snap,
1458                drift_state,
1459                alt_drift_state,
1460            }
1461        }
1462
1463        DistributionalModelState {
1464            config: self.config.clone(),
1465            location_steps: self.location_steps.iter().map(snapshot_step).collect(),
1466            scale_steps: self.scale_steps.iter().map(snapshot_step).collect(),
1467            location_base: self.location_base,
1468            scale_base: self.scale_base,
1469            base_initialized: self.base_initialized,
1470            initial_targets: self.initial_targets.clone(),
1471            initial_target_count: self.initial_target_count,
1472            samples_seen: self.samples_seen,
1473            rng_state: self.rng_state,
1474            uncertainty_modulated_lr: self.uncertainty_modulated_lr,
1475            rolling_sigma_mean: self.rolling_sigma_mean,
1476            ewma_sq_err: self.ewma_sq_err,
1477        }
1478    }
1479
1480    /// Reconstruct a [`DistributionalSGBT`] from a serialized [`crate::serde_support::DistributionalModelState`].
1481    ///
1482    /// Rebuilds both location and scale ensembles including tree topology
1483    /// and leaf values. Histogram accumulators are left empty and will
1484    /// rebuild from continued training.
1485    #[cfg(feature = "_serde_support")]
1486    pub fn from_distributional_state(
1487        state: crate::serde_support::DistributionalModelState,
1488    ) -> Self {
1489        use super::rebuild_tree;
1490        use crate::ensemble::replacement::TreeSlot;
1491        use crate::serde_support::StepSnapshot;
1492
1493        let leaf_decay_alpha = state
1494            .config
1495            .leaf_half_life
1496            .map(|hl| crate::math::exp((-(crate::math::ln(2.0_f64)) / hl as f64)));
1497        let max_tree_samples = state.config.max_tree_samples;
1498
1499        let base_tree_config = TreeConfig::new()
1500            .max_depth(state.config.max_depth)
1501            .n_bins(state.config.n_bins)
1502            .lambda(state.config.lambda)
1503            .gamma(state.config.gamma)
1504            .grace_period(state.config.grace_period)
1505            .delta(state.config.delta)
1506            .feature_subsample_rate(state.config.feature_subsample_rate)
1507            .leaf_decay_alpha_opt(leaf_decay_alpha)
1508            .split_reeval_interval_opt(state.config.split_reeval_interval)
1509            .feature_types_opt(state.config.feature_types.clone())
1510            .gradient_clip_sigma_opt(state.config.gradient_clip_sigma)
1511            .monotone_constraints_opt(state.config.monotone_constraints.clone())
1512            .leaf_model_type(state.config.leaf_model_type.clone());
1513
1514        // Rebuild a Vec<BoostingStep> from step snapshots with a given seed transform.
1515        let rebuild_steps = |snaps: &[StepSnapshot], seed_xor: u64| -> Vec<BoostingStep> {
1516            snaps
1517                .iter()
1518                .enumerate()
1519                .map(|(i, snap)| {
1520                    let tc = base_tree_config
1521                        .clone()
1522                        .seed(state.config.seed ^ (i as u64) ^ seed_xor);
1523
1524                    let active = rebuild_tree(&snap.tree, tc.clone());
1525                    let alternate = snap
1526                        .alternate_tree
1527                        .as_ref()
1528                        .map(|s| rebuild_tree(s, tc.clone()));
1529
1530                    let mut detector = state.config.drift_detector.create();
1531                    if let Some(ref ds) = snap.drift_state {
1532                        detector.restore_state(ds);
1533                    }
1534                    let mut slot =
1535                        TreeSlot::from_trees(active, alternate, tc, detector, max_tree_samples);
1536                    if let Some(ref ads) = snap.alt_drift_state {
1537                        if let Some(alt_det) = slot.alt_detector_mut() {
1538                            alt_det.restore_state(ads);
1539                        }
1540                    }
1541                    BoostingStep::from_slot(slot)
1542                })
1543                .collect()
1544        };
1545
1546        // Location: seed offset 0, Scale: seed offset 0x0005_CA1E_0000_0000
1547        let location_steps = rebuild_steps(&state.location_steps, 0);
1548        let scale_steps = rebuild_steps(&state.scale_steps, 0x0005_CA1E_0000_0000);
1549
1550        let scale_mode = state.config.scale_mode;
1551        let empirical_sigma_alpha = state.config.empirical_sigma_alpha;
1552        let packed_refresh_interval = state.config.packed_refresh_interval;
1553        Self {
1554            config: state.config,
1555            location_steps,
1556            scale_steps,
1557            location_base: state.location_base,
1558            scale_base: state.scale_base,
1559            base_initialized: state.base_initialized,
1560            initial_targets: state.initial_targets,
1561            initial_target_count: state.initial_target_count,
1562            samples_seen: state.samples_seen,
1563            rng_state: state.rng_state,
1564            uncertainty_modulated_lr: state.uncertainty_modulated_lr,
1565            rolling_sigma_mean: state.rolling_sigma_mean,
1566            scale_mode,
1567            ewma_sq_err: state.ewma_sq_err,
1568            empirical_sigma_alpha,
1569            prev_sigma: 0.0,
1570            sigma_velocity: 0.0,
1571            auto_bandwidths: Vec::new(),
1572            last_replacement_sum: 0,
1573            ensemble_grad_mean: 0.0,
1574            ensemble_grad_m2: 0.0,
1575            ensemble_grad_count: 0,
1576            packed_cache: None,
1577            samples_since_refresh: 0,
1578            packed_refresh_interval,
1579        }
1580    }
1581}
1582
1583// ---------------------------------------------------------------------------
1584// StreamingLearner impl
1585// ---------------------------------------------------------------------------
1586
1587use crate::learner::StreamingLearner;
1588
1589impl StreamingLearner for DistributionalSGBT {
1590    fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
1591        let sample = SampleRef::weighted(features, target, weight);
1592        // UFCS: call the inherent train_one(&impl Observation), not this trait method.
1593        DistributionalSGBT::train_one(self, &sample);
1594    }
1595
1596    /// Returns the mean (μ) of the predicted Gaussian distribution.
1597    fn predict(&self, features: &[f64]) -> f64 {
1598        DistributionalSGBT::predict(self, features).mu
1599    }
1600
1601    fn n_samples_seen(&self) -> u64 {
1602        self.samples_seen
1603    }
1604
1605    fn reset(&mut self) {
1606        DistributionalSGBT::reset(self);
1607    }
1608}
1609
1610// ---------------------------------------------------------------------------
1611// Tests
1612// ---------------------------------------------------------------------------
1613
1614#[cfg(test)]
1615mod tests {
1616    use super::*;
1617    use alloc::format;
1618    use alloc::vec;
1619    use alloc::vec::Vec;
1620
1621    fn test_config() -> SGBTConfig {
1622        SGBTConfig::builder()
1623            .n_steps(10)
1624            .learning_rate(0.1)
1625            .grace_period(20)
1626            .max_depth(4)
1627            .n_bins(16)
1628            .initial_target_count(10)
1629            .build()
1630            .unwrap()
1631    }
1632
1633    #[test]
1634    fn fresh_model_predicts_zero() {
1635        let model = DistributionalSGBT::new(test_config());
1636        let pred = model.predict(&[1.0, 2.0, 3.0]);
1637        assert!(pred.mu.abs() < 1e-12);
1638        assert!(pred.sigma > 0.0);
1639    }
1640
1641    #[test]
1642    fn sigma_always_positive() {
1643        let mut model = DistributionalSGBT::new(test_config());
1644
1645        // Train on various data
1646        for i in 0..200 {
1647            let x = i as f64 * 0.1;
1648            model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
1649        }
1650
1651        // Check multiple predictions
1652        for i in 0..20 {
1653            let x = i as f64 * 0.5;
1654            let pred = model.predict(&[x, x * 0.5]);
1655            assert!(
1656                pred.sigma > 0.0,
1657                "sigma must be positive, got {}",
1658                pred.sigma
1659            );
1660            assert!(pred.sigma.is_finite(), "sigma must be finite");
1661        }
1662    }
1663
1664    #[test]
1665    fn constant_target_has_small_sigma() {
1666        let mut model = DistributionalSGBT::new(test_config());
1667
1668        // Train on constant target = 5.0 with varying features
1669        for i in 0..200 {
1670            let x = i as f64 * 0.1;
1671            model.train_one(&(vec![x, x * 2.0], 5.0));
1672        }
1673
1674        let pred = model.predict(&[1.0, 2.0]);
1675        assert!(pred.mu.is_finite());
1676        assert!(pred.sigma.is_finite());
1677        assert!(pred.sigma > 0.0);
1678        // Sigma should be relatively small for constant target
1679        // (no noise to explain)
1680    }
1681
1682    #[test]
1683    fn noisy_target_has_finite_predictions() {
1684        let mut model = DistributionalSGBT::new(test_config());
1685
1686        // Simple xorshift for deterministic "noise"
1687        let mut rng: u64 = 42;
1688        for i in 0..200 {
1689            rng ^= rng << 13;
1690            rng ^= rng >> 7;
1691            rng ^= rng << 17;
1692            let noise = (rng % 1000) as f64 / 500.0 - 1.0; // [-1, 1]
1693            let x = i as f64 * 0.1;
1694            model.train_one(&(vec![x], x * 2.0 + noise));
1695        }
1696
1697        let pred = model.predict(&[5.0]);
1698        assert!(pred.mu.is_finite());
1699        assert!(pred.sigma.is_finite());
1700        assert!(pred.sigma > 0.0);
1701    }
1702
1703    #[test]
1704    fn predict_interval_bounds_correct() {
1705        let mut model = DistributionalSGBT::new(test_config());
1706
1707        for i in 0..200 {
1708            let x = i as f64 * 0.1;
1709            model.train_one(&(vec![x], x * 2.0));
1710        }
1711
1712        let (lo, hi) = model.predict_interval(&[5.0], 1.96);
1713        let pred = model.predict(&[5.0]);
1714
1715        assert!(lo < pred.mu, "lower bound should be < mu");
1716        assert!(hi > pred.mu, "upper bound should be > mu");
1717        assert!((hi - lo - 2.0 * 1.96 * pred.sigma).abs() < 1e-10);
1718    }
1719
1720    #[test]
1721    fn batch_prediction_matches_individual() {
1722        let mut model = DistributionalSGBT::new(test_config());
1723
1724        for i in 0..100 {
1725            let x = i as f64 * 0.1;
1726            model.train_one(&(vec![x, x * 2.0], x));
1727        }
1728
1729        let features = vec![vec![1.0, 2.0], vec![3.0, 6.0], vec![5.0, 10.0]];
1730        let batch = model.predict_batch(&features);
1731
1732        for (feat, batch_pred) in features.iter().zip(batch.iter()) {
1733            let individual = model.predict(feat);
1734            assert!((batch_pred.mu - individual.mu).abs() < 1e-12);
1735            assert!((batch_pred.sigma - individual.sigma).abs() < 1e-12);
1736        }
1737    }
1738
1739    #[test]
1740    fn reset_clears_state() {
1741        let mut model = DistributionalSGBT::new(test_config());
1742
1743        for i in 0..200 {
1744            let x = i as f64 * 0.1;
1745            model.train_one(&(vec![x], x * 2.0));
1746        }
1747
1748        assert!(model.n_samples_seen() > 0);
1749        model.reset();
1750
1751        assert_eq!(model.n_samples_seen(), 0);
1752        assert!(!model.is_initialized());
1753    }
1754
1755    #[test]
1756    fn gaussian_prediction_lower_upper() {
1757        let pred = GaussianPrediction {
1758            mu: 10.0,
1759            sigma: 2.0,
1760            log_sigma: 2.0_f64.ln(),
1761        };
1762
1763        assert!((pred.lower(1.96) - (10.0 - 1.96 * 2.0)).abs() < 1e-10);
1764        assert!((pred.upper(1.96) - (10.0 + 1.96 * 2.0)).abs() < 1e-10);
1765    }
1766
1767    #[test]
1768    fn train_batch_works() {
1769        let mut model = DistributionalSGBT::new(test_config());
1770        let samples: Vec<(Vec<f64>, f64)> = (0..100)
1771            .map(|i| {
1772                let x = i as f64 * 0.1;
1773                (vec![x], x * 2.0)
1774            })
1775            .collect();
1776
1777        model.train_batch(&samples);
1778        assert_eq!(model.n_samples_seen(), 100);
1779    }
1780
1781    #[test]
1782    fn debug_format_works() {
1783        let model = DistributionalSGBT::new(test_config());
1784        let debug = format!("{:?}", model);
1785        assert!(debug.contains("DistributionalSGBT"));
1786    }
1787
1788    #[test]
1789    fn n_trees_counts_both_ensembles() {
1790        let model = DistributionalSGBT::new(test_config());
1791        // 10 location + 10 scale = 20 trees minimum
1792        assert!(model.n_trees() >= 20);
1793    }
1794
1795    // -- σ-modulated learning rate tests --
1796
1797    fn modulated_config() -> SGBTConfig {
1798        SGBTConfig::builder()
1799            .n_steps(10)
1800            .learning_rate(0.1)
1801            .grace_period(20)
1802            .max_depth(4)
1803            .n_bins(16)
1804            .initial_target_count(10)
1805            .uncertainty_modulated_lr(true)
1806            .build()
1807            .unwrap()
1808    }
1809
1810    #[test]
1811    fn sigma_modulated_initializes_rolling_mean() {
1812        let mut model = DistributionalSGBT::new(modulated_config());
1813        assert!(model.is_uncertainty_modulated());
1814
1815        // Before initialization, rolling_sigma_mean is 1.0 (placeholder)
1816        assert!((model.rolling_sigma_mean() - 1.0).abs() < 1e-12);
1817
1818        // Train past initialization
1819        for i in 0..200 {
1820            let x = i as f64 * 0.1;
1821            model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
1822        }
1823
1824        // After training, rolling_sigma_mean should have adapted from initial std
1825        assert!(model.rolling_sigma_mean() > 0.0);
1826        assert!(model.rolling_sigma_mean().is_finite());
1827    }
1828
1829    #[test]
1830    fn predict_distributional_returns_sigma_ratio() {
1831        let mut model = DistributionalSGBT::new(modulated_config());
1832
1833        for i in 0..200 {
1834            let x = i as f64 * 0.1;
1835            model.train_one(&(vec![x], x * 2.0 + 1.0));
1836        }
1837
1838        let (mu, sigma, sigma_ratio) = model.predict_distributional(&[5.0]);
1839        assert!(mu.is_finite());
1840        assert!(sigma > 0.0);
1841        assert!(
1842            (0.1..=10.0).contains(&sigma_ratio),
1843            "sigma_ratio={}",
1844            sigma_ratio
1845        );
1846    }
1847
1848    #[test]
1849    fn predict_distributional_without_modulation_returns_one() {
1850        let mut model = DistributionalSGBT::new(test_config());
1851        assert!(!model.is_uncertainty_modulated());
1852
1853        for i in 0..200 {
1854            let x = i as f64 * 0.1;
1855            model.train_one(&(vec![x], x * 2.0));
1856        }
1857
1858        let (_mu, _sigma, sigma_ratio) = model.predict_distributional(&[5.0]);
1859        assert!(
1860            (sigma_ratio - 1.0).abs() < 1e-12,
1861            "should be 1.0 when disabled"
1862        );
1863    }
1864
1865    #[test]
1866    fn modulated_model_sigma_finite_under_varying_noise() {
1867        let mut model = DistributionalSGBT::new(modulated_config());
1868
1869        let mut rng: u64 = 123;
1870        for i in 0..500 {
1871            rng ^= rng << 13;
1872            rng ^= rng >> 7;
1873            rng ^= rng << 17;
1874            let noise = (rng % 1000) as f64 / 100.0 - 5.0; // [-5, 5]
1875            let x = i as f64 * 0.1;
1876            // Regime shift at i=250: noise amplitude increases
1877            let scale = if i < 250 { 1.0 } else { 5.0 };
1878            model.train_one(&(vec![x], x * 2.0 + noise * scale));
1879        }
1880
1881        let pred = model.predict(&[10.0]);
1882        assert!(pred.mu.is_finite());
1883        assert!(pred.sigma.is_finite());
1884        assert!(pred.sigma > 0.0);
1885        assert!(model.rolling_sigma_mean().is_finite());
1886    }
1887
1888    #[test]
1889    fn reset_clears_rolling_sigma_mean() {
1890        let mut model = DistributionalSGBT::new(modulated_config());
1891
1892        for i in 0..200 {
1893            let x = i as f64 * 0.1;
1894            model.train_one(&(vec![x], x * 2.0));
1895        }
1896
1897        let sigma_before = model.rolling_sigma_mean();
1898        assert!(sigma_before > 0.0);
1899
1900        model.reset();
1901        assert!((model.rolling_sigma_mean() - 1.0).abs() < 1e-12);
1902    }
1903
1904    #[test]
1905    fn streaming_learner_returns_mu() {
1906        let mut model = DistributionalSGBT::new(test_config());
1907        for i in 0..200 {
1908            let x = i as f64 * 0.1;
1909            StreamingLearner::train(&mut model, &[x], x * 2.0 + 1.0);
1910        }
1911        let pred = StreamingLearner::predict(&model, &[5.0]);
1912        let gaussian = DistributionalSGBT::predict(&model, &[5.0]);
1913        assert!(
1914            (pred - gaussian.mu).abs() < 1e-12,
1915            "StreamingLearner::predict should return mu"
1916        );
1917    }
1918
1919    // -- Diagnostics tests --
1920
1921    fn trained_model() -> DistributionalSGBT {
1922        let config = SGBTConfig::builder()
1923            .n_steps(10)
1924            .learning_rate(0.1)
1925            .grace_period(10) // low grace period to ensure splits
1926            .max_depth(4)
1927            .n_bins(16)
1928            .initial_target_count(10)
1929            .build()
1930            .unwrap();
1931        let mut model = DistributionalSGBT::new(config);
1932        for i in 0..500 {
1933            let x = i as f64 * 0.1;
1934            model.train_one(&(vec![x, x * 0.5, (i % 3) as f64], x * 2.0 + 1.0));
1935        }
1936        model
1937    }
1938
1939    #[test]
1940    fn diagnostics_returns_correct_tree_count() {
1941        let model = trained_model();
1942        let diag = model.diagnostics();
1943        // 10 location + 10 scale = 20 trees
1944        assert_eq!(diag.trees.len(), 20, "should have 2*n_steps trees");
1945    }
1946
1947    #[test]
1948    fn diagnostics_trees_have_leaves() {
1949        let model = trained_model();
1950        let diag = model.diagnostics();
1951        for (i, tree) in diag.trees.iter().enumerate() {
1952            assert!(tree.n_leaves >= 1, "tree {i} should have at least 1 leaf");
1953        }
1954        // At least some trees should have seen samples.
1955        let total_samples: u64 = diag.trees.iter().map(|t| t.samples_seen).sum();
1956        assert!(
1957            total_samples > 0,
1958            "at least some trees should have seen samples"
1959        );
1960    }
1961
1962    #[test]
1963    fn diagnostics_leaf_weight_stats_finite() {
1964        let model = trained_model();
1965        let diag = model.diagnostics();
1966        for (i, tree) in diag.trees.iter().enumerate() {
1967            let (min, max, mean, std) = tree.leaf_weight_stats;
1968            assert!(min.is_finite(), "tree {i} min not finite");
1969            assert!(max.is_finite(), "tree {i} max not finite");
1970            assert!(mean.is_finite(), "tree {i} mean not finite");
1971            assert!(std.is_finite(), "tree {i} std not finite");
1972            assert!(min <= max, "tree {i} min > max");
1973        }
1974    }
1975
1976    #[test]
1977    fn diagnostics_base_predictions_match() {
1978        let model = trained_model();
1979        let diag = model.diagnostics();
1980        assert!(
1981            (diag.location_base - model.predict(&[0.0, 0.0, 0.0]).mu).abs() < 100.0,
1982            "location_base should be plausible"
1983        );
1984    }
1985
1986    #[test]
1987    fn predict_decomposed_reconstructs_prediction() {
1988        let model = trained_model();
1989        let features = [5.0, 2.5, 1.0];
1990        let pred = model.predict(&features);
1991        let decomp = model.predict_decomposed(&features);
1992
1993        assert!(
1994            (decomp.mu() - pred.mu).abs() < 1e-10,
1995            "decomposed mu ({}) != predict mu ({})",
1996            decomp.mu(),
1997            pred.mu
1998        );
1999        assert!(
2000            (decomp.sigma() - pred.sigma).abs() < 1e-10,
2001            "decomposed sigma ({}) != predict sigma ({})",
2002            decomp.sigma(),
2003            pred.sigma
2004        );
2005    }
2006
2007    #[test]
2008    fn predict_decomposed_correct_lengths() {
2009        let model = trained_model();
2010        let decomp = model.predict_decomposed(&[1.0, 0.5, 0.0]);
2011        assert_eq!(
2012            decomp.location_contributions.len(),
2013            model.n_steps(),
2014            "location contributions should have n_steps entries"
2015        );
2016        assert_eq!(
2017            decomp.scale_contributions.len(),
2018            model.n_steps(),
2019            "scale contributions should have n_steps entries"
2020        );
2021    }
2022
2023    #[test]
2024    fn feature_importances_work() {
2025        let model = trained_model();
2026        let imp = model.feature_importances();
2027        // After enough training, importances should be non-empty and non-negative.
2028        // They may or may not sum to 1.0 if no splits have occurred yet.
2029        for (i, &v) in imp.iter().enumerate() {
2030            assert!(v >= 0.0, "importance {i} should be non-negative, got {v}");
2031            assert!(v.is_finite(), "importance {i} should be finite");
2032        }
2033        let sum: f64 = imp.iter().sum();
2034        if sum > 0.0 {
2035            assert!(
2036                (sum - 1.0).abs() < 1e-10,
2037                "non-zero importances should sum to 1.0, got {sum}"
2038            );
2039        }
2040    }
2041
2042    #[test]
2043    fn feature_importances_split_works() {
2044        let model = trained_model();
2045        let (loc_imp, scale_imp) = model.feature_importances_split();
2046        for (name, imp) in [("location", &loc_imp), ("scale", &scale_imp)] {
2047            let sum: f64 = imp.iter().sum();
2048            if sum > 0.0 {
2049                assert!(
2050                    (sum - 1.0).abs() < 1e-10,
2051                    "{name} importances should sum to 1.0, got {sum}"
2052                );
2053            }
2054            for &v in imp.iter() {
2055                assert!(v >= 0.0 && v.is_finite());
2056            }
2057        }
2058    }
2059
2060    // -- Empirical σ tests --
2061
2062    #[test]
2063    fn empirical_sigma_default_mode() {
2064        use crate::ensemble::config::ScaleMode;
2065        let config = test_config();
2066        let model = DistributionalSGBT::new(config);
2067        assert_eq!(model.scale_mode(), ScaleMode::Empirical);
2068    }
2069
2070    #[test]
2071    fn empirical_sigma_tracks_errors() {
2072        let mut model = DistributionalSGBT::new(test_config());
2073
2074        // Train on clean linear data
2075        for i in 0..200 {
2076            let x = i as f64 * 0.1;
2077            model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
2078        }
2079
2080        let sigma_clean = model.empirical_sigma();
2081        assert!(sigma_clean > 0.0, "sigma should be positive");
2082        assert!(sigma_clean.is_finite(), "sigma should be finite");
2083
2084        // Now train on noisy data — sigma should increase
2085        let mut rng: u64 = 42;
2086        for i in 200..400 {
2087            rng ^= rng << 13;
2088            rng ^= rng >> 7;
2089            rng ^= rng << 17;
2090            let noise = (rng % 10000) as f64 / 100.0 - 50.0; // big noise
2091            let x = i as f64 * 0.1;
2092            model.train_one(&(vec![x, x * 0.5], x * 2.0 + noise));
2093        }
2094
2095        let sigma_noisy = model.empirical_sigma();
2096        assert!(
2097            sigma_noisy > sigma_clean,
2098            "noisy regime should increase sigma: clean={sigma_clean}, noisy={sigma_noisy}"
2099        );
2100    }
2101
2102    #[test]
2103    fn empirical_sigma_modulated_lr_adapts() {
2104        let config = SGBTConfig::builder()
2105            .n_steps(10)
2106            .learning_rate(0.1)
2107            .grace_period(20)
2108            .max_depth(4)
2109            .n_bins(16)
2110            .initial_target_count(10)
2111            .uncertainty_modulated_lr(true)
2112            .build()
2113            .unwrap();
2114        let mut model = DistributionalSGBT::new(config);
2115
2116        // Train and verify sigma_ratio changes
2117        for i in 0..300 {
2118            let x = i as f64 * 0.1;
2119            model.train_one(&(vec![x], x * 2.0 + 1.0));
2120        }
2121
2122        let (_, _, sigma_ratio) = model.predict_distributional(&[5.0]);
2123        assert!(sigma_ratio.is_finite());
2124        assert!(
2125            (0.1..=10.0).contains(&sigma_ratio),
2126            "sigma_ratio={sigma_ratio}"
2127        );
2128    }
2129
2130    #[test]
2131    fn tree_chain_mode_trains_scale_trees() {
2132        use crate::ensemble::config::ScaleMode;
2133        let config = SGBTConfig::builder()
2134            .n_steps(10)
2135            .learning_rate(0.1)
2136            .grace_period(10)
2137            .max_depth(4)
2138            .n_bins(16)
2139            .initial_target_count(10)
2140            .scale_mode(ScaleMode::TreeChain)
2141            .build()
2142            .unwrap();
2143        let mut model = DistributionalSGBT::new(config);
2144        assert_eq!(model.scale_mode(), ScaleMode::TreeChain);
2145
2146        for i in 0..500 {
2147            let x = i as f64 * 0.1;
2148            model.train_one(&(vec![x, x * 0.5, (i % 3) as f64], x * 2.0 + 1.0));
2149        }
2150
2151        let pred = model.predict(&[5.0, 2.5, 1.0]);
2152        assert!(pred.mu.is_finite());
2153        assert!(pred.sigma > 0.0);
2154        assert!(pred.sigma.is_finite());
2155    }
2156
2157    #[test]
2158    fn diagnostics_shows_empirical_sigma() {
2159        let model = trained_model();
2160        let diag = model.diagnostics();
2161        assert!(
2162            diag.empirical_sigma > 0.0,
2163            "empirical_sigma should be positive"
2164        );
2165        assert!(
2166            diag.empirical_sigma.is_finite(),
2167            "empirical_sigma should be finite"
2168        );
2169    }
2170
2171    #[test]
2172    fn diagnostics_scale_trees_split_fields() {
2173        let model = trained_model();
2174        let diag = model.diagnostics();
2175        assert_eq!(diag.location_trees.len(), model.n_steps());
2176        assert_eq!(diag.scale_trees.len(), model.n_steps());
2177        // In empirical mode, scale_trees_active might be 0 (trees not trained)
2178        // This is expected and actually the point.
2179    }
2180
2181    #[test]
2182    fn reset_clears_empirical_sigma() {
2183        let mut model = DistributionalSGBT::new(test_config());
2184        for i in 0..200 {
2185            let x = i as f64 * 0.1;
2186            model.train_one(&(vec![x], x * 2.0));
2187        }
2188        model.reset();
2189        // After reset, ewma_sq_err resets to 1.0
2190        assert!((model.empirical_sigma() - 1.0).abs() < 1e-12);
2191    }
2192
2193    #[test]
2194    fn predict_smooth_returns_finite() {
2195        let config = SGBTConfig::builder()
2196            .n_steps(3)
2197            .learning_rate(0.1)
2198            .grace_period(20)
2199            .max_depth(4)
2200            .n_bins(16)
2201            .initial_target_count(10)
2202            .build()
2203            .unwrap();
2204        let mut model = DistributionalSGBT::new(config);
2205
2206        for i in 0..200 {
2207            let x = (i as f64) * 0.1;
2208            let features = vec![x, x.sin()];
2209            let target = 2.0 * x + 1.0;
2210            model.train_one(&(features, target));
2211        }
2212
2213        let pred = model.predict_smooth(&[1.0, 1.0_f64.sin()], 0.5);
2214        assert!(pred.mu.is_finite(), "smooth mu should be finite");
2215        assert!(pred.sigma.is_finite(), "smooth sigma should be finite");
2216        assert!(pred.sigma > 0.0, "smooth sigma should be positive");
2217    }
2218
2219    // -- PD sigma modulation tests --
2220
2221    #[test]
2222    fn sigma_velocity_responds_to_error_spike() {
2223        let config = SGBTConfig::builder()
2224            .n_steps(3)
2225            .learning_rate(0.1)
2226            .grace_period(20)
2227            .max_depth(4)
2228            .n_bins(16)
2229            .initial_target_count(10)
2230            .uncertainty_modulated_lr(true)
2231            .build()
2232            .unwrap();
2233        let mut model = DistributionalSGBT::new(config);
2234
2235        // Stable phase
2236        for i in 0..200 {
2237            let x = (i as f64) * 0.1;
2238            model.train_one(&(vec![x, x.sin()], 2.0 * x + 1.0));
2239        }
2240
2241        let velocity_before = model.sigma_velocity();
2242
2243        // Error spike -- sudden regime change
2244        for i in 0..50 {
2245            let x = (i as f64) * 0.1;
2246            model.train_one(&(vec![x, x.sin()], 100.0 * x + 50.0));
2247        }
2248
2249        let velocity_after = model.sigma_velocity();
2250
2251        // Velocity should be positive (sigma increasing due to errors)
2252        assert!(
2253            velocity_after > velocity_before,
2254            "sigma velocity should increase after error spike: before={}, after={}",
2255            velocity_before,
2256            velocity_after,
2257        );
2258    }
2259
2260    #[test]
2261    fn sigma_velocity_getter_works() {
2262        let config = SGBTConfig::builder()
2263            .n_steps(2)
2264            .learning_rate(0.1)
2265            .grace_period(20)
2266            .max_depth(4)
2267            .n_bins(16)
2268            .initial_target_count(10)
2269            .build()
2270            .unwrap();
2271        let model = DistributionalSGBT::new(config);
2272        // Fresh model should have zero velocity
2273        assert_eq!(model.sigma_velocity(), 0.0);
2274    }
2275
2276    #[test]
2277    fn diagnostics_leaf_sample_counts_populated() {
2278        let config = SGBTConfig::builder()
2279            .n_steps(3)
2280            .learning_rate(0.1)
2281            .grace_period(10)
2282            .max_depth(4)
2283            .n_bins(16)
2284            .initial_target_count(10)
2285            .build()
2286            .unwrap();
2287        let mut model = DistributionalSGBT::new(config);
2288
2289        for i in 0..200 {
2290            let x = (i as f64) * 0.1;
2291            let features = vec![x, x.sin()];
2292            let target = 2.0 * x + 1.0;
2293            model.train_one(&(features, target));
2294        }
2295
2296        let diags = model.diagnostics();
2297        for (ti, tree) in diags.trees.iter().enumerate() {
2298            assert_eq!(
2299                tree.leaf_sample_counts.len(),
2300                tree.n_leaves,
2301                "tree {} should have sample count per leaf",
2302                ti,
2303            );
2304            // Total samples across leaves should equal samples_seen for trees with data
2305            if tree.samples_seen > 0 {
2306                let total: u64 = tree.leaf_sample_counts.iter().sum();
2307                assert!(
2308                    total > 0,
2309                    "tree {} has {} samples_seen but leaf counts sum to 0",
2310                    ti,
2311                    tree.samples_seen,
2312                );
2313            }
2314        }
2315    }
2316
2317    // -------------------------------------------------------------------
2318    // Packed cache tests (dual-path inference)
2319    // -------------------------------------------------------------------
2320
2321    #[test]
2322    fn packed_cache_disabled_by_default() {
2323        let model = DistributionalSGBT::new(test_config());
2324        assert!(!model.has_packed_cache());
2325        assert_eq!(model.config().packed_refresh_interval, 0);
2326    }
2327
2328    #[test]
2329    #[cfg(feature = "_packed_cache_tests_disabled")]
2330    fn packed_cache_refreshes_after_interval() {
2331        let config = SGBTConfig::builder()
2332            .n_steps(5)
2333            .learning_rate(0.1)
2334            .grace_period(5)
2335            .max_depth(3)
2336            .n_bins(8)
2337            .initial_target_count(10)
2338            .packed_refresh_interval(20)
2339            .build()
2340            .unwrap();
2341
2342        let mut model = DistributionalSGBT::new(config);
2343
2344        // Train past initialization (10 samples) + refresh interval (20 samples)
2345        for i in 0..40 {
2346            let x = i as f64 * 0.1;
2347            model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
2348        }
2349
2350        // Cache should exist after enough training
2351        assert!(
2352            model.has_packed_cache(),
2353            "packed cache should exist after training past refresh interval"
2354        );
2355
2356        // Predictions should be finite
2357        let pred = model.predict(&[2.0, 4.0, 1.0]);
2358        assert!(pred.mu.is_finite(), "mu should be finite: {}", pred.mu);
2359    }
2360
2361    #[test]
2362    #[cfg(feature = "_packed_cache_tests_disabled")]
2363    fn packed_cache_matches_full_tree() {
2364        let config = SGBTConfig::builder()
2365            .n_steps(5)
2366            .learning_rate(0.1)
2367            .grace_period(5)
2368            .max_depth(3)
2369            .n_bins(8)
2370            .initial_target_count(10)
2371            .build()
2372            .unwrap();
2373
2374        let mut model = DistributionalSGBT::new(config);
2375
2376        // Train
2377        for i in 0..80 {
2378            let x = i as f64 * 0.1;
2379            model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
2380        }
2381
2382        // Get full-tree prediction (no cache)
2383        assert!(!model.has_packed_cache());
2384        let full_pred = model.predict(&[2.0, 4.0, 1.0]);
2385
2386        // Enable cache and get cached prediction
2387        model.enable_packed_cache(10);
2388        assert!(model.has_packed_cache());
2389        let cached_pred = model.predict(&[2.0, 4.0, 1.0]);
2390
2391        // Should match within f32 precision (f64->f32 conversion loses some precision)
2392        let mu_diff = (full_pred.mu - cached_pred.mu).abs();
2393        assert!(
2394            mu_diff < 0.1,
2395            "packed cache mu ({}) should match full tree mu ({}) within f32 tolerance, diff={}",
2396            cached_pred.mu,
2397            full_pred.mu,
2398            mu_diff
2399        );
2400
2401        // Sigma should be identical (not affected by packed cache)
2402        assert!(
2403            (full_pred.sigma - cached_pred.sigma).abs() < 1e-12,
2404            "sigma should be identical: full={}, cached={}",
2405            full_pred.sigma,
2406            cached_pred.sigma
2407        );
2408    }
2409}
2410
2411#[cfg(test)]
2412#[cfg(feature = "_serde_support")]
2413mod serde_tests {
2414    use super::*;
2415    use crate::SGBTConfig;
2416
2417    fn make_trained_distributional() -> DistributionalSGBT {
2418        let config = SGBTConfig::builder()
2419            .n_steps(5)
2420            .learning_rate(0.1)
2421            .max_depth(3)
2422            .grace_period(2)
2423            .initial_target_count(10)
2424            .build()
2425            .unwrap();
2426        let mut model = DistributionalSGBT::new(config);
2427        for i in 0..50 {
2428            let x = i as f64 * 0.1;
2429            model.train_one(&(vec![x], x.sin()));
2430        }
2431        model
2432    }
2433
2434    #[test]
2435    fn json_round_trip_preserves_predictions() {
2436        let model = make_trained_distributional();
2437        let state = model.to_distributional_state();
2438        let json = crate::serde_support::save_distributional_model(&state).unwrap();
2439        let loaded_state = crate::serde_support::load_distributional_model(&json).unwrap();
2440        let restored = DistributionalSGBT::from_distributional_state(loaded_state);
2441
2442        let test_points = [0.5, 1.0, 2.0, 3.0];
2443        for &x in &test_points {
2444            let orig = model.predict(&[x]);
2445            let rest = restored.predict(&[x]);
2446            assert!(
2447                (orig.mu - rest.mu).abs() < 1e-10,
2448                "JSON round-trip mu mismatch at x={}: {} vs {}",
2449                x,
2450                orig.mu,
2451                rest.mu
2452            );
2453            assert!(
2454                (orig.sigma - rest.sigma).abs() < 1e-10,
2455                "JSON round-trip sigma mismatch at x={}: {} vs {}",
2456                x,
2457                orig.sigma,
2458                rest.sigma
2459            );
2460        }
2461    }
2462
2463    #[test]
2464    fn state_preserves_rolling_sigma_mean() {
2465        let config = SGBTConfig::builder()
2466            .n_steps(5)
2467            .learning_rate(0.1)
2468            .max_depth(3)
2469            .grace_period(2)
2470            .initial_target_count(10)
2471            .uncertainty_modulated_lr(true)
2472            .build()
2473            .unwrap();
2474        let mut model = DistributionalSGBT::new(config);
2475        for i in 0..50 {
2476            let x = i as f64 * 0.1;
2477            model.train_one(&(vec![x], x.sin()));
2478        }
2479        let state = model.to_distributional_state();
2480        assert!(state.uncertainty_modulated_lr);
2481        assert!(state.rolling_sigma_mean >= 0.0);
2482
2483        let restored = DistributionalSGBT::from_distributional_state(state);
2484        assert_eq!(model.n_samples_seen(), restored.n_samples_seen());
2485    }
2486
2487    #[test]
2488    fn auto_bandwidth_computed_distributional() {
2489        let config = SGBTConfig::builder()
2490            .n_steps(3)
2491            .learning_rate(0.1)
2492            .grace_period(10)
2493            .initial_target_count(10)
2494            .build()
2495            .unwrap();
2496        let mut model = DistributionalSGBT::new(config);
2497
2498        for i in 0..200 {
2499            let x = (i as f64) * 0.1;
2500            model.train_one(&(vec![x, x.sin()], 2.0 * x + 1.0));
2501        }
2502
2503        // auto_bandwidths should be populated
2504        let bws = model.auto_bandwidths();
2505        assert_eq!(bws.len(), 2, "should have 2 feature bandwidths");
2506
2507        // Diagnostics should include auto_bandwidths
2508        let diag = model.diagnostics();
2509        assert_eq!(diag.auto_bandwidths.len(), 2);
2510
2511        // TreeDiagnostic should have prediction stats
2512        assert!(diag.location_trees[0].prediction_mean.is_finite());
2513        assert!(diag.location_trees[0].prediction_std.is_finite());
2514
2515        let pred = model.predict(&[1.0, 1.0_f64.sin()]);
2516        assert!(pred.mu.is_finite(), "auto-bandwidth mu should be finite");
2517        assert!(pred.sigma > 0.0, "auto-bandwidth sigma should be positive");
2518    }
2519
2520    #[test]
2521    fn max_leaf_output_clamps_predictions() {
2522        let config = SGBTConfig::builder()
2523            .n_steps(5)
2524            .learning_rate(1.0) // Intentionally large to force extreme leaf weights
2525            .max_leaf_output(0.5)
2526            .build()
2527            .unwrap();
2528        let mut model = DistributionalSGBT::new(config);
2529
2530        // Train with extreme targets to force large leaf weights
2531        for i in 0..200 {
2532            let target = if i % 2 == 0 { 100.0 } else { -100.0 };
2533            let sample = crate::Sample::new(vec![i as f64 % 5.0, (i as f64).sin()], target);
2534            model.train_one(&sample);
2535        }
2536
2537        // Each tree's prediction should be clamped to [-0.5, 0.5]
2538        let pred = model.predict(&[2.0, 0.5]);
2539        assert!(
2540            pred.mu.is_finite(),
2541            "prediction should be finite with clamping"
2542        );
2543    }
2544
2545    #[test]
2546    fn min_hessian_sum_suppresses_fresh_leaves() {
2547        let config = SGBTConfig::builder()
2548            .n_steps(3)
2549            .learning_rate(0.01)
2550            .min_hessian_sum(50.0)
2551            .build()
2552            .unwrap();
2553        let mut model = DistributionalSGBT::new(config);
2554
2555        // Train minimal samples
2556        for i in 0..60 {
2557            let sample = crate::Sample::new(vec![i as f64, (i as f64).sin()], i as f64 * 0.1);
2558            model.train_one(&sample);
2559        }
2560
2561        let pred = model.predict(&[30.0, 0.5]);
2562        assert!(
2563            pred.mu.is_finite(),
2564            "prediction should be finite with min_hessian_sum"
2565        );
2566    }
2567
2568    #[test]
2569    fn predict_interpolated_returns_finite() {
2570        let config = SGBTConfig::builder()
2571            .n_steps(10)
2572            .learning_rate(0.1)
2573            .grace_period(20)
2574            .max_depth(4)
2575            .n_bins(16)
2576            .initial_target_count(10)
2577            .build()
2578            .unwrap();
2579        let mut model = DistributionalSGBT::new(config);
2580        for i in 0..200 {
2581            let x = i as f64 * 0.1;
2582            let sample = crate::Sample::new(vec![x, x.sin()], x.cos());
2583            model.train_one(&sample);
2584        }
2585
2586        let pred = model.predict_interpolated(&[1.0, 0.5]);
2587        assert!(pred.mu.is_finite(), "interpolated mu should be finite");
2588        assert!(pred.sigma > 0.0, "interpolated sigma should be positive");
2589    }
2590
2591    #[test]
2592    fn huber_k_bounds_gradients() {
2593        let config = SGBTConfig::builder()
2594            .n_steps(5)
2595            .learning_rate(0.01)
2596            .huber_k(1.345)
2597            .build()
2598            .unwrap();
2599        let mut model = DistributionalSGBT::new(config);
2600
2601        // Train with occasional extreme outliers
2602        for i in 0..300 {
2603            let target = if i % 50 == 0 {
2604                1000.0
2605            } else {
2606                (i as f64 * 0.1).sin()
2607            };
2608            let sample = crate::Sample::new(vec![i as f64 % 10.0, (i as f64).cos()], target);
2609            model.train_one(&sample);
2610        }
2611
2612        let pred = model.predict(&[5.0, 0.3]);
2613        assert!(
2614            pred.mu.is_finite(),
2615            "Huber-loss mu should be finite despite outliers"
2616        );
2617        assert!(pred.sigma > 0.0, "sigma should be positive");
2618    }
2619
2620    #[test]
2621    fn ensemble_gradient_stats_populated() {
2622        let config = SGBTConfig::builder()
2623            .n_steps(10)
2624            .learning_rate(0.1)
2625            .grace_period(20)
2626            .max_depth(4)
2627            .n_bins(16)
2628            .initial_target_count(10)
2629            .build()
2630            .unwrap();
2631        let mut model = DistributionalSGBT::new(config);
2632        for i in 0..200 {
2633            let x = i as f64 * 0.1;
2634            let sample = crate::Sample::new(vec![x, x.sin()], x.cos());
2635            model.train_one(&sample);
2636        }
2637
2638        let diag = model.diagnostics();
2639        assert!(
2640            diag.ensemble_grad_mean.is_finite(),
2641            "ensemble grad mean should be finite"
2642        );
2643        assert!(
2644            diag.ensemble_grad_std >= 0.0,
2645            "ensemble grad std should be non-negative"
2646        );
2647        assert!(
2648            diag.ensemble_grad_std.is_finite(),
2649            "ensemble grad std should be finite"
2650        );
2651    }
2652
2653    #[test]
2654    fn huber_k_validation() {
2655        let result = SGBTConfig::builder()
2656            .n_steps(5)
2657            .learning_rate(0.01)
2658            .huber_k(-1.0)
2659            .build();
2660        assert!(result.is_err(), "negative huber_k should fail validation");
2661    }
2662
2663    #[test]
2664    fn max_leaf_output_validation() {
2665        let result = SGBTConfig::builder()
2666            .n_steps(5)
2667            .learning_rate(0.01)
2668            .max_leaf_output(-1.0)
2669            .build();
2670        assert!(
2671            result.is_err(),
2672            "negative max_leaf_output should fail validation"
2673        );
2674    }
2675
2676    #[test]
2677    fn predict_sibling_interpolated_varies_with_features() {
2678        let config = SGBTConfig::builder()
2679            .n_steps(10)
2680            .learning_rate(0.1)
2681            .grace_period(10)
2682            .max_depth(6)
2683            .delta(0.1)
2684            .initial_target_count(10)
2685            .build()
2686            .unwrap();
2687        let mut model = DistributionalSGBT::new(config);
2688
2689        for i in 0..2000 {
2690            let x = (i as f64) * 0.01;
2691            let y = x.sin() * x + 0.5 * (x * 2.0).cos();
2692            let sample = crate::Sample::new(vec![x, x * 0.3], y);
2693            model.train_one(&sample);
2694        }
2695
2696        // Verify the method runs correctly and produces finite predictions
2697        let pred = model.predict_sibling_interpolated(&[5.0, 1.5]);
2698        assert!(
2699            pred.mu.is_finite(),
2700            "sibling interpolated mu should be finite"
2701        );
2702        assert!(
2703            pred.sigma > 0.0,
2704            "sibling interpolated sigma should be positive"
2705        );
2706
2707        // Sweep — at minimum, sibling should produce same or more variation than hard
2708        let bws = model.auto_bandwidths();
2709        if bws.iter().any(|&b| b.is_finite()) {
2710            let hard_preds: Vec<f64> = (0..200)
2711                .map(|i| {
2712                    let x = i as f64 * 0.1;
2713                    model.predict(&[x, x * 0.3]).mu
2714                })
2715                .collect();
2716            let hard_changes = hard_preds
2717                .windows(2)
2718                .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2719                .count();
2720
2721            let preds: Vec<f64> = (0..200)
2722                .map(|i| {
2723                    let x = i as f64 * 0.1;
2724                    model.predict_sibling_interpolated(&[x, x * 0.3]).mu
2725                })
2726                .collect();
2727
2728            let sibling_changes = preds
2729                .windows(2)
2730                .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2731                .count();
2732            assert!(
2733                sibling_changes >= hard_changes,
2734                "sibling should produce >= hard changes: sibling={}, hard={}",
2735                sibling_changes,
2736                hard_changes
2737            );
2738        }
2739    }
2740}