Skip to main content

irithyll_core/ensemble/
distributional.rs

1//! Distributional SGBT -- outputs Gaussian N(μ, σ²) instead of a point estimate.
2//!
3//! [`DistributionalSGBT`] supports two scale estimation modes via
4//! [`ScaleMode`]:
5//!
6//! ## Empirical σ (default)
7//!
8//! Tracks an EWMA of squared prediction errors:
9//!
10//! ```text
11//! err = target - mu
12//! ewma_sq_err = alpha * err² + (1 - alpha) * ewma_sq_err
13//! sigma = sqrt(ewma_sq_err)
14//! ```
15//!
16//! Always calibrated (σ literally *is* recent error magnitude), zero tuning,
17//! O(1) memory and compute.  When `uncertainty_modulated_lr` is enabled,
18//! high recent errors → σ large → location LR scales up → faster correction.
19//!
20//! ## Tree chain (NGBoost-style)
21//!
22//! Maintains two independent tree ensembles: one for location (μ), one for
23//! scale (log σ).  Gives feature-conditional uncertainty but requires strong
24//! scale-gradient signal for the trees to split.
25//!
26//! # References
27//!
28//! Duan et al. (2020). "NGBoost: Natural Gradient Boosting for Probabilistic Prediction."
29
30use alloc::vec;
31use alloc::vec::Vec;
32
33use crate::ensemble::config::{SGBTConfig, ScaleMode};
34use crate::ensemble::step::BoostingStep;
35use crate::sample::{Observation, SampleRef};
36use crate::tree::builder::TreeConfig;
37
38/// Cached packed f32 binary for fast location-only inference.
39///
40/// Re-exported periodically from the location ensemble. Predictions use
41/// contiguous BFS-packed memory for cache-optimal tree traversal.
42struct PackedInferenceCache {
43    bytes: Vec<u8>,
44    base: f64,
45    n_features: usize,
46}
47
48impl Clone for PackedInferenceCache {
49    fn clone(&self) -> Self {
50        Self {
51            bytes: self.bytes.clone(),
52            base: self.base,
53            n_features: self.n_features,
54        }
55    }
56}
57
58/// Prediction from a distributional model: full Gaussian N(μ, σ²).
59#[derive(Debug, Clone, Copy)]
60pub struct GaussianPrediction {
61    /// Location parameter (mean).
62    pub mu: f64,
63    /// Scale parameter (standard deviation, always > 0).
64    pub sigma: f64,
65    /// Log of scale parameter (raw model output for scale ensemble).
66    pub log_sigma: f64,
67}
68
69impl GaussianPrediction {
70    /// Lower bound of a symmetric confidence interval.
71    ///
72    /// For 95% CI, use `z = 1.96`.
73    #[inline]
74    pub fn lower(&self, z: f64) -> f64 {
75        self.mu - z * self.sigma
76    }
77
78    /// Upper bound of a symmetric confidence interval.
79    #[inline]
80    pub fn upper(&self, z: f64) -> f64 {
81        self.mu + z * self.sigma
82    }
83}
84
85// ---------------------------------------------------------------------------
86// Diagnostic structs
87// ---------------------------------------------------------------------------
88
89/// Per-tree diagnostic summary.
90#[derive(Debug, Clone)]
91pub struct TreeDiagnostic {
92    /// Number of leaf nodes in this tree.
93    pub n_leaves: usize,
94    /// Maximum depth reached by any leaf.
95    pub max_depth_reached: usize,
96    /// Total samples this tree has seen.
97    pub samples_seen: u64,
98    /// Leaf weight statistics: `(min, max, mean, std)`.
99    pub leaf_weight_stats: (f64, f64, f64, f64),
100    /// Feature indices this tree has split on (non-zero gain).
101    pub split_features: Vec<usize>,
102    /// Per-leaf sample counts showing data distribution across leaves.
103    pub leaf_sample_counts: Vec<u64>,
104    /// Running mean of predictions from this tree (Welford online).
105    pub prediction_mean: f64,
106    /// Running standard deviation of predictions from this tree.
107    pub prediction_std: f64,
108}
109
110/// Full model diagnostics for [`DistributionalSGBT`].
111///
112/// Contains per-tree summaries, feature usage, base predictions, and
113/// empirical σ state.
114#[derive(Debug, Clone)]
115pub struct ModelDiagnostics {
116    /// Per-tree diagnostic summaries (location trees first, then scale trees).
117    pub trees: Vec<TreeDiagnostic>,
118    /// Location trees only (view into `trees`).
119    pub location_trees: Vec<TreeDiagnostic>,
120    /// Scale trees only (view into `trees`).
121    pub scale_trees: Vec<TreeDiagnostic>,
122    /// How many trees each feature is used in (split count per feature).
123    pub feature_split_counts: Vec<usize>,
124    /// Base prediction for location (mean).
125    pub location_base: f64,
126    /// Base prediction for scale (log-sigma).
127    pub scale_base: f64,
128    /// Current empirical σ (`sqrt(ewma_sq_err)`), always available.
129    pub empirical_sigma: f64,
130    /// Scale mode in use.
131    pub scale_mode: ScaleMode,
132    /// Number of scale trees that actually split (>1 leaf). 0 = frozen chain.
133    pub scale_trees_active: usize,
134    /// Per-feature auto-calibrated bandwidths for smooth prediction.
135    /// `f64::INFINITY` means that feature uses hard routing.
136    pub auto_bandwidths: Vec<f64>,
137    /// Ensemble-level gradient running mean.
138    pub ensemble_grad_mean: f64,
139    /// Ensemble-level gradient standard deviation.
140    pub ensemble_grad_std: f64,
141}
142
143/// Decomposed prediction showing each tree's contribution.
144#[derive(Debug, Clone)]
145pub struct DecomposedPrediction {
146    /// Base location prediction (mean of initial targets).
147    pub location_base: f64,
148    /// Base scale prediction (log-sigma of initial targets).
149    pub scale_base: f64,
150    /// Per-step location contributions: `learning_rate * tree_prediction`.
151    /// `location_base + sum(location_contributions)` = μ.
152    pub location_contributions: Vec<f64>,
153    /// Per-step scale contributions: `learning_rate * tree_prediction`.
154    /// `scale_base + sum(scale_contributions)` = log(σ).
155    pub scale_contributions: Vec<f64>,
156}
157
158impl DecomposedPrediction {
159    /// Reconstruct the final μ from base + contributions.
160    pub fn mu(&self) -> f64 {
161        self.location_base + self.location_contributions.iter().sum::<f64>()
162    }
163
164    /// Reconstruct the final log(σ) from base + contributions.
165    pub fn log_sigma(&self) -> f64 {
166        self.scale_base + self.scale_contributions.iter().sum::<f64>()
167    }
168
169    /// Reconstruct the final σ (exponentiated).
170    pub fn sigma(&self) -> f64 {
171        crate::math::exp(self.log_sigma()).max(1e-8)
172    }
173}
174
175// ---------------------------------------------------------------------------
176// DistributionalSGBT
177// ---------------------------------------------------------------------------
178
179/// NGBoost-style distributional streaming gradient boosted trees.
180///
181/// Outputs a full Gaussian predictive distribution N(μ, σ²) by maintaining two
182/// independent ensembles -- one for location (mean) and one for scale (log-sigma).
183///
184/// # Example
185///
186/// ```text
187/// use irithyll::SGBTConfig;
188/// use irithyll::ensemble::distributional::DistributionalSGBT;
189///
190/// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
191/// let mut model = DistributionalSGBT::new(config);
192///
193/// // Train on streaming data
194/// model.train_one(&(vec![1.0, 2.0], 3.5));
195///
196/// // Get full distributional prediction
197/// let pred = model.predict(&[1.0, 2.0]);
198/// println!("mean={}, sigma={}", pred.mu, pred.sigma);
199/// ```
200pub struct DistributionalSGBT {
201    /// Configuration (shared between location and scale ensembles).
202    config: SGBTConfig,
203    /// Location (mean) boosting steps.
204    location_steps: Vec<BoostingStep>,
205    /// Scale (log-sigma) boosting steps (only used in `TreeChain` mode).
206    scale_steps: Vec<BoostingStep>,
207    /// Base prediction for location (mean of initial targets).
208    location_base: f64,
209    /// Base prediction for scale (log of std of initial targets).
210    scale_base: f64,
211    /// Whether base predictions have been initialized.
212    base_initialized: bool,
213    /// Running collection of initial targets for computing base predictions.
214    initial_targets: Vec<f64>,
215    /// Number of initial targets to collect before setting base predictions.
216    initial_target_count: usize,
217    /// Total samples trained.
218    samples_seen: u64,
219    /// RNG state for variant logic.
220    rng_state: u64,
221    /// Whether σ-modulated learning rate is enabled.
222    uncertainty_modulated_lr: bool,
223    /// EWMA of the model's predicted σ -- used as the denominator in σ-ratio.
224    ///
225    /// Updated with alpha = 0.001 (slow adaptation) after each training step.
226    /// Initialized from the standard deviation of the initial target collection.
227    rolling_sigma_mean: f64,
228    /// Scale estimation mode: Empirical (default) or TreeChain.
229    scale_mode: ScaleMode,
230    /// EWMA of squared prediction errors (empirical σ mode).
231    ///
232    /// `sigma = sqrt(ewma_sq_err)`. Updated every training step with
233    /// `ewma_sq_err = alpha * err² + (1 - alpha) * ewma_sq_err`.
234    ewma_sq_err: f64,
235    /// EWMA alpha for empirical σ.
236    empirical_sigma_alpha: f64,
237    /// Previous empirical σ value for computing σ velocity.
238    prev_sigma: f64,
239    /// EWMA-smoothed derivative of empirical σ (σ velocity).
240    /// Positive = σ increasing (errors growing), negative = σ decreasing.
241    sigma_velocity: f64,
242    /// Per-feature auto-calibrated bandwidths for smooth prediction.
243    auto_bandwidths: Vec<f64>,
244    /// Sum of replacement counts at last bandwidth computation.
245    last_replacement_sum: u64,
246    /// Ensemble-level gradient running mean (Welford).
247    ensemble_grad_mean: f64,
248    /// Ensemble-level gradient M2 (Welford sum of squared deviations).
249    ensemble_grad_m2: f64,
250    /// Ensemble-level gradient sample count.
251    ensemble_grad_count: u64,
252    /// Packed f32 cache for fast location-only inference (dual-path).
253    packed_cache: Option<PackedInferenceCache>,
254    /// Samples trained since last packed cache refresh.
255    samples_since_refresh: u64,
256    /// How often to refresh the packed cache (0 = disabled).
257    packed_refresh_interval: u64,
258}
259
260impl Clone for DistributionalSGBT {
261    fn clone(&self) -> Self {
262        Self {
263            config: self.config.clone(),
264            location_steps: self.location_steps.clone(),
265            scale_steps: self.scale_steps.clone(),
266            location_base: self.location_base,
267            scale_base: self.scale_base,
268            base_initialized: self.base_initialized,
269            initial_targets: self.initial_targets.clone(),
270            initial_target_count: self.initial_target_count,
271            samples_seen: self.samples_seen,
272            rng_state: self.rng_state,
273            uncertainty_modulated_lr: self.uncertainty_modulated_lr,
274            rolling_sigma_mean: self.rolling_sigma_mean,
275            scale_mode: self.scale_mode,
276            ewma_sq_err: self.ewma_sq_err,
277            empirical_sigma_alpha: self.empirical_sigma_alpha,
278            prev_sigma: self.prev_sigma,
279            sigma_velocity: self.sigma_velocity,
280            auto_bandwidths: self.auto_bandwidths.clone(),
281            last_replacement_sum: self.last_replacement_sum,
282            ensemble_grad_mean: self.ensemble_grad_mean,
283            ensemble_grad_m2: self.ensemble_grad_m2,
284            ensemble_grad_count: self.ensemble_grad_count,
285            packed_cache: self.packed_cache.clone(),
286            samples_since_refresh: self.samples_since_refresh,
287            packed_refresh_interval: self.packed_refresh_interval,
288        }
289    }
290}
291
292impl core::fmt::Debug for DistributionalSGBT {
293    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
294        let mut s = f.debug_struct("DistributionalSGBT");
295        s.field("n_steps", &self.location_steps.len())
296            .field("samples_seen", &self.samples_seen)
297            .field("location_base", &self.location_base)
298            .field("scale_mode", &self.scale_mode)
299            .field("base_initialized", &self.base_initialized);
300        match self.scale_mode {
301            ScaleMode::Empirical => {
302                s.field("empirical_sigma", &crate::math::sqrt(self.ewma_sq_err));
303            }
304            ScaleMode::TreeChain => {
305                s.field("scale_base", &self.scale_base);
306            }
307        }
308        if self.uncertainty_modulated_lr {
309            s.field("rolling_sigma_mean", &self.rolling_sigma_mean);
310        }
311        s.finish()
312    }
313}
314
315impl DistributionalSGBT {
316    /// Create a new distributional SGBT with the given configuration.
317    ///
318    /// When `scale_mode` is `Empirical` (default), scale trees are still allocated
319    /// but never trained — only the EWMA error tracker produces σ.  When
320    /// `scale_mode` is `TreeChain`, both location and scale ensembles are active.
321    pub fn new(config: SGBTConfig) -> Self {
322        let leaf_decay_alpha = config
323            .leaf_half_life
324            .map(|hl| crate::math::exp(-(crate::math::ln(2.0_f64)) / hl as f64));
325
326        let tree_config = TreeConfig::new()
327            .max_depth(config.max_depth)
328            .n_bins(config.n_bins)
329            .lambda(config.lambda)
330            .gamma(config.gamma)
331            .grace_period(config.grace_period)
332            .delta(config.delta)
333            .feature_subsample_rate(config.feature_subsample_rate)
334            .leaf_decay_alpha_opt(leaf_decay_alpha)
335            .split_reeval_interval_opt(config.split_reeval_interval)
336            .feature_types_opt(config.feature_types.clone())
337            .gradient_clip_sigma_opt(config.gradient_clip_sigma)
338            .monotone_constraints_opt(config.monotone_constraints.clone())
339            .max_leaf_output_opt(config.max_leaf_output)
340            .adaptive_depth_opt(config.adaptive_depth)
341            .min_hessian_sum_opt(config.min_hessian_sum)
342            .leaf_model_type(config.leaf_model_type.clone());
343
344        let max_tree_samples = config.max_tree_samples;
345
346        // Location ensemble (seed offset: 0)
347        let shadow_warmup = config.shadow_warmup.unwrap_or(0);
348        let location_steps: Vec<BoostingStep> = (0..config.n_steps)
349            .map(|i| {
350                let mut tc = tree_config.clone();
351                tc.seed = config.seed ^ (i as u64);
352                let detector = config.drift_detector.create();
353                if shadow_warmup > 0 {
354                    BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
355                } else {
356                    BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
357                }
358            })
359            .collect();
360
361        // Scale ensemble (seed offset: 0xSCALE) -- only trained in TreeChain mode
362        let scale_steps: Vec<BoostingStep> = (0..config.n_steps)
363            .map(|i| {
364                let mut tc = tree_config.clone();
365                tc.seed = config.seed ^ (i as u64) ^ 0x0005_CA1E_0000_0000;
366                let detector = config.drift_detector.create();
367                if shadow_warmup > 0 {
368                    BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
369                } else {
370                    BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
371                }
372            })
373            .collect();
374
375        let seed = config.seed;
376        let initial_target_count = config.initial_target_count;
377        let uncertainty_modulated_lr = config.uncertainty_modulated_lr;
378        let scale_mode = config.scale_mode;
379        let empirical_sigma_alpha = config.empirical_sigma_alpha;
380        let packed_refresh_interval = config.packed_refresh_interval;
381        Self {
382            config,
383            location_steps,
384            scale_steps,
385            location_base: 0.0,
386            scale_base: 0.0,
387            base_initialized: false,
388            initial_targets: Vec::new(),
389            initial_target_count,
390            samples_seen: 0,
391            rng_state: seed,
392            uncertainty_modulated_lr,
393            rolling_sigma_mean: 1.0, // overwritten during base initialization
394            scale_mode,
395            ewma_sq_err: 1.0, // overwritten during base initialization
396            empirical_sigma_alpha,
397            prev_sigma: 0.0,
398            sigma_velocity: 0.0,
399            auto_bandwidths: Vec::new(),
400            last_replacement_sum: 0,
401            ensemble_grad_mean: 0.0,
402            ensemble_grad_m2: 0.0,
403            ensemble_grad_count: 0,
404            packed_cache: None,
405            samples_since_refresh: 0,
406            packed_refresh_interval,
407        }
408    }
409
410    /// Train on a single observation.
411    pub fn train_one(&mut self, sample: &impl Observation) {
412        self.samples_seen += 1;
413        let target = sample.target();
414        let features = sample.features();
415
416        // Initialize base predictions from first few targets
417        if !self.base_initialized {
418            self.initial_targets.push(target);
419            if self.initial_targets.len() >= self.initial_target_count {
420                // Location base = mean
421                let sum: f64 = self.initial_targets.iter().sum();
422                let mean = sum / self.initial_targets.len() as f64;
423                self.location_base = mean;
424
425                // Scale base = log(std) -- clamped for stability
426                let var: f64 = self
427                    .initial_targets
428                    .iter()
429                    .map(|&y| (y - mean) * (y - mean))
430                    .sum::<f64>()
431                    / self.initial_targets.len() as f64;
432                let initial_std = crate::math::sqrt(var).max(1e-6);
433                self.scale_base = crate::math::ln(initial_std);
434
435                // Initialize rolling sigma mean and ewma from initial targets std
436                self.rolling_sigma_mean = initial_std;
437                self.ewma_sq_err = var.max(1e-12);
438
439                // Initialize PD sigma state
440                self.prev_sigma = initial_std;
441                self.sigma_velocity = 0.0;
442
443                self.base_initialized = true;
444                self.initial_targets.clear();
445                self.initial_targets.shrink_to_fit();
446            }
447            return;
448        }
449
450        match self.scale_mode {
451            ScaleMode::Empirical => self.train_one_empirical(target, features),
452            ScaleMode::TreeChain => self.train_one_tree_chain(target, features),
453        }
454
455        // Refresh auto-bandwidths when trees have been replaced
456        self.refresh_bandwidths();
457    }
458
459    /// Empirical-σ training: location trees only, σ from EWMA of squared errors.
460    fn train_one_empirical(&mut self, target: f64, features: &[f64]) {
461        // Current location prediction (before this step's update)
462        let mut mu = self.location_base;
463        for s in 0..self.location_steps.len() {
464            mu += self.config.learning_rate * self.location_steps[s].predict(features);
465        }
466
467        // Empirical sigma from EWMA of squared prediction errors
468        let err = target - mu;
469        let alpha = self.empirical_sigma_alpha;
470        self.ewma_sq_err = (1.0 - alpha) * self.ewma_sq_err + alpha * err * err;
471        let empirical_sigma = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
472
473        // Compute σ-ratio for uncertainty-modulated learning rate (PD controller)
474        let sigma_ratio = if self.uncertainty_modulated_lr {
475            // Compute sigma velocity (derivative of sigma over time)
476            let d_sigma = empirical_sigma - self.prev_sigma;
477            self.prev_sigma = empirical_sigma;
478
479            // EWMA-smooth the velocity (same alpha as empirical sigma for synchronization)
480            self.sigma_velocity = (1.0 - alpha) * self.sigma_velocity + alpha * d_sigma;
481
482            // Adaptive derivative gain: self-calibrating, no config needed
483            let k_d = if self.rolling_sigma_mean > 1e-12 {
484                crate::math::abs(self.sigma_velocity) / self.rolling_sigma_mean
485            } else {
486                0.0
487            };
488
489            // PD ratio: proportional + derivative
490            let pd_sigma = empirical_sigma + k_d * self.sigma_velocity;
491            let ratio = (pd_sigma / self.rolling_sigma_mean).clamp(0.1, 10.0);
492
493            // Update rolling sigma mean with slow EWMA
494            const SIGMA_EWMA_ALPHA: f64 = 0.001;
495            self.rolling_sigma_mean = (1.0 - SIGMA_EWMA_ALPHA) * self.rolling_sigma_mean
496                + SIGMA_EWMA_ALPHA * empirical_sigma;
497
498            ratio
499        } else {
500            1.0
501        };
502
503        let base_lr = self.config.learning_rate;
504
505        // Train location steps only -- no scale trees needed
506        let mut mu_accum = self.location_base;
507        for s in 0..self.location_steps.len() {
508            let (g_mu, h_mu) = self.location_gradient(mu_accum, target);
509            // Welford update for ensemble gradient stats
510            self.update_ensemble_grad_stats(g_mu);
511            let train_count = self.config.variant.train_count(h_mu, &mut self.rng_state);
512            let loc_pred =
513                self.location_steps[s].train_and_predict(features, g_mu, h_mu, train_count);
514            mu_accum += (base_lr * sigma_ratio) * loc_pred;
515        }
516
517        // Refresh packed cache if interval reached
518        self.maybe_refresh_packed_cache();
519    }
520
521    /// Tree-chain training: full NGBoost dual-chain with location + scale trees.
522    fn train_one_tree_chain(&mut self, target: f64, features: &[f64]) {
523        let mut mu = self.location_base;
524        let mut log_sigma = self.scale_base;
525
526        // Compute σ-ratio for uncertainty-modulated learning rate (PD controller).
527        let sigma_ratio = if self.uncertainty_modulated_lr {
528            let current_sigma = crate::math::exp(log_sigma).max(1e-8);
529
530            // Compute sigma velocity (derivative of sigma over time)
531            let d_sigma = current_sigma - self.prev_sigma;
532            self.prev_sigma = current_sigma;
533
534            // EWMA-smooth the velocity (same alpha as empirical sigma for synchronization)
535            let alpha = self.empirical_sigma_alpha;
536            self.sigma_velocity = (1.0 - alpha) * self.sigma_velocity + alpha * d_sigma;
537
538            // Adaptive derivative gain: self-calibrating, no config needed
539            let k_d = if self.rolling_sigma_mean > 1e-12 {
540                crate::math::abs(self.sigma_velocity) / self.rolling_sigma_mean
541            } else {
542                0.0
543            };
544
545            // PD ratio: proportional + derivative
546            let pd_sigma = current_sigma + k_d * self.sigma_velocity;
547            let ratio = (pd_sigma / self.rolling_sigma_mean).clamp(0.1, 10.0);
548
549            const SIGMA_EWMA_ALPHA: f64 = 0.001;
550            self.rolling_sigma_mean = (1.0 - SIGMA_EWMA_ALPHA) * self.rolling_sigma_mean
551                + SIGMA_EWMA_ALPHA * current_sigma;
552
553            ratio
554        } else {
555            1.0
556        };
557
558        let base_lr = self.config.learning_rate;
559
560        // Sequential boosting: both ensembles target their respective residuals
561        for s in 0..self.location_steps.len() {
562            let sigma = crate::math::exp(log_sigma).max(1e-8);
563            let z = (target - mu) / sigma;
564
565            // Location gradients (squared or Huber loss w.r.t. mu)
566            let (g_mu, h_mu) = self.location_gradient(mu, target);
567            // Welford update for ensemble gradient stats
568            self.update_ensemble_grad_stats(g_mu);
569
570            // Scale gradients (NLL w.r.t. log_sigma)
571            let g_sigma = 1.0 - z * z;
572            let h_sigma = (2.0 * z * z).clamp(0.01, 100.0);
573
574            let train_count = self.config.variant.train_count(h_mu, &mut self.rng_state);
575
576            // Train location step -- σ-modulated LR when enabled
577            let loc_pred =
578                self.location_steps[s].train_and_predict(features, g_mu, h_mu, train_count);
579            mu += (base_lr * sigma_ratio) * loc_pred;
580
581            // Train scale step -- ALWAYS at unmodulated base rate.
582            let scale_pred =
583                self.scale_steps[s].train_and_predict(features, g_sigma, h_sigma, train_count);
584            log_sigma += base_lr * scale_pred;
585        }
586
587        // Also update empirical sigma tracker for diagnostics
588        let err = target - mu;
589        let alpha = self.empirical_sigma_alpha;
590        self.ewma_sq_err = (1.0 - alpha) * self.ewma_sq_err + alpha * err * err;
591
592        // Refresh packed cache if interval reached
593        self.maybe_refresh_packed_cache();
594    }
595
596    /// Predict the full Gaussian distribution for a feature vector.
597    ///
598    /// When a packed cache is available, uses it for the location (μ) prediction
599    /// via contiguous BFS-packed memory traversal. Falls back to full tree
600    /// traversal if the cache is absent or produces non-finite results.
601    ///
602    /// Sigma computation always uses the primary path (EWMA or scale chain)
603    /// and is unaffected by the packed cache.
604    pub fn predict(&self, features: &[f64]) -> GaussianPrediction {
605        // Try packed cache for mu if available
606        let mu = if let Some(ref cache) = self.packed_cache {
607            let features_f32: Vec<f32> = features.iter().map(|&v| v as f32).collect();
608            match crate::EnsembleView::from_bytes(&cache.bytes) {
609                Ok(view) => {
610                    let packed_mu = cache.base + view.predict(&features_f32) as f64;
611                    if packed_mu.is_finite() {
612                        packed_mu
613                    } else {
614                        self.predict_full_trees(features)
615                    }
616                }
617                Err(_) => self.predict_full_trees(features),
618            }
619        } else {
620            self.predict_full_trees(features)
621        };
622
623        let (sigma, log_sigma) = match self.scale_mode {
624            ScaleMode::Empirical => {
625                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
626                (s, crate::math::ln(s))
627            }
628            ScaleMode::TreeChain => {
629                let mut ls = self.scale_base;
630                if self.auto_bandwidths.is_empty() {
631                    for s in 0..self.scale_steps.len() {
632                        ls += self.config.learning_rate * self.scale_steps[s].predict(features);
633                    }
634                } else {
635                    for s in 0..self.scale_steps.len() {
636                        ls += self.config.learning_rate
637                            * self.scale_steps[s]
638                                .predict_smooth_auto(features, &self.auto_bandwidths);
639                    }
640                }
641                (crate::math::exp(ls).max(1e-8), ls)
642            }
643        };
644
645        GaussianPrediction {
646            mu,
647            sigma,
648            log_sigma,
649        }
650    }
651
652    /// Full-tree location prediction (fallback when packed cache is unavailable).
653    fn predict_full_trees(&self, features: &[f64]) -> f64 {
654        let mut mu = self.location_base;
655        if self.auto_bandwidths.is_empty() {
656            for s in 0..self.location_steps.len() {
657                mu += self.config.learning_rate * self.location_steps[s].predict(features);
658            }
659        } else {
660            for s in 0..self.location_steps.len() {
661                mu += self.config.learning_rate
662                    * self.location_steps[s].predict_smooth_auto(features, &self.auto_bandwidths);
663            }
664        }
665        mu
666    }
667
668    /// Predict using sigmoid-blended soft routing for smooth interpolation.
669    ///
670    /// Instead of hard left/right routing at tree split nodes, each split
671    /// uses sigmoid blending: `alpha = sigmoid((threshold - feature) / bandwidth)`.
672    /// The result is a continuous function that varies smoothly with every
673    /// feature change.
674    ///
675    /// `bandwidth` controls transition sharpness: smaller = sharper (closer
676    /// to hard splits), larger = smoother.
677    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> GaussianPrediction {
678        let mut mu = self.location_base;
679        for s in 0..self.location_steps.len() {
680            mu += self.config.learning_rate
681                * self.location_steps[s].predict_smooth(features, bandwidth);
682        }
683
684        let (sigma, log_sigma) = match self.scale_mode {
685            ScaleMode::Empirical => {
686                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
687                (s, crate::math::ln(s))
688            }
689            ScaleMode::TreeChain => {
690                let mut ls = self.scale_base;
691                for s in 0..self.scale_steps.len() {
692                    ls += self.config.learning_rate
693                        * self.scale_steps[s].predict_smooth(features, bandwidth);
694                }
695                (crate::math::exp(ls).max(1e-8), ls)
696            }
697        };
698
699        GaussianPrediction {
700            mu,
701            sigma,
702            log_sigma,
703        }
704    }
705
706    /// Predict with parent-leaf linear interpolation.
707    ///
708    /// Blends each leaf prediction with its parent's preserved prediction
709    /// based on sample count, preventing stale predictions from fresh leaves.
710    pub fn predict_interpolated(&self, features: &[f64]) -> GaussianPrediction {
711        let mut mu = self.location_base;
712        for s in 0..self.location_steps.len() {
713            mu += self.config.learning_rate * self.location_steps[s].predict_interpolated(features);
714        }
715
716        let (sigma, log_sigma) = match self.scale_mode {
717            ScaleMode::Empirical => {
718                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
719                (s, crate::math::ln(s))
720            }
721            ScaleMode::TreeChain => {
722                let mut ls = self.scale_base;
723                for s in 0..self.scale_steps.len() {
724                    ls += self.config.learning_rate
725                        * self.scale_steps[s].predict_interpolated(features);
726                }
727                (crate::math::exp(ls).max(1e-8), ls)
728            }
729        };
730
731        GaussianPrediction {
732            mu,
733            sigma,
734            log_sigma,
735        }
736    }
737
738    /// Predict with sibling-based interpolation for feature-continuous predictions.
739    ///
740    /// At each split node near the threshold boundary, blends left and right
741    /// subtree predictions linearly. Uses auto-calibrated bandwidths as the
742    /// interpolation margin. Predictions vary continuously as features change.
743    pub fn predict_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
744        let mut mu = self.location_base;
745        for s in 0..self.location_steps.len() {
746            mu += self.config.learning_rate
747                * self.location_steps[s]
748                    .predict_sibling_interpolated(features, &self.auto_bandwidths);
749        }
750
751        let (sigma, log_sigma) = match self.scale_mode {
752            ScaleMode::Empirical => {
753                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
754                (s, crate::math::ln(s))
755            }
756            ScaleMode::TreeChain => {
757                let mut ls = self.scale_base;
758                for s in 0..self.scale_steps.len() {
759                    ls += self.config.learning_rate
760                        * self.scale_steps[s]
761                            .predict_sibling_interpolated(features, &self.auto_bandwidths);
762                }
763                (crate::math::exp(ls).max(1e-8), ls)
764            }
765        };
766
767        GaussianPrediction {
768            mu,
769            sigma,
770            log_sigma,
771        }
772    }
773
774    /// Predict with graduated active-shadow blending.
775    ///
776    /// Smoothly transitions between active and shadow trees during replacement.
777    /// Requires `shadow_warmup` to be configured.
778    pub fn predict_graduated(&self, features: &[f64]) -> GaussianPrediction {
779        let mut mu = self.location_base;
780        for s in 0..self.location_steps.len() {
781            mu += self.config.learning_rate * self.location_steps[s].predict_graduated(features);
782        }
783
784        let (sigma, log_sigma) = match self.scale_mode {
785            ScaleMode::Empirical => {
786                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
787                (s, crate::math::ln(s))
788            }
789            ScaleMode::TreeChain => {
790                let mut ls = self.scale_base;
791                for s in 0..self.scale_steps.len() {
792                    ls +=
793                        self.config.learning_rate * self.scale_steps[s].predict_graduated(features);
794                }
795                (crate::math::exp(ls).max(1e-8), ls)
796            }
797        };
798
799        GaussianPrediction {
800            mu,
801            sigma,
802            log_sigma,
803        }
804    }
805
806    /// Predict with graduated blending + sibling interpolation (premium path).
807    pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
808        let mut mu = self.location_base;
809        for s in 0..self.location_steps.len() {
810            mu += self.config.learning_rate
811                * self.location_steps[s]
812                    .predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
813        }
814
815        let (sigma, log_sigma) = match self.scale_mode {
816            ScaleMode::Empirical => {
817                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
818                (s, crate::math::ln(s))
819            }
820            ScaleMode::TreeChain => {
821                let mut ls = self.scale_base;
822                for s in 0..self.scale_steps.len() {
823                    ls += self.config.learning_rate
824                        * self.scale_steps[s].predict_graduated_sibling_interpolated(
825                            features,
826                            &self.auto_bandwidths,
827                        );
828                }
829                (crate::math::exp(ls).max(1e-8), ls)
830            }
831        };
832
833        GaussianPrediction {
834            mu,
835            sigma,
836            log_sigma,
837        }
838    }
839
840    /// Predict with σ-ratio diagnostic exposed.
841    ///
842    /// Returns `(mu, sigma, sigma_ratio)` where `sigma_ratio` is
843    /// `current_sigma / rolling_sigma_mean` -- the multiplier applied to the
844    /// location learning rate when [`uncertainty_modulated_lr`](SGBTConfig::uncertainty_modulated_lr)
845    /// is enabled.
846    ///
847    /// When σ-modulation is disabled, `sigma_ratio` is always `1.0`.
848    pub fn predict_distributional(&self, features: &[f64]) -> (f64, f64, f64) {
849        let pred = self.predict(features);
850        let sigma_ratio = if self.uncertainty_modulated_lr {
851            (pred.sigma / self.rolling_sigma_mean).clamp(0.1, 10.0)
852        } else {
853            1.0
854        };
855        (pred.mu, pred.sigma, sigma_ratio)
856    }
857
858    /// Current empirical sigma (`sqrt(ewma_sq_err)`).
859    ///
860    /// Returns the model's recent error magnitude. Available in both scale modes.
861    #[inline]
862    pub fn empirical_sigma(&self) -> f64 {
863        crate::math::sqrt(self.ewma_sq_err)
864    }
865
866    /// Current scale mode.
867    #[inline]
868    pub fn scale_mode(&self) -> ScaleMode {
869        self.scale_mode
870    }
871
872    /// Current σ velocity -- the EWMA-smoothed derivative of empirical σ.
873    ///
874    /// Positive values indicate growing prediction errors (model deteriorating
875    /// or regime change). Negative values indicate improving predictions.
876    /// Only meaningful when `ScaleMode::Empirical` is active.
877    #[inline]
878    pub fn sigma_velocity(&self) -> f64 {
879        self.sigma_velocity
880    }
881
882    /// Predict the mean (location parameter) only.
883    #[inline]
884    pub fn predict_mu(&self, features: &[f64]) -> f64 {
885        self.predict(features).mu
886    }
887
888    /// Predict the standard deviation (scale parameter) only.
889    #[inline]
890    pub fn predict_sigma(&self, features: &[f64]) -> f64 {
891        self.predict(features).sigma
892    }
893
894    /// Predict a symmetric confidence interval.
895    ///
896    /// `confidence` is the Z-score multiplier:
897    /// - 1.0 → 68% CI
898    /// - 1.96 → 95% CI
899    /// - 2.576 → 99% CI
900    pub fn predict_interval(&self, features: &[f64], confidence: f64) -> (f64, f64) {
901        let pred = self.predict(features);
902        (pred.lower(confidence), pred.upper(confidence))
903    }
904
905    /// Batch prediction.
906    pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<GaussianPrediction> {
907        feature_matrix.iter().map(|f| self.predict(f)).collect()
908    }
909
910    /// Train on a batch of observations.
911    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
912        for sample in samples {
913            self.train_one(sample);
914        }
915    }
916
917    /// Train on a batch with periodic callback.
918    pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
919        &mut self,
920        samples: &[O],
921        interval: usize,
922        mut callback: F,
923    ) {
924        let interval = interval.max(1);
925        for (i, sample) in samples.iter().enumerate() {
926            self.train_one(sample);
927            if (i + 1) % interval == 0 {
928                callback(i + 1);
929            }
930        }
931        let total = samples.len();
932        if total % interval != 0 {
933            callback(total);
934        }
935    }
936
937    /// Compute location gradient (squared or adaptive Huber).
938    ///
939    /// When `huber_k` is configured, uses Huber loss with adaptive
940    /// `delta = k * empirical_sigma` for bounded gradients.
941    #[inline]
942    fn location_gradient(&self, mu: f64, target: f64) -> (f64, f64) {
943        if let Some(k) = self.config.huber_k {
944            let delta = k * crate::math::sqrt(self.ewma_sq_err).max(1e-8);
945            let residual = mu - target;
946            if crate::math::abs(residual) <= delta {
947                (residual, 1.0)
948            } else {
949                (delta * residual.signum(), 1e-6)
950            }
951        } else {
952            (mu - target, 1.0)
953        }
954    }
955
956    /// Welford update for ensemble-level gradient statistics.
957    #[inline]
958    fn update_ensemble_grad_stats(&mut self, gradient: f64) {
959        self.ensemble_grad_count += 1;
960        let delta = gradient - self.ensemble_grad_mean;
961        self.ensemble_grad_mean += delta / self.ensemble_grad_count as f64;
962        let delta2 = gradient - self.ensemble_grad_mean;
963        self.ensemble_grad_m2 += delta * delta2;
964    }
965
966    /// Ensemble-level gradient standard deviation.
967    pub fn ensemble_grad_std(&self) -> f64 {
968        if self.ensemble_grad_count < 2 {
969            return 0.0;
970        }
971        crate::math::fmax(
972            crate::math::sqrt(self.ensemble_grad_m2 / (self.ensemble_grad_count - 1) as f64),
973            0.0,
974        )
975    }
976
977    /// Ensemble-level gradient mean.
978    pub fn ensemble_grad_mean(&self) -> f64 {
979        self.ensemble_grad_mean
980    }
981
982    /// Check if packed cache should be refreshed and do so if interval reached.
983    fn maybe_refresh_packed_cache(&mut self) {
984        if self.packed_refresh_interval > 0 {
985            self.samples_since_refresh += 1;
986            if self.samples_since_refresh >= self.packed_refresh_interval {
987                self.refresh_packed_cache();
988                self.samples_since_refresh = 0;
989            }
990        }
991    }
992
993    /// Re-export the location ensemble into the packed cache.
994    ///
995    /// In `irithyll-core` this is a no-op because `export_embedded` is not
996    /// available. The full `irithyll` crate overrides this via its re-export
997    /// layer to populate the cache with packed f32 bytes.
998    fn refresh_packed_cache(&mut self) {
999        // export_distributional_packed lives in irithyll (not irithyll-core),
1000        // so packed cache is not populated in core-only builds.
1001        // The full irithyll crate hooks into this via its own DistributionalSGBT
1002        // re-export which adds the packed cache logic.
1003    }
1004
1005    /// Enable or reconfigure the packed inference cache at runtime.
1006    ///
1007    /// Sets the refresh interval and immediately builds the initial cache
1008    /// if the model has been initialized. Pass `0` to disable.
1009    pub fn enable_packed_cache(&mut self, interval: u64) {
1010        self.packed_refresh_interval = interval;
1011        self.samples_since_refresh = 0;
1012        if interval > 0 && self.base_initialized {
1013            self.refresh_packed_cache();
1014        } else if interval == 0 {
1015            self.packed_cache = None;
1016        }
1017    }
1018
1019    /// Whether the packed inference cache is currently populated.
1020    #[inline]
1021    pub fn has_packed_cache(&self) -> bool {
1022        self.packed_cache.is_some()
1023    }
1024
1025    /// Refresh auto-bandwidths if any tree has been replaced since last computation.
1026    fn refresh_bandwidths(&mut self) {
1027        let current_sum: u64 = self
1028            .location_steps
1029            .iter()
1030            .chain(self.scale_steps.iter())
1031            .map(|s| s.slot().replacements())
1032            .sum();
1033        if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
1034            self.auto_bandwidths = self.compute_auto_bandwidths();
1035            self.last_replacement_sum = current_sum;
1036        }
1037    }
1038
1039    /// Compute per-feature auto-calibrated bandwidths from all trees.
1040    fn compute_auto_bandwidths(&self) -> Vec<f64> {
1041        const K: f64 = 2.0;
1042
1043        let n_features = self
1044            .location_steps
1045            .iter()
1046            .chain(self.scale_steps.iter())
1047            .filter_map(|s| s.slot().active_tree().n_features())
1048            .max()
1049            .unwrap_or(0);
1050
1051        if n_features == 0 {
1052            return Vec::new();
1053        }
1054
1055        let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
1056
1057        for step in self.location_steps.iter().chain(self.scale_steps.iter()) {
1058            let tree_thresholds = step
1059                .slot()
1060                .active_tree()
1061                .collect_split_thresholds_per_feature();
1062            for (i, ts) in tree_thresholds.into_iter().enumerate() {
1063                if i < n_features {
1064                    all_thresholds[i].extend(ts);
1065                }
1066            }
1067        }
1068
1069        let n_bins = self.config.n_bins as f64;
1070
1071        all_thresholds
1072            .iter()
1073            .map(|ts| {
1074                if ts.is_empty() {
1075                    return f64::INFINITY;
1076                }
1077
1078                let mut sorted = ts.clone();
1079                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
1080                sorted.dedup_by(|a, b| crate::math::abs(*a - *b) < 1e-15);
1081
1082                if sorted.len() < 2 {
1083                    return f64::INFINITY;
1084                }
1085
1086                let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
1087
1088                if sorted.len() < 3 {
1089                    let range = sorted.last().unwrap() - sorted.first().unwrap();
1090                    if range < 1e-15 {
1091                        return f64::INFINITY;
1092                    }
1093                    return (range / n_bins) * K;
1094                }
1095
1096                gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
1097                let median_gap = if gaps.len() % 2 == 0 {
1098                    (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
1099                } else {
1100                    gaps[gaps.len() / 2]
1101                };
1102
1103                if median_gap < 1e-15 {
1104                    f64::INFINITY
1105                } else {
1106                    median_gap * K
1107                }
1108            })
1109            .collect()
1110    }
1111
1112    /// Per-feature auto-calibrated bandwidths used by `predict()`.
1113    pub fn auto_bandwidths(&self) -> &[f64] {
1114        &self.auto_bandwidths
1115    }
1116
1117    /// Reset to initial untrained state.
1118    pub fn reset(&mut self) {
1119        for step in &mut self.location_steps {
1120            step.reset();
1121        }
1122        for step in &mut self.scale_steps {
1123            step.reset();
1124        }
1125        self.location_base = 0.0;
1126        self.scale_base = 0.0;
1127        self.base_initialized = false;
1128        self.initial_targets.clear();
1129        self.samples_seen = 0;
1130        self.rng_state = self.config.seed;
1131        self.rolling_sigma_mean = 1.0;
1132        self.ewma_sq_err = 1.0;
1133        self.prev_sigma = 0.0;
1134        self.sigma_velocity = 0.0;
1135        self.auto_bandwidths.clear();
1136        self.last_replacement_sum = 0;
1137        self.ensemble_grad_mean = 0.0;
1138        self.ensemble_grad_m2 = 0.0;
1139        self.ensemble_grad_count = 0;
1140        self.packed_cache = None;
1141        self.samples_since_refresh = 0;
1142    }
1143
1144    /// Total samples trained.
1145    #[inline]
1146    pub fn n_samples_seen(&self) -> u64 {
1147        self.samples_seen
1148    }
1149
1150    /// Number of boosting steps (same for location and scale).
1151    #[inline]
1152    pub fn n_steps(&self) -> usize {
1153        self.location_steps.len()
1154    }
1155
1156    /// Total trees (location + scale, active + alternates).
1157    pub fn n_trees(&self) -> usize {
1158        let loc = self.location_steps.len()
1159            + self
1160                .location_steps
1161                .iter()
1162                .filter(|s| s.has_alternate())
1163                .count();
1164        let scale = self.scale_steps.len()
1165            + self
1166                .scale_steps
1167                .iter()
1168                .filter(|s| s.has_alternate())
1169                .count();
1170        loc + scale
1171    }
1172
1173    /// Total leaves across all active trees (location + scale).
1174    pub fn total_leaves(&self) -> usize {
1175        let loc: usize = self.location_steps.iter().map(|s| s.n_leaves()).sum();
1176        let scale: usize = self.scale_steps.iter().map(|s| s.n_leaves()).sum();
1177        loc + scale
1178    }
1179
1180    /// Whether base predictions have been initialized.
1181    #[inline]
1182    pub fn is_initialized(&self) -> bool {
1183        self.base_initialized
1184    }
1185
1186    /// Access the configuration.
1187    #[inline]
1188    pub fn config(&self) -> &SGBTConfig {
1189        &self.config
1190    }
1191
1192    /// Access the location boosting steps (for export/inspection).
1193    pub fn location_steps(&self) -> &[BoostingStep] {
1194        &self.location_steps
1195    }
1196
1197    /// Base prediction for the location (mean) ensemble.
1198    #[inline]
1199    pub fn location_base(&self) -> f64 {
1200        self.location_base
1201    }
1202
1203    /// Learning rate from the model configuration.
1204    #[inline]
1205    pub fn learning_rate(&self) -> f64 {
1206        self.config.learning_rate
1207    }
1208
1209    /// Current rolling σ mean (EWMA of predicted σ).
1210    ///
1211    /// Returns `1.0` if the model hasn't been initialized yet.
1212    #[inline]
1213    pub fn rolling_sigma_mean(&self) -> f64 {
1214        self.rolling_sigma_mean
1215    }
1216
1217    /// Whether σ-modulated learning rate is active.
1218    #[inline]
1219    pub fn is_uncertainty_modulated(&self) -> bool {
1220        self.uncertainty_modulated_lr
1221    }
1222
1223    // -------------------------------------------------------------------
1224    // Diagnostics
1225    // -------------------------------------------------------------------
1226
1227    /// Full model diagnostics: per-tree structure, feature usage, base predictions.
1228    ///
1229    /// The `trees` vector contains location trees first (indices `0..n_steps`),
1230    /// then scale trees (`n_steps..2*n_steps`).
1231    ///
1232    /// `scale_trees_active` counts how many scale trees have actually split
1233    /// (more than 1 leaf). If this is 0, the scale chain is effectively frozen.
1234    pub fn diagnostics(&self) -> ModelDiagnostics {
1235        let n = self.location_steps.len();
1236        let mut trees = Vec::with_capacity(2 * n);
1237        let mut feature_split_counts: Vec<usize> = Vec::new();
1238
1239        fn collect_tree_diags(
1240            steps: &[BoostingStep],
1241            trees: &mut Vec<TreeDiagnostic>,
1242            feature_split_counts: &mut Vec<usize>,
1243        ) {
1244            for step in steps {
1245                let slot = step.slot();
1246                let tree = slot.active_tree();
1247                let arena = tree.arena();
1248
1249                let leaf_values: Vec<f64> = (0..arena.is_leaf.len())
1250                    .filter(|&i| arena.is_leaf[i])
1251                    .map(|i| arena.leaf_value[i])
1252                    .collect();
1253
1254                let leaf_sample_counts: Vec<u64> = (0..arena.is_leaf.len())
1255                    .filter(|&i| arena.is_leaf[i])
1256                    .map(|i| arena.sample_count[i])
1257                    .collect();
1258
1259                let max_depth_reached = (0..arena.is_leaf.len())
1260                    .filter(|&i| arena.is_leaf[i])
1261                    .map(|i| arena.depth[i] as usize)
1262                    .max()
1263                    .unwrap_or(0);
1264
1265                let leaf_weight_stats = if leaf_values.is_empty() {
1266                    (0.0, 0.0, 0.0, 0.0)
1267                } else {
1268                    let min = leaf_values.iter().cloned().fold(f64::INFINITY, f64::min);
1269                    let max = leaf_values
1270                        .iter()
1271                        .cloned()
1272                        .fold(f64::NEG_INFINITY, f64::max);
1273                    let sum: f64 = leaf_values.iter().sum();
1274                    let mean = sum / leaf_values.len() as f64;
1275                    let var: f64 = leaf_values
1276                        .iter()
1277                        .map(|v| crate::math::powi(v - mean, 2))
1278                        .sum::<f64>()
1279                        / leaf_values.len() as f64;
1280                    (min, max, mean, crate::math::sqrt(var))
1281                };
1282
1283                let gains = slot.split_gains();
1284                let split_features: Vec<usize> = gains
1285                    .iter()
1286                    .enumerate()
1287                    .filter(|(_, &g)| g > 0.0)
1288                    .map(|(i, _)| i)
1289                    .collect();
1290
1291                if !gains.is_empty() {
1292                    if feature_split_counts.is_empty() {
1293                        feature_split_counts.resize(gains.len(), 0);
1294                    }
1295                    for &fi in &split_features {
1296                        if fi < feature_split_counts.len() {
1297                            feature_split_counts[fi] += 1;
1298                        }
1299                    }
1300                }
1301
1302                trees.push(TreeDiagnostic {
1303                    n_leaves: leaf_values.len(),
1304                    max_depth_reached,
1305                    samples_seen: step.n_samples_seen(),
1306                    leaf_weight_stats,
1307                    split_features,
1308                    leaf_sample_counts,
1309                    prediction_mean: slot.prediction_mean(),
1310                    prediction_std: slot.prediction_std(),
1311                });
1312            }
1313        }
1314
1315        collect_tree_diags(&self.location_steps, &mut trees, &mut feature_split_counts);
1316        collect_tree_diags(&self.scale_steps, &mut trees, &mut feature_split_counts);
1317
1318        let location_trees = trees[..n].to_vec();
1319        let scale_trees = trees[n..].to_vec();
1320        let scale_trees_active = scale_trees.iter().filter(|t| t.n_leaves > 1).count();
1321
1322        ModelDiagnostics {
1323            trees,
1324            location_trees,
1325            scale_trees,
1326            feature_split_counts,
1327            location_base: self.location_base,
1328            scale_base: self.scale_base,
1329            empirical_sigma: crate::math::sqrt(self.ewma_sq_err),
1330            scale_mode: self.scale_mode,
1331            scale_trees_active,
1332            auto_bandwidths: self.auto_bandwidths.clone(),
1333            ensemble_grad_mean: self.ensemble_grad_mean,
1334            ensemble_grad_std: self.ensemble_grad_std(),
1335        }
1336    }
1337
1338    /// Per-tree contribution to the final prediction.
1339    ///
1340    /// Returns two vectors: location contributions and scale contributions.
1341    /// Each entry is `learning_rate * tree_prediction` -- the additive
1342    /// contribution of that boosting step to the final μ or log(σ).
1343    ///
1344    /// Summing `location_base + sum(location_contributions)` recovers μ.
1345    /// Summing `scale_base + sum(scale_contributions)` recovers log(σ).
1346    ///
1347    /// In `Empirical` scale mode, `scale_base` is `ln(empirical_sigma)` and
1348    /// `scale_contributions` are all zero (σ is not tree-derived).
1349    pub fn predict_decomposed(&self, features: &[f64]) -> DecomposedPrediction {
1350        let lr = self.config.learning_rate;
1351        let location: Vec<f64> = self
1352            .location_steps
1353            .iter()
1354            .map(|s| lr * s.predict(features))
1355            .collect();
1356
1357        let (sb, scale) = match self.scale_mode {
1358            ScaleMode::Empirical => {
1359                let empirical_sigma = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
1360                (
1361                    crate::math::ln(empirical_sigma),
1362                    vec![0.0; self.location_steps.len()],
1363                )
1364            }
1365            ScaleMode::TreeChain => {
1366                let s: Vec<f64> = self
1367                    .scale_steps
1368                    .iter()
1369                    .map(|s| lr * s.predict(features))
1370                    .collect();
1371                (self.scale_base, s)
1372            }
1373        };
1374
1375        DecomposedPrediction {
1376            location_base: self.location_base,
1377            scale_base: sb,
1378            location_contributions: location,
1379            scale_contributions: scale,
1380        }
1381    }
1382
1383    /// Feature importances based on accumulated split gains across all trees.
1384    ///
1385    /// Aggregates gains from both location and scale ensembles, then
1386    /// normalizes to sum to 1.0. Indexed by feature.
1387    /// Returns an empty Vec if no splits have occurred yet.
1388    pub fn feature_importances(&self) -> Vec<f64> {
1389        let mut totals: Vec<f64> = Vec::new();
1390        for steps in [&self.location_steps, &self.scale_steps] {
1391            for step in steps {
1392                let gains = step.slot().split_gains();
1393                if totals.is_empty() && !gains.is_empty() {
1394                    totals.resize(gains.len(), 0.0);
1395                }
1396                for (i, &g) in gains.iter().enumerate() {
1397                    if i < totals.len() {
1398                        totals[i] += g;
1399                    }
1400                }
1401            }
1402        }
1403        let sum: f64 = totals.iter().sum();
1404        if sum > 0.0 {
1405            totals.iter_mut().for_each(|v| *v /= sum);
1406        }
1407        totals
1408    }
1409
1410    /// Feature importances split by ensemble: `(location_importances, scale_importances)`.
1411    ///
1412    /// Each vector is independently normalized to sum to 1.0.
1413    /// Useful for understanding which features drive the mean vs. the uncertainty.
1414    pub fn feature_importances_split(&self) -> (Vec<f64>, Vec<f64>) {
1415        fn aggregate(steps: &[BoostingStep]) -> Vec<f64> {
1416            let mut totals: Vec<f64> = Vec::new();
1417            for step in steps {
1418                let gains = step.slot().split_gains();
1419                if totals.is_empty() && !gains.is_empty() {
1420                    totals.resize(gains.len(), 0.0);
1421                }
1422                for (i, &g) in gains.iter().enumerate() {
1423                    if i < totals.len() {
1424                        totals[i] += g;
1425                    }
1426                }
1427            }
1428            let sum: f64 = totals.iter().sum();
1429            if sum > 0.0 {
1430                totals.iter_mut().for_each(|v| *v /= sum);
1431            }
1432            totals
1433        }
1434        (
1435            aggregate(&self.location_steps),
1436            aggregate(&self.scale_steps),
1437        )
1438    }
1439
1440    /// Convert this model into a serializable [`crate::serde_support::DistributionalModelState`].
1441    ///
1442    /// Captures the full ensemble state (both location and scale trees) for
1443    /// persistence. Histogram accumulators are NOT serialized -- they rebuild
1444    /// naturally from continued training.
1445    #[cfg(feature = "_serde_support")]
1446    pub fn to_distributional_state(&self) -> crate::serde_support::DistributionalModelState {
1447        use super::snapshot_tree;
1448        use crate::serde_support::{DistributionalModelState, StepSnapshot};
1449
1450        fn snapshot_step(step: &BoostingStep) -> StepSnapshot {
1451            let slot = step.slot();
1452            let tree_snap = snapshot_tree(slot.active_tree());
1453            let alt_snap = slot.alternate_tree().map(snapshot_tree);
1454            let drift_state = slot.detector().serialize_state();
1455            let alt_drift_state = slot.alt_detector().and_then(|d| d.serialize_state());
1456            StepSnapshot {
1457                tree: tree_snap,
1458                alternate_tree: alt_snap,
1459                drift_state,
1460                alt_drift_state,
1461            }
1462        }
1463
1464        DistributionalModelState {
1465            config: self.config.clone(),
1466            location_steps: self.location_steps.iter().map(snapshot_step).collect(),
1467            scale_steps: self.scale_steps.iter().map(snapshot_step).collect(),
1468            location_base: self.location_base,
1469            scale_base: self.scale_base,
1470            base_initialized: self.base_initialized,
1471            initial_targets: self.initial_targets.clone(),
1472            initial_target_count: self.initial_target_count,
1473            samples_seen: self.samples_seen,
1474            rng_state: self.rng_state,
1475            uncertainty_modulated_lr: self.uncertainty_modulated_lr,
1476            rolling_sigma_mean: self.rolling_sigma_mean,
1477            ewma_sq_err: self.ewma_sq_err,
1478        }
1479    }
1480
1481    /// Reconstruct a [`DistributionalSGBT`] from a serialized [`crate::serde_support::DistributionalModelState`].
1482    ///
1483    /// Rebuilds both location and scale ensembles including tree topology
1484    /// and leaf values. Histogram accumulators are left empty and will
1485    /// rebuild from continued training.
1486    #[cfg(feature = "_serde_support")]
1487    pub fn from_distributional_state(
1488        state: crate::serde_support::DistributionalModelState,
1489    ) -> Self {
1490        use super::rebuild_tree;
1491        use crate::ensemble::replacement::TreeSlot;
1492        use crate::serde_support::StepSnapshot;
1493
1494        let leaf_decay_alpha = state
1495            .config
1496            .leaf_half_life
1497            .map(|hl| crate::math::exp((-(crate::math::ln(2.0_f64)) / hl as f64)));
1498        let max_tree_samples = state.config.max_tree_samples;
1499
1500        let base_tree_config = TreeConfig::new()
1501            .max_depth(state.config.max_depth)
1502            .n_bins(state.config.n_bins)
1503            .lambda(state.config.lambda)
1504            .gamma(state.config.gamma)
1505            .grace_period(state.config.grace_period)
1506            .delta(state.config.delta)
1507            .feature_subsample_rate(state.config.feature_subsample_rate)
1508            .leaf_decay_alpha_opt(leaf_decay_alpha)
1509            .split_reeval_interval_opt(state.config.split_reeval_interval)
1510            .feature_types_opt(state.config.feature_types.clone())
1511            .gradient_clip_sigma_opt(state.config.gradient_clip_sigma)
1512            .monotone_constraints_opt(state.config.monotone_constraints.clone())
1513            .adaptive_depth_opt(state.config.adaptive_depth)
1514            .leaf_model_type(state.config.leaf_model_type.clone());
1515
1516        // Rebuild a Vec<BoostingStep> from step snapshots with a given seed transform.
1517        let rebuild_steps = |snaps: &[StepSnapshot], seed_xor: u64| -> Vec<BoostingStep> {
1518            snaps
1519                .iter()
1520                .enumerate()
1521                .map(|(i, snap)| {
1522                    let tc = base_tree_config
1523                        .clone()
1524                        .seed(state.config.seed ^ (i as u64) ^ seed_xor);
1525
1526                    let active = rebuild_tree(&snap.tree, tc.clone());
1527                    let alternate = snap
1528                        .alternate_tree
1529                        .as_ref()
1530                        .map(|s| rebuild_tree(s, tc.clone()));
1531
1532                    let mut detector = state.config.drift_detector.create();
1533                    if let Some(ref ds) = snap.drift_state {
1534                        detector.restore_state(ds);
1535                    }
1536                    let mut slot =
1537                        TreeSlot::from_trees(active, alternate, tc, detector, max_tree_samples);
1538                    if let Some(ref ads) = snap.alt_drift_state {
1539                        if let Some(alt_det) = slot.alt_detector_mut() {
1540                            alt_det.restore_state(ads);
1541                        }
1542                    }
1543                    BoostingStep::from_slot(slot)
1544                })
1545                .collect()
1546        };
1547
1548        // Location: seed offset 0, Scale: seed offset 0x0005_CA1E_0000_0000
1549        let location_steps = rebuild_steps(&state.location_steps, 0);
1550        let scale_steps = rebuild_steps(&state.scale_steps, 0x0005_CA1E_0000_0000);
1551
1552        let scale_mode = state.config.scale_mode;
1553        let empirical_sigma_alpha = state.config.empirical_sigma_alpha;
1554        let packed_refresh_interval = state.config.packed_refresh_interval;
1555        Self {
1556            config: state.config,
1557            location_steps,
1558            scale_steps,
1559            location_base: state.location_base,
1560            scale_base: state.scale_base,
1561            base_initialized: state.base_initialized,
1562            initial_targets: state.initial_targets,
1563            initial_target_count: state.initial_target_count,
1564            samples_seen: state.samples_seen,
1565            rng_state: state.rng_state,
1566            uncertainty_modulated_lr: state.uncertainty_modulated_lr,
1567            rolling_sigma_mean: state.rolling_sigma_mean,
1568            scale_mode,
1569            ewma_sq_err: state.ewma_sq_err,
1570            empirical_sigma_alpha,
1571            prev_sigma: 0.0,
1572            sigma_velocity: 0.0,
1573            auto_bandwidths: Vec::new(),
1574            last_replacement_sum: 0,
1575            ensemble_grad_mean: 0.0,
1576            ensemble_grad_m2: 0.0,
1577            ensemble_grad_count: 0,
1578            packed_cache: None,
1579            samples_since_refresh: 0,
1580            packed_refresh_interval,
1581        }
1582    }
1583}
1584
1585// ---------------------------------------------------------------------------
1586// StreamingLearner impl
1587// ---------------------------------------------------------------------------
1588
1589use crate::learner::StreamingLearner;
1590
1591impl StreamingLearner for DistributionalSGBT {
1592    fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
1593        let sample = SampleRef::weighted(features, target, weight);
1594        // UFCS: call the inherent train_one(&impl Observation), not this trait method.
1595        DistributionalSGBT::train_one(self, &sample);
1596    }
1597
1598    /// Returns the mean (μ) of the predicted Gaussian distribution.
1599    fn predict(&self, features: &[f64]) -> f64 {
1600        DistributionalSGBT::predict(self, features).mu
1601    }
1602
1603    fn n_samples_seen(&self) -> u64 {
1604        self.samples_seen
1605    }
1606
1607    fn reset(&mut self) {
1608        DistributionalSGBT::reset(self);
1609    }
1610}
1611
1612// ---------------------------------------------------------------------------
1613// Tests
1614// ---------------------------------------------------------------------------
1615
1616#[cfg(test)]
1617mod tests {
1618    use super::*;
1619    use alloc::format;
1620    use alloc::vec;
1621    use alloc::vec::Vec;
1622
1623    fn test_config() -> SGBTConfig {
1624        SGBTConfig::builder()
1625            .n_steps(10)
1626            .learning_rate(0.1)
1627            .grace_period(20)
1628            .max_depth(4)
1629            .n_bins(16)
1630            .initial_target_count(10)
1631            .build()
1632            .unwrap()
1633    }
1634
1635    #[test]
1636    fn fresh_model_predicts_zero() {
1637        let model = DistributionalSGBT::new(test_config());
1638        let pred = model.predict(&[1.0, 2.0, 3.0]);
1639        assert!(pred.mu.abs() < 1e-12);
1640        assert!(pred.sigma > 0.0);
1641    }
1642
1643    #[test]
1644    fn sigma_always_positive() {
1645        let mut model = DistributionalSGBT::new(test_config());
1646
1647        // Train on various data
1648        for i in 0..200 {
1649            let x = i as f64 * 0.1;
1650            model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
1651        }
1652
1653        // Check multiple predictions
1654        for i in 0..20 {
1655            let x = i as f64 * 0.5;
1656            let pred = model.predict(&[x, x * 0.5]);
1657            assert!(
1658                pred.sigma > 0.0,
1659                "sigma must be positive, got {}",
1660                pred.sigma
1661            );
1662            assert!(pred.sigma.is_finite(), "sigma must be finite");
1663        }
1664    }
1665
1666    #[test]
1667    fn constant_target_has_small_sigma() {
1668        let mut model = DistributionalSGBT::new(test_config());
1669
1670        // Train on constant target = 5.0 with varying features
1671        for i in 0..200 {
1672            let x = i as f64 * 0.1;
1673            model.train_one(&(vec![x, x * 2.0], 5.0));
1674        }
1675
1676        let pred = model.predict(&[1.0, 2.0]);
1677        assert!(pred.mu.is_finite());
1678        assert!(pred.sigma.is_finite());
1679        assert!(pred.sigma > 0.0);
1680        // Sigma should be relatively small for constant target
1681        // (no noise to explain)
1682    }
1683
1684    #[test]
1685    fn noisy_target_has_finite_predictions() {
1686        let mut model = DistributionalSGBT::new(test_config());
1687
1688        // Simple xorshift for deterministic "noise"
1689        let mut rng: u64 = 42;
1690        for i in 0..200 {
1691            rng ^= rng << 13;
1692            rng ^= rng >> 7;
1693            rng ^= rng << 17;
1694            let noise = (rng % 1000) as f64 / 500.0 - 1.0; // [-1, 1]
1695            let x = i as f64 * 0.1;
1696            model.train_one(&(vec![x], x * 2.0 + noise));
1697        }
1698
1699        let pred = model.predict(&[5.0]);
1700        assert!(pred.mu.is_finite());
1701        assert!(pred.sigma.is_finite());
1702        assert!(pred.sigma > 0.0);
1703    }
1704
1705    #[test]
1706    fn predict_interval_bounds_correct() {
1707        let mut model = DistributionalSGBT::new(test_config());
1708
1709        for i in 0..200 {
1710            let x = i as f64 * 0.1;
1711            model.train_one(&(vec![x], x * 2.0));
1712        }
1713
1714        let (lo, hi) = model.predict_interval(&[5.0], 1.96);
1715        let pred = model.predict(&[5.0]);
1716
1717        assert!(lo < pred.mu, "lower bound should be < mu");
1718        assert!(hi > pred.mu, "upper bound should be > mu");
1719        assert!((hi - lo - 2.0 * 1.96 * pred.sigma).abs() < 1e-10);
1720    }
1721
1722    #[test]
1723    fn batch_prediction_matches_individual() {
1724        let mut model = DistributionalSGBT::new(test_config());
1725
1726        for i in 0..100 {
1727            let x = i as f64 * 0.1;
1728            model.train_one(&(vec![x, x * 2.0], x));
1729        }
1730
1731        let features = vec![vec![1.0, 2.0], vec![3.0, 6.0], vec![5.0, 10.0]];
1732        let batch = model.predict_batch(&features);
1733
1734        for (feat, batch_pred) in features.iter().zip(batch.iter()) {
1735            let individual = model.predict(feat);
1736            assert!((batch_pred.mu - individual.mu).abs() < 1e-12);
1737            assert!((batch_pred.sigma - individual.sigma).abs() < 1e-12);
1738        }
1739    }
1740
1741    #[test]
1742    fn reset_clears_state() {
1743        let mut model = DistributionalSGBT::new(test_config());
1744
1745        for i in 0..200 {
1746            let x = i as f64 * 0.1;
1747            model.train_one(&(vec![x], x * 2.0));
1748        }
1749
1750        assert!(model.n_samples_seen() > 0);
1751        model.reset();
1752
1753        assert_eq!(model.n_samples_seen(), 0);
1754        assert!(!model.is_initialized());
1755    }
1756
1757    #[test]
1758    fn gaussian_prediction_lower_upper() {
1759        let pred = GaussianPrediction {
1760            mu: 10.0,
1761            sigma: 2.0,
1762            log_sigma: 2.0_f64.ln(),
1763        };
1764
1765        assert!((pred.lower(1.96) - (10.0 - 1.96 * 2.0)).abs() < 1e-10);
1766        assert!((pred.upper(1.96) - (10.0 + 1.96 * 2.0)).abs() < 1e-10);
1767    }
1768
1769    #[test]
1770    fn train_batch_works() {
1771        let mut model = DistributionalSGBT::new(test_config());
1772        let samples: Vec<(Vec<f64>, f64)> = (0..100)
1773            .map(|i| {
1774                let x = i as f64 * 0.1;
1775                (vec![x], x * 2.0)
1776            })
1777            .collect();
1778
1779        model.train_batch(&samples);
1780        assert_eq!(model.n_samples_seen(), 100);
1781    }
1782
1783    #[test]
1784    fn debug_format_works() {
1785        let model = DistributionalSGBT::new(test_config());
1786        let debug = format!("{:?}", model);
1787        assert!(debug.contains("DistributionalSGBT"));
1788    }
1789
1790    #[test]
1791    fn n_trees_counts_both_ensembles() {
1792        let model = DistributionalSGBT::new(test_config());
1793        // 10 location + 10 scale = 20 trees minimum
1794        assert!(model.n_trees() >= 20);
1795    }
1796
1797    // -- σ-modulated learning rate tests --
1798
1799    fn modulated_config() -> SGBTConfig {
1800        SGBTConfig::builder()
1801            .n_steps(10)
1802            .learning_rate(0.1)
1803            .grace_period(20)
1804            .max_depth(4)
1805            .n_bins(16)
1806            .initial_target_count(10)
1807            .uncertainty_modulated_lr(true)
1808            .build()
1809            .unwrap()
1810    }
1811
1812    #[test]
1813    fn sigma_modulated_initializes_rolling_mean() {
1814        let mut model = DistributionalSGBT::new(modulated_config());
1815        assert!(model.is_uncertainty_modulated());
1816
1817        // Before initialization, rolling_sigma_mean is 1.0 (placeholder)
1818        assert!((model.rolling_sigma_mean() - 1.0).abs() < 1e-12);
1819
1820        // Train past initialization
1821        for i in 0..200 {
1822            let x = i as f64 * 0.1;
1823            model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
1824        }
1825
1826        // After training, rolling_sigma_mean should have adapted from initial std
1827        assert!(model.rolling_sigma_mean() > 0.0);
1828        assert!(model.rolling_sigma_mean().is_finite());
1829    }
1830
1831    #[test]
1832    fn predict_distributional_returns_sigma_ratio() {
1833        let mut model = DistributionalSGBT::new(modulated_config());
1834
1835        for i in 0..200 {
1836            let x = i as f64 * 0.1;
1837            model.train_one(&(vec![x], x * 2.0 + 1.0));
1838        }
1839
1840        let (mu, sigma, sigma_ratio) = model.predict_distributional(&[5.0]);
1841        assert!(mu.is_finite());
1842        assert!(sigma > 0.0);
1843        assert!(
1844            (0.1..=10.0).contains(&sigma_ratio),
1845            "sigma_ratio={}",
1846            sigma_ratio
1847        );
1848    }
1849
1850    #[test]
1851    fn predict_distributional_without_modulation_returns_one() {
1852        let mut model = DistributionalSGBT::new(test_config());
1853        assert!(!model.is_uncertainty_modulated());
1854
1855        for i in 0..200 {
1856            let x = i as f64 * 0.1;
1857            model.train_one(&(vec![x], x * 2.0));
1858        }
1859
1860        let (_mu, _sigma, sigma_ratio) = model.predict_distributional(&[5.0]);
1861        assert!(
1862            (sigma_ratio - 1.0).abs() < 1e-12,
1863            "should be 1.0 when disabled"
1864        );
1865    }
1866
1867    #[test]
1868    fn modulated_model_sigma_finite_under_varying_noise() {
1869        let mut model = DistributionalSGBT::new(modulated_config());
1870
1871        let mut rng: u64 = 123;
1872        for i in 0..500 {
1873            rng ^= rng << 13;
1874            rng ^= rng >> 7;
1875            rng ^= rng << 17;
1876            let noise = (rng % 1000) as f64 / 100.0 - 5.0; // [-5, 5]
1877            let x = i as f64 * 0.1;
1878            // Regime shift at i=250: noise amplitude increases
1879            let scale = if i < 250 { 1.0 } else { 5.0 };
1880            model.train_one(&(vec![x], x * 2.0 + noise * scale));
1881        }
1882
1883        let pred = model.predict(&[10.0]);
1884        assert!(pred.mu.is_finite());
1885        assert!(pred.sigma.is_finite());
1886        assert!(pred.sigma > 0.0);
1887        assert!(model.rolling_sigma_mean().is_finite());
1888    }
1889
1890    #[test]
1891    fn reset_clears_rolling_sigma_mean() {
1892        let mut model = DistributionalSGBT::new(modulated_config());
1893
1894        for i in 0..200 {
1895            let x = i as f64 * 0.1;
1896            model.train_one(&(vec![x], x * 2.0));
1897        }
1898
1899        let sigma_before = model.rolling_sigma_mean();
1900        assert!(sigma_before > 0.0);
1901
1902        model.reset();
1903        assert!((model.rolling_sigma_mean() - 1.0).abs() < 1e-12);
1904    }
1905
1906    #[test]
1907    fn streaming_learner_returns_mu() {
1908        let mut model = DistributionalSGBT::new(test_config());
1909        for i in 0..200 {
1910            let x = i as f64 * 0.1;
1911            StreamingLearner::train(&mut model, &[x], x * 2.0 + 1.0);
1912        }
1913        let pred = StreamingLearner::predict(&model, &[5.0]);
1914        let gaussian = DistributionalSGBT::predict(&model, &[5.0]);
1915        assert!(
1916            (pred - gaussian.mu).abs() < 1e-12,
1917            "StreamingLearner::predict should return mu"
1918        );
1919    }
1920
1921    // -- Diagnostics tests --
1922
1923    fn trained_model() -> DistributionalSGBT {
1924        let config = SGBTConfig::builder()
1925            .n_steps(10)
1926            .learning_rate(0.1)
1927            .grace_period(10) // low grace period to ensure splits
1928            .max_depth(4)
1929            .n_bins(16)
1930            .initial_target_count(10)
1931            .build()
1932            .unwrap();
1933        let mut model = DistributionalSGBT::new(config);
1934        for i in 0..500 {
1935            let x = i as f64 * 0.1;
1936            model.train_one(&(vec![x, x * 0.5, (i % 3) as f64], x * 2.0 + 1.0));
1937        }
1938        model
1939    }
1940
1941    #[test]
1942    fn diagnostics_returns_correct_tree_count() {
1943        let model = trained_model();
1944        let diag = model.diagnostics();
1945        // 10 location + 10 scale = 20 trees
1946        assert_eq!(diag.trees.len(), 20, "should have 2*n_steps trees");
1947    }
1948
1949    #[test]
1950    fn diagnostics_trees_have_leaves() {
1951        let model = trained_model();
1952        let diag = model.diagnostics();
1953        for (i, tree) in diag.trees.iter().enumerate() {
1954            assert!(tree.n_leaves >= 1, "tree {i} should have at least 1 leaf");
1955        }
1956        // At least some trees should have seen samples.
1957        let total_samples: u64 = diag.trees.iter().map(|t| t.samples_seen).sum();
1958        assert!(
1959            total_samples > 0,
1960            "at least some trees should have seen samples"
1961        );
1962    }
1963
1964    #[test]
1965    fn diagnostics_leaf_weight_stats_finite() {
1966        let model = trained_model();
1967        let diag = model.diagnostics();
1968        for (i, tree) in diag.trees.iter().enumerate() {
1969            let (min, max, mean, std) = tree.leaf_weight_stats;
1970            assert!(min.is_finite(), "tree {i} min not finite");
1971            assert!(max.is_finite(), "tree {i} max not finite");
1972            assert!(mean.is_finite(), "tree {i} mean not finite");
1973            assert!(std.is_finite(), "tree {i} std not finite");
1974            assert!(min <= max, "tree {i} min > max");
1975        }
1976    }
1977
1978    #[test]
1979    fn diagnostics_base_predictions_match() {
1980        let model = trained_model();
1981        let diag = model.diagnostics();
1982        assert!(
1983            (diag.location_base - model.predict(&[0.0, 0.0, 0.0]).mu).abs() < 100.0,
1984            "location_base should be plausible"
1985        );
1986    }
1987
1988    #[test]
1989    fn predict_decomposed_reconstructs_prediction() {
1990        let model = trained_model();
1991        let features = [5.0, 2.5, 1.0];
1992        let pred = model.predict(&features);
1993        let decomp = model.predict_decomposed(&features);
1994
1995        assert!(
1996            (decomp.mu() - pred.mu).abs() < 1e-10,
1997            "decomposed mu ({}) != predict mu ({})",
1998            decomp.mu(),
1999            pred.mu
2000        );
2001        assert!(
2002            (decomp.sigma() - pred.sigma).abs() < 1e-10,
2003            "decomposed sigma ({}) != predict sigma ({})",
2004            decomp.sigma(),
2005            pred.sigma
2006        );
2007    }
2008
2009    #[test]
2010    fn predict_decomposed_correct_lengths() {
2011        let model = trained_model();
2012        let decomp = model.predict_decomposed(&[1.0, 0.5, 0.0]);
2013        assert_eq!(
2014            decomp.location_contributions.len(),
2015            model.n_steps(),
2016            "location contributions should have n_steps entries"
2017        );
2018        assert_eq!(
2019            decomp.scale_contributions.len(),
2020            model.n_steps(),
2021            "scale contributions should have n_steps entries"
2022        );
2023    }
2024
2025    #[test]
2026    fn feature_importances_work() {
2027        let model = trained_model();
2028        let imp = model.feature_importances();
2029        // After enough training, importances should be non-empty and non-negative.
2030        // They may or may not sum to 1.0 if no splits have occurred yet.
2031        for (i, &v) in imp.iter().enumerate() {
2032            assert!(v >= 0.0, "importance {i} should be non-negative, got {v}");
2033            assert!(v.is_finite(), "importance {i} should be finite");
2034        }
2035        let sum: f64 = imp.iter().sum();
2036        if sum > 0.0 {
2037            assert!(
2038                (sum - 1.0).abs() < 1e-10,
2039                "non-zero importances should sum to 1.0, got {sum}"
2040            );
2041        }
2042    }
2043
2044    #[test]
2045    fn feature_importances_split_works() {
2046        let model = trained_model();
2047        let (loc_imp, scale_imp) = model.feature_importances_split();
2048        for (name, imp) in [("location", &loc_imp), ("scale", &scale_imp)] {
2049            let sum: f64 = imp.iter().sum();
2050            if sum > 0.0 {
2051                assert!(
2052                    (sum - 1.0).abs() < 1e-10,
2053                    "{name} importances should sum to 1.0, got {sum}"
2054                );
2055            }
2056            for &v in imp.iter() {
2057                assert!(v >= 0.0 && v.is_finite());
2058            }
2059        }
2060    }
2061
2062    // -- Empirical σ tests --
2063
2064    #[test]
2065    fn empirical_sigma_default_mode() {
2066        use crate::ensemble::config::ScaleMode;
2067        let config = test_config();
2068        let model = DistributionalSGBT::new(config);
2069        assert_eq!(model.scale_mode(), ScaleMode::Empirical);
2070    }
2071
2072    #[test]
2073    fn empirical_sigma_tracks_errors() {
2074        let mut model = DistributionalSGBT::new(test_config());
2075
2076        // Train on clean linear data
2077        for i in 0..200 {
2078            let x = i as f64 * 0.1;
2079            model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
2080        }
2081
2082        let sigma_clean = model.empirical_sigma();
2083        assert!(sigma_clean > 0.0, "sigma should be positive");
2084        assert!(sigma_clean.is_finite(), "sigma should be finite");
2085
2086        // Now train on noisy data — sigma should increase
2087        let mut rng: u64 = 42;
2088        for i in 200..400 {
2089            rng ^= rng << 13;
2090            rng ^= rng >> 7;
2091            rng ^= rng << 17;
2092            let noise = (rng % 10000) as f64 / 100.0 - 50.0; // big noise
2093            let x = i as f64 * 0.1;
2094            model.train_one(&(vec![x, x * 0.5], x * 2.0 + noise));
2095        }
2096
2097        let sigma_noisy = model.empirical_sigma();
2098        assert!(
2099            sigma_noisy > sigma_clean,
2100            "noisy regime should increase sigma: clean={sigma_clean}, noisy={sigma_noisy}"
2101        );
2102    }
2103
2104    #[test]
2105    fn empirical_sigma_modulated_lr_adapts() {
2106        let config = SGBTConfig::builder()
2107            .n_steps(10)
2108            .learning_rate(0.1)
2109            .grace_period(20)
2110            .max_depth(4)
2111            .n_bins(16)
2112            .initial_target_count(10)
2113            .uncertainty_modulated_lr(true)
2114            .build()
2115            .unwrap();
2116        let mut model = DistributionalSGBT::new(config);
2117
2118        // Train and verify sigma_ratio changes
2119        for i in 0..300 {
2120            let x = i as f64 * 0.1;
2121            model.train_one(&(vec![x], x * 2.0 + 1.0));
2122        }
2123
2124        let (_, _, sigma_ratio) = model.predict_distributional(&[5.0]);
2125        assert!(sigma_ratio.is_finite());
2126        assert!(
2127            (0.1..=10.0).contains(&sigma_ratio),
2128            "sigma_ratio={sigma_ratio}"
2129        );
2130    }
2131
2132    #[test]
2133    fn tree_chain_mode_trains_scale_trees() {
2134        use crate::ensemble::config::ScaleMode;
2135        let config = SGBTConfig::builder()
2136            .n_steps(10)
2137            .learning_rate(0.1)
2138            .grace_period(10)
2139            .max_depth(4)
2140            .n_bins(16)
2141            .initial_target_count(10)
2142            .scale_mode(ScaleMode::TreeChain)
2143            .build()
2144            .unwrap();
2145        let mut model = DistributionalSGBT::new(config);
2146        assert_eq!(model.scale_mode(), ScaleMode::TreeChain);
2147
2148        for i in 0..500 {
2149            let x = i as f64 * 0.1;
2150            model.train_one(&(vec![x, x * 0.5, (i % 3) as f64], x * 2.0 + 1.0));
2151        }
2152
2153        let pred = model.predict(&[5.0, 2.5, 1.0]);
2154        assert!(pred.mu.is_finite());
2155        assert!(pred.sigma > 0.0);
2156        assert!(pred.sigma.is_finite());
2157    }
2158
2159    #[test]
2160    fn diagnostics_shows_empirical_sigma() {
2161        let model = trained_model();
2162        let diag = model.diagnostics();
2163        assert!(
2164            diag.empirical_sigma > 0.0,
2165            "empirical_sigma should be positive"
2166        );
2167        assert!(
2168            diag.empirical_sigma.is_finite(),
2169            "empirical_sigma should be finite"
2170        );
2171    }
2172
2173    #[test]
2174    fn diagnostics_scale_trees_split_fields() {
2175        let model = trained_model();
2176        let diag = model.diagnostics();
2177        assert_eq!(diag.location_trees.len(), model.n_steps());
2178        assert_eq!(diag.scale_trees.len(), model.n_steps());
2179        // In empirical mode, scale_trees_active might be 0 (trees not trained)
2180        // This is expected and actually the point.
2181    }
2182
2183    #[test]
2184    fn reset_clears_empirical_sigma() {
2185        let mut model = DistributionalSGBT::new(test_config());
2186        for i in 0..200 {
2187            let x = i as f64 * 0.1;
2188            model.train_one(&(vec![x], x * 2.0));
2189        }
2190        model.reset();
2191        // After reset, ewma_sq_err resets to 1.0
2192        assert!((model.empirical_sigma() - 1.0).abs() < 1e-12);
2193    }
2194
2195    #[test]
2196    fn predict_smooth_returns_finite() {
2197        let config = SGBTConfig::builder()
2198            .n_steps(3)
2199            .learning_rate(0.1)
2200            .grace_period(20)
2201            .max_depth(4)
2202            .n_bins(16)
2203            .initial_target_count(10)
2204            .build()
2205            .unwrap();
2206        let mut model = DistributionalSGBT::new(config);
2207
2208        for i in 0..200 {
2209            let x = (i as f64) * 0.1;
2210            let features = vec![x, x.sin()];
2211            let target = 2.0 * x + 1.0;
2212            model.train_one(&(features, target));
2213        }
2214
2215        let pred = model.predict_smooth(&[1.0, 1.0_f64.sin()], 0.5);
2216        assert!(pred.mu.is_finite(), "smooth mu should be finite");
2217        assert!(pred.sigma.is_finite(), "smooth sigma should be finite");
2218        assert!(pred.sigma > 0.0, "smooth sigma should be positive");
2219    }
2220
2221    // -- PD sigma modulation tests --
2222
2223    #[test]
2224    fn sigma_velocity_responds_to_error_spike() {
2225        let config = SGBTConfig::builder()
2226            .n_steps(3)
2227            .learning_rate(0.1)
2228            .grace_period(20)
2229            .max_depth(4)
2230            .n_bins(16)
2231            .initial_target_count(10)
2232            .uncertainty_modulated_lr(true)
2233            .build()
2234            .unwrap();
2235        let mut model = DistributionalSGBT::new(config);
2236
2237        // Stable phase
2238        for i in 0..200 {
2239            let x = (i as f64) * 0.1;
2240            model.train_one(&(vec![x, x.sin()], 2.0 * x + 1.0));
2241        }
2242
2243        let velocity_before = model.sigma_velocity();
2244
2245        // Error spike -- sudden regime change
2246        for i in 0..50 {
2247            let x = (i as f64) * 0.1;
2248            model.train_one(&(vec![x, x.sin()], 100.0 * x + 50.0));
2249        }
2250
2251        let velocity_after = model.sigma_velocity();
2252
2253        // Velocity should be positive (sigma increasing due to errors)
2254        assert!(
2255            velocity_after > velocity_before,
2256            "sigma velocity should increase after error spike: before={}, after={}",
2257            velocity_before,
2258            velocity_after,
2259        );
2260    }
2261
2262    #[test]
2263    fn sigma_velocity_getter_works() {
2264        let config = SGBTConfig::builder()
2265            .n_steps(2)
2266            .learning_rate(0.1)
2267            .grace_period(20)
2268            .max_depth(4)
2269            .n_bins(16)
2270            .initial_target_count(10)
2271            .build()
2272            .unwrap();
2273        let model = DistributionalSGBT::new(config);
2274        // Fresh model should have zero velocity
2275        assert_eq!(model.sigma_velocity(), 0.0);
2276    }
2277
2278    #[test]
2279    fn diagnostics_leaf_sample_counts_populated() {
2280        let config = SGBTConfig::builder()
2281            .n_steps(3)
2282            .learning_rate(0.1)
2283            .grace_period(10)
2284            .max_depth(4)
2285            .n_bins(16)
2286            .initial_target_count(10)
2287            .build()
2288            .unwrap();
2289        let mut model = DistributionalSGBT::new(config);
2290
2291        for i in 0..200 {
2292            let x = (i as f64) * 0.1;
2293            let features = vec![x, x.sin()];
2294            let target = 2.0 * x + 1.0;
2295            model.train_one(&(features, target));
2296        }
2297
2298        let diags = model.diagnostics();
2299        for (ti, tree) in diags.trees.iter().enumerate() {
2300            assert_eq!(
2301                tree.leaf_sample_counts.len(),
2302                tree.n_leaves,
2303                "tree {} should have sample count per leaf",
2304                ti,
2305            );
2306            // Total samples across leaves should equal samples_seen for trees with data
2307            if tree.samples_seen > 0 {
2308                let total: u64 = tree.leaf_sample_counts.iter().sum();
2309                assert!(
2310                    total > 0,
2311                    "tree {} has {} samples_seen but leaf counts sum to 0",
2312                    ti,
2313                    tree.samples_seen,
2314                );
2315            }
2316        }
2317    }
2318
2319    // -------------------------------------------------------------------
2320    // Packed cache tests (dual-path inference)
2321    // -------------------------------------------------------------------
2322
2323    #[test]
2324    fn packed_cache_disabled_by_default() {
2325        let model = DistributionalSGBT::new(test_config());
2326        assert!(!model.has_packed_cache());
2327        assert_eq!(model.config().packed_refresh_interval, 0);
2328    }
2329
2330    #[test]
2331    #[cfg(feature = "_packed_cache_tests_disabled")]
2332    fn packed_cache_refreshes_after_interval() {
2333        let config = SGBTConfig::builder()
2334            .n_steps(5)
2335            .learning_rate(0.1)
2336            .grace_period(5)
2337            .max_depth(3)
2338            .n_bins(8)
2339            .initial_target_count(10)
2340            .packed_refresh_interval(20)
2341            .build()
2342            .unwrap();
2343
2344        let mut model = DistributionalSGBT::new(config);
2345
2346        // Train past initialization (10 samples) + refresh interval (20 samples)
2347        for i in 0..40 {
2348            let x = i as f64 * 0.1;
2349            model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
2350        }
2351
2352        // Cache should exist after enough training
2353        assert!(
2354            model.has_packed_cache(),
2355            "packed cache should exist after training past refresh interval"
2356        );
2357
2358        // Predictions should be finite
2359        let pred = model.predict(&[2.0, 4.0, 1.0]);
2360        assert!(pred.mu.is_finite(), "mu should be finite: {}", pred.mu);
2361    }
2362
2363    #[test]
2364    #[cfg(feature = "_packed_cache_tests_disabled")]
2365    fn packed_cache_matches_full_tree() {
2366        let config = SGBTConfig::builder()
2367            .n_steps(5)
2368            .learning_rate(0.1)
2369            .grace_period(5)
2370            .max_depth(3)
2371            .n_bins(8)
2372            .initial_target_count(10)
2373            .build()
2374            .unwrap();
2375
2376        let mut model = DistributionalSGBT::new(config);
2377
2378        // Train
2379        for i in 0..80 {
2380            let x = i as f64 * 0.1;
2381            model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
2382        }
2383
2384        // Get full-tree prediction (no cache)
2385        assert!(!model.has_packed_cache());
2386        let full_pred = model.predict(&[2.0, 4.0, 1.0]);
2387
2388        // Enable cache and get cached prediction
2389        model.enable_packed_cache(10);
2390        assert!(model.has_packed_cache());
2391        let cached_pred = model.predict(&[2.0, 4.0, 1.0]);
2392
2393        // Should match within f32 precision (f64->f32 conversion loses some precision)
2394        let mu_diff = (full_pred.mu - cached_pred.mu).abs();
2395        assert!(
2396            mu_diff < 0.1,
2397            "packed cache mu ({}) should match full tree mu ({}) within f32 tolerance, diff={}",
2398            cached_pred.mu,
2399            full_pred.mu,
2400            mu_diff
2401        );
2402
2403        // Sigma should be identical (not affected by packed cache)
2404        assert!(
2405            (full_pred.sigma - cached_pred.sigma).abs() < 1e-12,
2406            "sigma should be identical: full={}, cached={}",
2407            full_pred.sigma,
2408            cached_pred.sigma
2409        );
2410    }
2411}
2412
2413#[cfg(test)]
2414#[cfg(feature = "_serde_support")]
2415mod serde_tests {
2416    use super::*;
2417    use crate::SGBTConfig;
2418
2419    fn make_trained_distributional() -> DistributionalSGBT {
2420        let config = SGBTConfig::builder()
2421            .n_steps(5)
2422            .learning_rate(0.1)
2423            .max_depth(3)
2424            .grace_period(2)
2425            .initial_target_count(10)
2426            .build()
2427            .unwrap();
2428        let mut model = DistributionalSGBT::new(config);
2429        for i in 0..50 {
2430            let x = i as f64 * 0.1;
2431            model.train_one(&(vec![x], x.sin()));
2432        }
2433        model
2434    }
2435
2436    #[test]
2437    fn json_round_trip_preserves_predictions() {
2438        let model = make_trained_distributional();
2439        let state = model.to_distributional_state();
2440        let json = crate::serde_support::save_distributional_model(&state).unwrap();
2441        let loaded_state = crate::serde_support::load_distributional_model(&json).unwrap();
2442        let restored = DistributionalSGBT::from_distributional_state(loaded_state);
2443
2444        let test_points = [0.5, 1.0, 2.0, 3.0];
2445        for &x in &test_points {
2446            let orig = model.predict(&[x]);
2447            let rest = restored.predict(&[x]);
2448            assert!(
2449                (orig.mu - rest.mu).abs() < 1e-10,
2450                "JSON round-trip mu mismatch at x={}: {} vs {}",
2451                x,
2452                orig.mu,
2453                rest.mu
2454            );
2455            assert!(
2456                (orig.sigma - rest.sigma).abs() < 1e-10,
2457                "JSON round-trip sigma mismatch at x={}: {} vs {}",
2458                x,
2459                orig.sigma,
2460                rest.sigma
2461            );
2462        }
2463    }
2464
2465    #[test]
2466    fn state_preserves_rolling_sigma_mean() {
2467        let config = SGBTConfig::builder()
2468            .n_steps(5)
2469            .learning_rate(0.1)
2470            .max_depth(3)
2471            .grace_period(2)
2472            .initial_target_count(10)
2473            .uncertainty_modulated_lr(true)
2474            .build()
2475            .unwrap();
2476        let mut model = DistributionalSGBT::new(config);
2477        for i in 0..50 {
2478            let x = i as f64 * 0.1;
2479            model.train_one(&(vec![x], x.sin()));
2480        }
2481        let state = model.to_distributional_state();
2482        assert!(state.uncertainty_modulated_lr);
2483        assert!(state.rolling_sigma_mean >= 0.0);
2484
2485        let restored = DistributionalSGBT::from_distributional_state(state);
2486        assert_eq!(model.n_samples_seen(), restored.n_samples_seen());
2487    }
2488
2489    #[test]
2490    fn auto_bandwidth_computed_distributional() {
2491        let config = SGBTConfig::builder()
2492            .n_steps(3)
2493            .learning_rate(0.1)
2494            .grace_period(10)
2495            .initial_target_count(10)
2496            .build()
2497            .unwrap();
2498        let mut model = DistributionalSGBT::new(config);
2499
2500        for i in 0..200 {
2501            let x = (i as f64) * 0.1;
2502            model.train_one(&(vec![x, x.sin()], 2.0 * x + 1.0));
2503        }
2504
2505        // auto_bandwidths should be populated
2506        let bws = model.auto_bandwidths();
2507        assert_eq!(bws.len(), 2, "should have 2 feature bandwidths");
2508
2509        // Diagnostics should include auto_bandwidths
2510        let diag = model.diagnostics();
2511        assert_eq!(diag.auto_bandwidths.len(), 2);
2512
2513        // TreeDiagnostic should have prediction stats
2514        assert!(diag.location_trees[0].prediction_mean.is_finite());
2515        assert!(diag.location_trees[0].prediction_std.is_finite());
2516
2517        let pred = model.predict(&[1.0, 1.0_f64.sin()]);
2518        assert!(pred.mu.is_finite(), "auto-bandwidth mu should be finite");
2519        assert!(pred.sigma > 0.0, "auto-bandwidth sigma should be positive");
2520    }
2521
2522    #[test]
2523    fn max_leaf_output_clamps_predictions() {
2524        let config = SGBTConfig::builder()
2525            .n_steps(5)
2526            .learning_rate(1.0) // Intentionally large to force extreme leaf weights
2527            .max_leaf_output(0.5)
2528            .build()
2529            .unwrap();
2530        let mut model = DistributionalSGBT::new(config);
2531
2532        // Train with extreme targets to force large leaf weights
2533        for i in 0..200 {
2534            let target = if i % 2 == 0 { 100.0 } else { -100.0 };
2535            let sample = crate::Sample::new(vec![i as f64 % 5.0, (i as f64).sin()], target);
2536            model.train_one(&sample);
2537        }
2538
2539        // Each tree's prediction should be clamped to [-0.5, 0.5]
2540        let pred = model.predict(&[2.0, 0.5]);
2541        assert!(
2542            pred.mu.is_finite(),
2543            "prediction should be finite with clamping"
2544        );
2545    }
2546
2547    #[test]
2548    fn min_hessian_sum_suppresses_fresh_leaves() {
2549        let config = SGBTConfig::builder()
2550            .n_steps(3)
2551            .learning_rate(0.01)
2552            .min_hessian_sum(50.0)
2553            .build()
2554            .unwrap();
2555        let mut model = DistributionalSGBT::new(config);
2556
2557        // Train minimal samples
2558        for i in 0..60 {
2559            let sample = crate::Sample::new(vec![i as f64, (i as f64).sin()], i as f64 * 0.1);
2560            model.train_one(&sample);
2561        }
2562
2563        let pred = model.predict(&[30.0, 0.5]);
2564        assert!(
2565            pred.mu.is_finite(),
2566            "prediction should be finite with min_hessian_sum"
2567        );
2568    }
2569
2570    #[test]
2571    fn predict_interpolated_returns_finite() {
2572        let config = SGBTConfig::builder()
2573            .n_steps(10)
2574            .learning_rate(0.1)
2575            .grace_period(20)
2576            .max_depth(4)
2577            .n_bins(16)
2578            .initial_target_count(10)
2579            .build()
2580            .unwrap();
2581        let mut model = DistributionalSGBT::new(config);
2582        for i in 0..200 {
2583            let x = i as f64 * 0.1;
2584            let sample = crate::Sample::new(vec![x, x.sin()], x.cos());
2585            model.train_one(&sample);
2586        }
2587
2588        let pred = model.predict_interpolated(&[1.0, 0.5]);
2589        assert!(pred.mu.is_finite(), "interpolated mu should be finite");
2590        assert!(pred.sigma > 0.0, "interpolated sigma should be positive");
2591    }
2592
2593    #[test]
2594    fn huber_k_bounds_gradients() {
2595        let config = SGBTConfig::builder()
2596            .n_steps(5)
2597            .learning_rate(0.01)
2598            .huber_k(1.345)
2599            .build()
2600            .unwrap();
2601        let mut model = DistributionalSGBT::new(config);
2602
2603        // Train with occasional extreme outliers
2604        for i in 0..300 {
2605            let target = if i % 50 == 0 {
2606                1000.0
2607            } else {
2608                (i as f64 * 0.1).sin()
2609            };
2610            let sample = crate::Sample::new(vec![i as f64 % 10.0, (i as f64).cos()], target);
2611            model.train_one(&sample);
2612        }
2613
2614        let pred = model.predict(&[5.0, 0.3]);
2615        assert!(
2616            pred.mu.is_finite(),
2617            "Huber-loss mu should be finite despite outliers"
2618        );
2619        assert!(pred.sigma > 0.0, "sigma should be positive");
2620    }
2621
2622    #[test]
2623    fn ensemble_gradient_stats_populated() {
2624        let config = SGBTConfig::builder()
2625            .n_steps(10)
2626            .learning_rate(0.1)
2627            .grace_period(20)
2628            .max_depth(4)
2629            .n_bins(16)
2630            .initial_target_count(10)
2631            .build()
2632            .unwrap();
2633        let mut model = DistributionalSGBT::new(config);
2634        for i in 0..200 {
2635            let x = i as f64 * 0.1;
2636            let sample = crate::Sample::new(vec![x, x.sin()], x.cos());
2637            model.train_one(&sample);
2638        }
2639
2640        let diag = model.diagnostics();
2641        assert!(
2642            diag.ensemble_grad_mean.is_finite(),
2643            "ensemble grad mean should be finite"
2644        );
2645        assert!(
2646            diag.ensemble_grad_std >= 0.0,
2647            "ensemble grad std should be non-negative"
2648        );
2649        assert!(
2650            diag.ensemble_grad_std.is_finite(),
2651            "ensemble grad std should be finite"
2652        );
2653    }
2654
2655    #[test]
2656    fn huber_k_validation() {
2657        let result = SGBTConfig::builder()
2658            .n_steps(5)
2659            .learning_rate(0.01)
2660            .huber_k(-1.0)
2661            .build();
2662        assert!(result.is_err(), "negative huber_k should fail validation");
2663    }
2664
2665    #[test]
2666    fn max_leaf_output_validation() {
2667        let result = SGBTConfig::builder()
2668            .n_steps(5)
2669            .learning_rate(0.01)
2670            .max_leaf_output(-1.0)
2671            .build();
2672        assert!(
2673            result.is_err(),
2674            "negative max_leaf_output should fail validation"
2675        );
2676    }
2677
2678    #[test]
2679    fn predict_sibling_interpolated_varies_with_features() {
2680        let config = SGBTConfig::builder()
2681            .n_steps(10)
2682            .learning_rate(0.1)
2683            .grace_period(10)
2684            .max_depth(6)
2685            .delta(0.1)
2686            .initial_target_count(10)
2687            .build()
2688            .unwrap();
2689        let mut model = DistributionalSGBT::new(config);
2690
2691        for i in 0..2000 {
2692            let x = (i as f64) * 0.01;
2693            let y = x.sin() * x + 0.5 * (x * 2.0).cos();
2694            let sample = crate::Sample::new(vec![x, x * 0.3], y);
2695            model.train_one(&sample);
2696        }
2697
2698        // Verify the method runs correctly and produces finite predictions
2699        let pred = model.predict_sibling_interpolated(&[5.0, 1.5]);
2700        assert!(
2701            pred.mu.is_finite(),
2702            "sibling interpolated mu should be finite"
2703        );
2704        assert!(
2705            pred.sigma > 0.0,
2706            "sibling interpolated sigma should be positive"
2707        );
2708
2709        // Sweep — at minimum, sibling should produce same or more variation than hard
2710        let bws = model.auto_bandwidths();
2711        if bws.iter().any(|&b| b.is_finite()) {
2712            let hard_preds: Vec<f64> = (0..200)
2713                .map(|i| {
2714                    let x = i as f64 * 0.1;
2715                    model.predict(&[x, x * 0.3]).mu
2716                })
2717                .collect();
2718            let hard_changes = hard_preds
2719                .windows(2)
2720                .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2721                .count();
2722
2723            let preds: Vec<f64> = (0..200)
2724                .map(|i| {
2725                    let x = i as f64 * 0.1;
2726                    model.predict_sibling_interpolated(&[x, x * 0.3]).mu
2727                })
2728                .collect();
2729
2730            let sibling_changes = preds
2731                .windows(2)
2732                .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2733                .count();
2734            assert!(
2735                sibling_changes >= hard_changes,
2736                "sibling should produce >= hard changes: sibling={}, hard={}",
2737                sibling_changes,
2738                hard_changes
2739            );
2740        }
2741    }
2742}