Skip to main content

irithyll_core/ensemble/
distributional.rs

1//! Distributional SGBT -- outputs Gaussian N(μ, σ²) instead of a point estimate.
2//!
3//! [`DistributionalSGBT`] supports two scale estimation modes via
4//! [`ScaleMode`]:
5//!
6//! ## Empirical σ (default)
7//!
8//! Tracks an EWMA of squared prediction errors:
9//!
10//! ```text
11//! err = target - mu
12//! ewma_sq_err = alpha * err² + (1 - alpha) * ewma_sq_err
13//! sigma = sqrt(ewma_sq_err)
14//! ```
15//!
16//! Always calibrated (σ literally *is* recent error magnitude), zero tuning,
17//! O(1) memory and compute.  When `uncertainty_modulated_lr` is enabled,
18//! high recent errors → σ large → location LR scales up → faster correction.
19//!
20//! ## Tree chain (NGBoost-style)
21//!
22//! Maintains two independent tree ensembles: one for location (μ), one for
23//! scale (log σ).  Gives feature-conditional uncertainty but requires strong
24//! scale-gradient signal for the trees to split.
25//!
26//! # References
27//!
28//! Duan et al. (2020). "NGBoost: Natural Gradient Boosting for Probabilistic Prediction."
29
30use alloc::vec;
31use alloc::vec::Vec;
32
33use crate::ensemble::config::{SGBTConfig, ScaleMode};
34use crate::ensemble::step::BoostingStep;
35use crate::sample::{Observation, SampleRef};
36use crate::tree::builder::TreeConfig;
37
38/// Cached packed f32 binary for fast location-only inference.
39///
40/// Re-exported periodically from the location ensemble. Predictions use
41/// contiguous BFS-packed memory for cache-optimal tree traversal.
42struct PackedInferenceCache {
43    bytes: Vec<u8>,
44    base: f64,
45    n_features: usize,
46}
47
48impl Clone for PackedInferenceCache {
49    fn clone(&self) -> Self {
50        Self {
51            bytes: self.bytes.clone(),
52            base: self.base,
53            n_features: self.n_features,
54        }
55    }
56}
57
58/// Prediction from a distributional model: full Gaussian N(μ, σ²).
59#[derive(Debug, Clone, Copy)]
60pub struct GaussianPrediction {
61    /// Location parameter (mean).
62    pub mu: f64,
63    /// Scale parameter (standard deviation, always > 0).
64    pub sigma: f64,
65    /// Log of scale parameter (raw model output for scale ensemble).
66    pub log_sigma: f64,
67    /// Tree contribution variance (epistemic uncertainty).
68    ///
69    /// Standard deviation of individual location-tree contributions,
70    /// computed via one-pass Welford variance with Bessel's correction.
71    /// Reacts instantly when trees disagree (no EWMA lag), making it
72    /// superior to empirical sigma for regime-change detection.
73    ///
74    /// Zero when the model has 0 or 1 active location trees.
75    pub honest_sigma: f64,
76}
77
78impl GaussianPrediction {
79    /// Lower bound of a symmetric confidence interval.
80    ///
81    /// For 95% CI, use `z = 1.96`.
82    #[inline]
83    pub fn lower(&self, z: f64) -> f64 {
84        self.mu - z * self.sigma
85    }
86
87    /// Upper bound of a symmetric confidence interval.
88    #[inline]
89    pub fn upper(&self, z: f64) -> f64 {
90        self.mu + z * self.sigma
91    }
92}
93
94// ---------------------------------------------------------------------------
95// Diagnostic structs
96// ---------------------------------------------------------------------------
97
98/// Per-tree diagnostic summary.
99#[derive(Debug, Clone)]
100pub struct TreeDiagnostic {
101    /// Number of leaf nodes in this tree.
102    pub n_leaves: usize,
103    /// Maximum depth reached by any leaf.
104    pub max_depth_reached: usize,
105    /// Total samples this tree has seen.
106    pub samples_seen: u64,
107    /// Leaf weight statistics: `(min, max, mean, std)`.
108    pub leaf_weight_stats: (f64, f64, f64, f64),
109    /// Feature indices this tree has split on (non-zero gain).
110    pub split_features: Vec<usize>,
111    /// Per-leaf sample counts showing data distribution across leaves.
112    pub leaf_sample_counts: Vec<u64>,
113    /// Running mean of predictions from this tree (Welford online).
114    pub prediction_mean: f64,
115    /// Running standard deviation of predictions from this tree.
116    pub prediction_std: f64,
117}
118
119/// Full model diagnostics for [`DistributionalSGBT`].
120///
121/// Contains per-tree summaries, feature usage, base predictions, and
122/// empirical σ state.
123#[derive(Debug, Clone)]
124pub struct ModelDiagnostics {
125    /// Per-tree diagnostic summaries (location trees first, then scale trees).
126    pub trees: Vec<TreeDiagnostic>,
127    /// Location trees only (view into `trees`).
128    pub location_trees: Vec<TreeDiagnostic>,
129    /// Scale trees only (view into `trees`).
130    pub scale_trees: Vec<TreeDiagnostic>,
131    /// How many trees each feature is used in (split count per feature).
132    pub feature_split_counts: Vec<usize>,
133    /// Base prediction for location (mean).
134    pub location_base: f64,
135    /// Base prediction for scale (log-sigma).
136    pub scale_base: f64,
137    /// Current empirical σ (`sqrt(ewma_sq_err)`), always available.
138    pub empirical_sigma: f64,
139    /// Scale mode in use.
140    pub scale_mode: ScaleMode,
141    /// Number of scale trees that actually split (>1 leaf). 0 = frozen chain.
142    pub scale_trees_active: usize,
143    /// Per-feature auto-calibrated bandwidths for smooth prediction.
144    /// `f64::INFINITY` means that feature uses hard routing.
145    pub auto_bandwidths: Vec<f64>,
146    /// Ensemble-level gradient running mean.
147    pub ensemble_grad_mean: f64,
148    /// Ensemble-level gradient standard deviation.
149    pub ensemble_grad_std: f64,
150}
151
152/// Decomposed prediction showing each tree's contribution.
153#[derive(Debug, Clone)]
154pub struct DecomposedPrediction {
155    /// Base location prediction (mean of initial targets).
156    pub location_base: f64,
157    /// Base scale prediction (log-sigma of initial targets).
158    pub scale_base: f64,
159    /// Per-step location contributions: `learning_rate * tree_prediction`.
160    /// `location_base + sum(location_contributions)` = μ.
161    pub location_contributions: Vec<f64>,
162    /// Per-step scale contributions: `learning_rate * tree_prediction`.
163    /// `scale_base + sum(scale_contributions)` = log(σ).
164    pub scale_contributions: Vec<f64>,
165}
166
167impl DecomposedPrediction {
168    /// Reconstruct the final μ from base + contributions.
169    pub fn mu(&self) -> f64 {
170        self.location_base + self.location_contributions.iter().sum::<f64>()
171    }
172
173    /// Reconstruct the final log(σ) from base + contributions.
174    pub fn log_sigma(&self) -> f64 {
175        self.scale_base + self.scale_contributions.iter().sum::<f64>()
176    }
177
178    /// Reconstruct the final σ (exponentiated).
179    pub fn sigma(&self) -> f64 {
180        crate::math::exp(self.log_sigma()).max(1e-8)
181    }
182}
183
184// ---------------------------------------------------------------------------
185// DistributionalSGBT
186// ---------------------------------------------------------------------------
187
188/// NGBoost-style distributional streaming gradient boosted trees.
189///
190/// Outputs a full Gaussian predictive distribution N(μ, σ²) by maintaining two
191/// independent ensembles -- one for location (mean) and one for scale (log-sigma).
192///
193/// # Example
194///
195/// ```text
196/// use irithyll::SGBTConfig;
197/// use irithyll::ensemble::distributional::DistributionalSGBT;
198///
199/// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
200/// let mut model = DistributionalSGBT::new(config);
201///
202/// // Train on streaming data
203/// model.train_one(&(vec![1.0, 2.0], 3.5));
204///
205/// // Get full distributional prediction
206/// let pred = model.predict(&[1.0, 2.0]);
207/// println!("mean={}, sigma={}", pred.mu, pred.sigma);
208/// ```
209pub struct DistributionalSGBT {
210    /// Configuration (shared between location and scale ensembles).
211    config: SGBTConfig,
212    /// Location (mean) boosting steps.
213    location_steps: Vec<BoostingStep>,
214    /// Scale (log-sigma) boosting steps (only used in `TreeChain` mode).
215    scale_steps: Vec<BoostingStep>,
216    /// Base prediction for location (mean of initial targets).
217    location_base: f64,
218    /// Base prediction for scale (log of std of initial targets).
219    scale_base: f64,
220    /// Whether base predictions have been initialized.
221    base_initialized: bool,
222    /// Running collection of initial targets for computing base predictions.
223    initial_targets: Vec<f64>,
224    /// Number of initial targets to collect before setting base predictions.
225    initial_target_count: usize,
226    /// Total samples trained.
227    samples_seen: u64,
228    /// RNG state for variant logic.
229    rng_state: u64,
230    /// Whether σ-modulated learning rate is enabled.
231    uncertainty_modulated_lr: bool,
232    /// EWMA of the model's predicted σ -- used as the denominator in σ-ratio.
233    ///
234    /// Updated with alpha = 0.001 (slow adaptation) after each training step.
235    /// Initialized from the standard deviation of the initial target collection.
236    rolling_sigma_mean: f64,
237    /// Scale estimation mode: Empirical (default) or TreeChain.
238    scale_mode: ScaleMode,
239    /// EWMA of squared prediction errors (empirical σ mode).
240    ///
241    /// `sigma = sqrt(ewma_sq_err)`. Updated every training step with
242    /// `ewma_sq_err = alpha * err² + (1 - alpha) * ewma_sq_err`.
243    ewma_sq_err: f64,
244    /// EWMA alpha for empirical σ.
245    empirical_sigma_alpha: f64,
246    /// Previous empirical σ value for computing σ velocity.
247    prev_sigma: f64,
248    /// EWMA-smoothed derivative of empirical σ (σ velocity).
249    /// Positive = σ increasing (errors growing), negative = σ decreasing.
250    sigma_velocity: f64,
251    /// Per-feature auto-calibrated bandwidths for smooth prediction.
252    auto_bandwidths: Vec<f64>,
253    /// Sum of replacement counts at last bandwidth computation.
254    last_replacement_sum: u64,
255    /// Ensemble-level gradient running mean (Welford).
256    ensemble_grad_mean: f64,
257    /// Ensemble-level gradient M2 (Welford sum of squared deviations).
258    ensemble_grad_m2: f64,
259    /// Ensemble-level gradient sample count.
260    ensemble_grad_count: u64,
261    /// EWMA of honest_sigma (tree contribution std dev) for σ-modulated LR.
262    ///
263    /// When `uncertainty_modulated_lr` is enabled, the honest_sigma ratio
264    /// (`honest_sigma / rolling_honest_sigma_mean`) modulates the effective
265    /// learning rate, reacting instantly to tree disagreement.
266    rolling_honest_sigma_mean: f64,
267    /// Packed f32 cache for fast location-only inference (dual-path).
268    packed_cache: Option<PackedInferenceCache>,
269    /// Samples trained since last packed cache refresh.
270    samples_since_refresh: u64,
271    /// How often to refresh the packed cache (0 = disabled).
272    packed_refresh_interval: u64,
273}
274
275impl Clone for DistributionalSGBT {
276    fn clone(&self) -> Self {
277        Self {
278            config: self.config.clone(),
279            location_steps: self.location_steps.clone(),
280            scale_steps: self.scale_steps.clone(),
281            location_base: self.location_base,
282            scale_base: self.scale_base,
283            base_initialized: self.base_initialized,
284            initial_targets: self.initial_targets.clone(),
285            initial_target_count: self.initial_target_count,
286            samples_seen: self.samples_seen,
287            rng_state: self.rng_state,
288            uncertainty_modulated_lr: self.uncertainty_modulated_lr,
289            rolling_sigma_mean: self.rolling_sigma_mean,
290            scale_mode: self.scale_mode,
291            ewma_sq_err: self.ewma_sq_err,
292            empirical_sigma_alpha: self.empirical_sigma_alpha,
293            prev_sigma: self.prev_sigma,
294            sigma_velocity: self.sigma_velocity,
295            auto_bandwidths: self.auto_bandwidths.clone(),
296            last_replacement_sum: self.last_replacement_sum,
297            ensemble_grad_mean: self.ensemble_grad_mean,
298            ensemble_grad_m2: self.ensemble_grad_m2,
299            ensemble_grad_count: self.ensemble_grad_count,
300            rolling_honest_sigma_mean: self.rolling_honest_sigma_mean,
301            packed_cache: self.packed_cache.clone(),
302            samples_since_refresh: self.samples_since_refresh,
303            packed_refresh_interval: self.packed_refresh_interval,
304        }
305    }
306}
307
308impl core::fmt::Debug for DistributionalSGBT {
309    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
310        let mut s = f.debug_struct("DistributionalSGBT");
311        s.field("n_steps", &self.location_steps.len())
312            .field("samples_seen", &self.samples_seen)
313            .field("location_base", &self.location_base)
314            .field("scale_mode", &self.scale_mode)
315            .field("base_initialized", &self.base_initialized);
316        match self.scale_mode {
317            ScaleMode::Empirical => {
318                s.field("empirical_sigma", &crate::math::sqrt(self.ewma_sq_err));
319            }
320            ScaleMode::TreeChain => {
321                s.field("scale_base", &self.scale_base);
322            }
323        }
324        if self.uncertainty_modulated_lr {
325            s.field("rolling_sigma_mean", &self.rolling_sigma_mean);
326        }
327        s.finish()
328    }
329}
330
331impl DistributionalSGBT {
332    /// Create a new distributional SGBT with the given configuration.
333    ///
334    /// When `scale_mode` is `Empirical` (default), scale trees are still allocated
335    /// but never trained — only the EWMA error tracker produces σ.  When
336    /// `scale_mode` is `TreeChain`, both location and scale ensembles are active.
337    pub fn new(config: SGBTConfig) -> Self {
338        let leaf_decay_alpha = config
339            .leaf_half_life
340            .map(|hl| crate::math::exp(-(crate::math::ln(2.0_f64)) / hl as f64));
341
342        let tree_config = TreeConfig::new()
343            .max_depth(config.max_depth)
344            .n_bins(config.n_bins)
345            .lambda(config.lambda)
346            .gamma(config.gamma)
347            .grace_period(config.grace_period)
348            .delta(config.delta)
349            .feature_subsample_rate(config.feature_subsample_rate)
350            .leaf_decay_alpha_opt(leaf_decay_alpha)
351            .split_reeval_interval_opt(config.split_reeval_interval)
352            .feature_types_opt(config.feature_types.clone())
353            .gradient_clip_sigma_opt(config.gradient_clip_sigma)
354            .monotone_constraints_opt(config.monotone_constraints.clone())
355            .max_leaf_output_opt(config.max_leaf_output)
356            .adaptive_depth_opt(config.adaptive_depth)
357            .min_hessian_sum_opt(config.min_hessian_sum)
358            .leaf_model_type(config.leaf_model_type.clone());
359
360        let max_tree_samples = config.max_tree_samples;
361
362        // Location ensemble (seed offset: 0)
363        let shadow_warmup = config.shadow_warmup.unwrap_or(0);
364        let location_steps: Vec<BoostingStep> = (0..config.n_steps)
365            .map(|i| {
366                let mut tc = tree_config.clone();
367                tc.seed = config.seed ^ (i as u64);
368                let detector = config.drift_detector.create();
369                if shadow_warmup > 0 {
370                    BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
371                } else {
372                    BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
373                }
374            })
375            .collect();
376
377        // Scale ensemble (seed offset: 0xSCALE) -- only trained in TreeChain mode
378        let scale_steps: Vec<BoostingStep> = (0..config.n_steps)
379            .map(|i| {
380                let mut tc = tree_config.clone();
381                tc.seed = config.seed ^ (i as u64) ^ 0x0005_CA1E_0000_0000;
382                let detector = config.drift_detector.create();
383                if shadow_warmup > 0 {
384                    BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
385                } else {
386                    BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
387                }
388            })
389            .collect();
390
391        let seed = config.seed;
392        let initial_target_count = config.initial_target_count;
393        let uncertainty_modulated_lr = config.uncertainty_modulated_lr;
394        let scale_mode = config.scale_mode;
395        let empirical_sigma_alpha = config.empirical_sigma_alpha;
396        let packed_refresh_interval = config.packed_refresh_interval;
397        Self {
398            config,
399            location_steps,
400            scale_steps,
401            location_base: 0.0,
402            scale_base: 0.0,
403            base_initialized: false,
404            initial_targets: Vec::new(),
405            initial_target_count,
406            samples_seen: 0,
407            rng_state: seed,
408            uncertainty_modulated_lr,
409            rolling_sigma_mean: 1.0, // overwritten during base initialization
410            scale_mode,
411            ewma_sq_err: 1.0, // overwritten during base initialization
412            empirical_sigma_alpha,
413            prev_sigma: 0.0,
414            sigma_velocity: 0.0,
415            auto_bandwidths: Vec::new(),
416            last_replacement_sum: 0,
417            ensemble_grad_mean: 0.0,
418            ensemble_grad_m2: 0.0,
419            ensemble_grad_count: 0,
420            rolling_honest_sigma_mean: 0.0,
421            packed_cache: None,
422            samples_since_refresh: 0,
423            packed_refresh_interval,
424        }
425    }
426
427    /// Compute honest_sigma from location tree contributions for a feature vector.
428    ///
429    /// Returns the Bessel-corrected standard deviation of individual tree
430    /// contributions (`learning_rate * tree.predict(features)`). This is
431    /// O(n_steps) — the same cost as a single prediction pass.
432    fn compute_honest_sigma(&self, features: &[f64]) -> f64 {
433        let n = self.location_steps.len();
434        if n <= 1 {
435            return 0.0;
436        }
437        let lr = self.config.learning_rate;
438        let mut sum = 0.0_f64;
439        let mut sq_sum = 0.0_f64;
440        for step in &self.location_steps {
441            let c = lr * step.predict(features);
442            sum += c;
443            sq_sum += c * c;
444        }
445        let nf = n as f64;
446        let mean_c = sum / nf;
447        let var = (sq_sum / nf) - (mean_c * mean_c);
448        let var_corrected = var * nf / (nf - 1.0);
449        crate::math::sqrt(var_corrected.max(0.0))
450    }
451
452    /// Train on a single observation.
453    pub fn train_one(&mut self, sample: &impl Observation) {
454        self.samples_seen += 1;
455        let target = sample.target();
456        let features = sample.features();
457
458        // Initialize base predictions from first few targets
459        if !self.base_initialized {
460            self.initial_targets.push(target);
461            if self.initial_targets.len() >= self.initial_target_count {
462                // Location base = mean
463                let sum: f64 = self.initial_targets.iter().sum();
464                let mean = sum / self.initial_targets.len() as f64;
465                self.location_base = mean;
466
467                // Scale base = log(std) -- clamped for stability
468                let var: f64 = self
469                    .initial_targets
470                    .iter()
471                    .map(|&y| (y - mean) * (y - mean))
472                    .sum::<f64>()
473                    / self.initial_targets.len() as f64;
474                let initial_std = crate::math::sqrt(var).max(1e-6);
475                self.scale_base = crate::math::ln(initial_std);
476
477                // Initialize rolling sigma mean and ewma from initial targets std
478                self.rolling_sigma_mean = initial_std;
479                self.ewma_sq_err = var.max(1e-12);
480
481                // Initialize PD sigma state
482                self.prev_sigma = initial_std;
483                self.sigma_velocity = 0.0;
484
485                self.base_initialized = true;
486                self.initial_targets.clear();
487                self.initial_targets.shrink_to_fit();
488            }
489            return;
490        }
491
492        match self.scale_mode {
493            ScaleMode::Empirical => self.train_one_empirical(target, features),
494            ScaleMode::TreeChain => self.train_one_tree_chain(target, features),
495        }
496
497        // Refresh auto-bandwidths when trees have been replaced
498        self.refresh_bandwidths();
499    }
500
501    /// Empirical-σ training: location trees only, σ from EWMA of squared errors.
502    fn train_one_empirical(&mut self, target: f64, features: &[f64]) {
503        // Current location prediction (before this step's update)
504        let mut mu = self.location_base;
505        for s in 0..self.location_steps.len() {
506            mu += self.config.learning_rate * self.location_steps[s].predict(features);
507        }
508
509        // Update rolling honest_sigma mean (tree contribution variance EWMA)
510        let honest_sigma = self.compute_honest_sigma(features);
511        const HONEST_SIGMA_ALPHA: f64 = 0.001;
512        self.rolling_honest_sigma_mean = (1.0 - HONEST_SIGMA_ALPHA)
513            * self.rolling_honest_sigma_mean
514            + HONEST_SIGMA_ALPHA * honest_sigma;
515
516        // Empirical sigma from EWMA of squared prediction errors
517        let err = target - mu;
518        let alpha = self.empirical_sigma_alpha;
519        self.ewma_sq_err = (1.0 - alpha) * self.ewma_sq_err + alpha * err * err;
520        let empirical_sigma = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
521
522        // Compute σ-ratio for uncertainty-modulated learning rate (PD controller)
523        let sigma_ratio = if self.uncertainty_modulated_lr {
524            // Compute sigma velocity (derivative of sigma over time)
525            let d_sigma = empirical_sigma - self.prev_sigma;
526            self.prev_sigma = empirical_sigma;
527
528            // EWMA-smooth the velocity (same alpha as empirical sigma for synchronization)
529            self.sigma_velocity = (1.0 - alpha) * self.sigma_velocity + alpha * d_sigma;
530
531            // Adaptive derivative gain: self-calibrating, no config needed
532            let k_d = if self.rolling_sigma_mean > 1e-12 {
533                crate::math::abs(self.sigma_velocity) / self.rolling_sigma_mean
534            } else {
535                0.0
536            };
537
538            // PD ratio: proportional + derivative
539            let pd_sigma = empirical_sigma + k_d * self.sigma_velocity;
540            let ratio = (pd_sigma / self.rolling_sigma_mean).clamp(0.1, 10.0);
541
542            // Update rolling sigma mean with slow EWMA
543            const SIGMA_EWMA_ALPHA: f64 = 0.001;
544            self.rolling_sigma_mean = (1.0 - SIGMA_EWMA_ALPHA) * self.rolling_sigma_mean
545                + SIGMA_EWMA_ALPHA * empirical_sigma;
546
547            ratio
548        } else {
549            1.0
550        };
551
552        let base_lr = self.config.learning_rate;
553
554        // Train location steps only -- no scale trees needed
555        let mut mu_accum = self.location_base;
556        for s in 0..self.location_steps.len() {
557            let (g_mu, h_mu) = self.location_gradient(mu_accum, target);
558            // Welford update for ensemble gradient stats
559            self.update_ensemble_grad_stats(g_mu);
560            let train_count = self.config.variant.train_count(h_mu, &mut self.rng_state);
561            let loc_pred =
562                self.location_steps[s].train_and_predict(features, g_mu, h_mu, train_count);
563            mu_accum += (base_lr * sigma_ratio) * loc_pred;
564        }
565
566        // Refresh packed cache if interval reached
567        self.maybe_refresh_packed_cache();
568    }
569
570    /// Tree-chain training: full NGBoost dual-chain with location + scale trees.
571    fn train_one_tree_chain(&mut self, target: f64, features: &[f64]) {
572        let mut mu = self.location_base;
573        let mut log_sigma = self.scale_base;
574
575        // Update rolling honest_sigma mean (tree contribution variance EWMA)
576        let honest_sigma = self.compute_honest_sigma(features);
577        const HONEST_SIGMA_ALPHA: f64 = 0.001;
578        self.rolling_honest_sigma_mean = (1.0 - HONEST_SIGMA_ALPHA)
579            * self.rolling_honest_sigma_mean
580            + HONEST_SIGMA_ALPHA * honest_sigma;
581
582        // Compute σ-ratio for uncertainty-modulated learning rate (PD controller).
583        let sigma_ratio = if self.uncertainty_modulated_lr {
584            let current_sigma = crate::math::exp(log_sigma).max(1e-8);
585
586            // Compute sigma velocity (derivative of sigma over time)
587            let d_sigma = current_sigma - self.prev_sigma;
588            self.prev_sigma = current_sigma;
589
590            // EWMA-smooth the velocity (same alpha as empirical sigma for synchronization)
591            let alpha = self.empirical_sigma_alpha;
592            self.sigma_velocity = (1.0 - alpha) * self.sigma_velocity + alpha * d_sigma;
593
594            // Adaptive derivative gain: self-calibrating, no config needed
595            let k_d = if self.rolling_sigma_mean > 1e-12 {
596                crate::math::abs(self.sigma_velocity) / self.rolling_sigma_mean
597            } else {
598                0.0
599            };
600
601            // PD ratio: proportional + derivative
602            let pd_sigma = current_sigma + k_d * self.sigma_velocity;
603            let ratio = (pd_sigma / self.rolling_sigma_mean).clamp(0.1, 10.0);
604
605            const SIGMA_EWMA_ALPHA: f64 = 0.001;
606            self.rolling_sigma_mean = (1.0 - SIGMA_EWMA_ALPHA) * self.rolling_sigma_mean
607                + SIGMA_EWMA_ALPHA * current_sigma;
608
609            ratio
610        } else {
611            1.0
612        };
613
614        let base_lr = self.config.learning_rate;
615
616        // Sequential boosting: both ensembles target their respective residuals
617        for s in 0..self.location_steps.len() {
618            let sigma = crate::math::exp(log_sigma).max(1e-8);
619            let z = (target - mu) / sigma;
620
621            // Location gradients (squared or Huber loss w.r.t. mu)
622            let (g_mu, h_mu) = self.location_gradient(mu, target);
623            // Welford update for ensemble gradient stats
624            self.update_ensemble_grad_stats(g_mu);
625
626            // Scale gradients (NLL w.r.t. log_sigma)
627            let g_sigma = 1.0 - z * z;
628            let h_sigma = (2.0 * z * z).clamp(0.01, 100.0);
629
630            let train_count = self.config.variant.train_count(h_mu, &mut self.rng_state);
631
632            // Train location step -- σ-modulated LR when enabled
633            let loc_pred =
634                self.location_steps[s].train_and_predict(features, g_mu, h_mu, train_count);
635            mu += (base_lr * sigma_ratio) * loc_pred;
636
637            // Train scale step -- ALWAYS at unmodulated base rate.
638            let scale_pred =
639                self.scale_steps[s].train_and_predict(features, g_sigma, h_sigma, train_count);
640            log_sigma += base_lr * scale_pred;
641        }
642
643        // Also update empirical sigma tracker for diagnostics
644        let err = target - mu;
645        let alpha = self.empirical_sigma_alpha;
646        self.ewma_sq_err = (1.0 - alpha) * self.ewma_sq_err + alpha * err * err;
647
648        // Refresh packed cache if interval reached
649        self.maybe_refresh_packed_cache();
650    }
651
652    /// Predict the full Gaussian distribution for a feature vector.
653    ///
654    /// When a packed cache is available, uses it for the location (μ) prediction
655    /// via contiguous BFS-packed memory traversal. Falls back to full tree
656    /// traversal if the cache is absent or produces non-finite results.
657    ///
658    /// Sigma computation always uses the primary path (EWMA or scale chain)
659    /// and is unaffected by the packed cache.
660    pub fn predict(&self, features: &[f64]) -> GaussianPrediction {
661        // Try packed cache for mu if available
662        let mu = if let Some(ref cache) = self.packed_cache {
663            let features_f32: Vec<f32> = features.iter().map(|&v| v as f32).collect();
664            match crate::EnsembleView::from_bytes(&cache.bytes) {
665                Ok(view) => {
666                    let packed_mu = cache.base + view.predict(&features_f32) as f64;
667                    if packed_mu.is_finite() {
668                        packed_mu
669                    } else {
670                        self.predict_full_trees(features)
671                    }
672                }
673                Err(_) => self.predict_full_trees(features),
674            }
675        } else {
676            self.predict_full_trees(features)
677        };
678
679        let (sigma, log_sigma) = match self.scale_mode {
680            ScaleMode::Empirical => {
681                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
682                (s, crate::math::ln(s))
683            }
684            ScaleMode::TreeChain => {
685                let mut ls = self.scale_base;
686                if self.auto_bandwidths.is_empty() {
687                    for s in 0..self.scale_steps.len() {
688                        ls += self.config.learning_rate * self.scale_steps[s].predict(features);
689                    }
690                } else {
691                    for s in 0..self.scale_steps.len() {
692                        ls += self.config.learning_rate
693                            * self.scale_steps[s]
694                                .predict_smooth_auto(features, &self.auto_bandwidths);
695                    }
696                }
697                (crate::math::exp(ls).max(1e-8), ls)
698            }
699        };
700
701        let honest_sigma = self.compute_honest_sigma(features);
702
703        GaussianPrediction {
704            mu,
705            sigma,
706            log_sigma,
707            honest_sigma,
708        }
709    }
710
711    /// Full-tree location prediction (fallback when packed cache is unavailable).
712    fn predict_full_trees(&self, features: &[f64]) -> f64 {
713        let mut mu = self.location_base;
714        if self.auto_bandwidths.is_empty() {
715            for s in 0..self.location_steps.len() {
716                mu += self.config.learning_rate * self.location_steps[s].predict(features);
717            }
718        } else {
719            for s in 0..self.location_steps.len() {
720                mu += self.config.learning_rate
721                    * self.location_steps[s].predict_smooth_auto(features, &self.auto_bandwidths);
722            }
723        }
724        mu
725    }
726
727    /// Predict using sigmoid-blended soft routing for smooth interpolation.
728    ///
729    /// Instead of hard left/right routing at tree split nodes, each split
730    /// uses sigmoid blending: `alpha = sigmoid((threshold - feature) / bandwidth)`.
731    /// The result is a continuous function that varies smoothly with every
732    /// feature change.
733    ///
734    /// `bandwidth` controls transition sharpness: smaller = sharper (closer
735    /// to hard splits), larger = smoother.
736    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> GaussianPrediction {
737        let mut mu = self.location_base;
738        for s in 0..self.location_steps.len() {
739            mu += self.config.learning_rate
740                * self.location_steps[s].predict_smooth(features, bandwidth);
741        }
742
743        let (sigma, log_sigma) = match self.scale_mode {
744            ScaleMode::Empirical => {
745                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
746                (s, crate::math::ln(s))
747            }
748            ScaleMode::TreeChain => {
749                let mut ls = self.scale_base;
750                for s in 0..self.scale_steps.len() {
751                    ls += self.config.learning_rate
752                        * self.scale_steps[s].predict_smooth(features, bandwidth);
753                }
754                (crate::math::exp(ls).max(1e-8), ls)
755            }
756        };
757
758        let honest_sigma = self.compute_honest_sigma(features);
759
760        GaussianPrediction {
761            mu,
762            sigma,
763            log_sigma,
764            honest_sigma,
765        }
766    }
767
768    /// Predict with parent-leaf linear interpolation.
769    ///
770    /// Blends each leaf prediction with its parent's preserved prediction
771    /// based on sample count, preventing stale predictions from fresh leaves.
772    pub fn predict_interpolated(&self, features: &[f64]) -> GaussianPrediction {
773        let mut mu = self.location_base;
774        for s in 0..self.location_steps.len() {
775            mu += self.config.learning_rate * self.location_steps[s].predict_interpolated(features);
776        }
777
778        let (sigma, log_sigma) = match self.scale_mode {
779            ScaleMode::Empirical => {
780                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
781                (s, crate::math::ln(s))
782            }
783            ScaleMode::TreeChain => {
784                let mut ls = self.scale_base;
785                for s in 0..self.scale_steps.len() {
786                    ls += self.config.learning_rate
787                        * self.scale_steps[s].predict_interpolated(features);
788                }
789                (crate::math::exp(ls).max(1e-8), ls)
790            }
791        };
792
793        let honest_sigma = self.compute_honest_sigma(features);
794
795        GaussianPrediction {
796            mu,
797            sigma,
798            log_sigma,
799            honest_sigma,
800        }
801    }
802
803    /// Predict with sibling-based interpolation for feature-continuous predictions.
804    ///
805    /// At each split node near the threshold boundary, blends left and right
806    /// subtree predictions linearly. Uses auto-calibrated bandwidths as the
807    /// interpolation margin. Predictions vary continuously as features change.
808    pub fn predict_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
809        let mut mu = self.location_base;
810        for s in 0..self.location_steps.len() {
811            mu += self.config.learning_rate
812                * self.location_steps[s]
813                    .predict_sibling_interpolated(features, &self.auto_bandwidths);
814        }
815
816        let (sigma, log_sigma) = match self.scale_mode {
817            ScaleMode::Empirical => {
818                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
819                (s, crate::math::ln(s))
820            }
821            ScaleMode::TreeChain => {
822                let mut ls = self.scale_base;
823                for s in 0..self.scale_steps.len() {
824                    ls += self.config.learning_rate
825                        * self.scale_steps[s]
826                            .predict_sibling_interpolated(features, &self.auto_bandwidths);
827                }
828                (crate::math::exp(ls).max(1e-8), ls)
829            }
830        };
831
832        let honest_sigma = self.compute_honest_sigma(features);
833
834        GaussianPrediction {
835            mu,
836            sigma,
837            log_sigma,
838            honest_sigma,
839        }
840    }
841
842    /// Predict with graduated active-shadow blending.
843    ///
844    /// Smoothly transitions between active and shadow trees during replacement.
845    /// Requires `shadow_warmup` to be configured.
846    pub fn predict_graduated(&self, features: &[f64]) -> GaussianPrediction {
847        let mut mu = self.location_base;
848        for s in 0..self.location_steps.len() {
849            mu += self.config.learning_rate * self.location_steps[s].predict_graduated(features);
850        }
851
852        let (sigma, log_sigma) = match self.scale_mode {
853            ScaleMode::Empirical => {
854                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
855                (s, crate::math::ln(s))
856            }
857            ScaleMode::TreeChain => {
858                let mut ls = self.scale_base;
859                for s in 0..self.scale_steps.len() {
860                    ls +=
861                        self.config.learning_rate * self.scale_steps[s].predict_graduated(features);
862                }
863                (crate::math::exp(ls).max(1e-8), ls)
864            }
865        };
866
867        let honest_sigma = self.compute_honest_sigma(features);
868
869        GaussianPrediction {
870            mu,
871            sigma,
872            log_sigma,
873            honest_sigma,
874        }
875    }
876
877    /// Predict with graduated blending + sibling interpolation (premium path).
878    pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
879        let mut mu = self.location_base;
880        for s in 0..self.location_steps.len() {
881            mu += self.config.learning_rate
882                * self.location_steps[s]
883                    .predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
884        }
885
886        let (sigma, log_sigma) = match self.scale_mode {
887            ScaleMode::Empirical => {
888                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
889                (s, crate::math::ln(s))
890            }
891            ScaleMode::TreeChain => {
892                let mut ls = self.scale_base;
893                for s in 0..self.scale_steps.len() {
894                    ls += self.config.learning_rate
895                        * self.scale_steps[s].predict_graduated_sibling_interpolated(
896                            features,
897                            &self.auto_bandwidths,
898                        );
899                }
900                (crate::math::exp(ls).max(1e-8), ls)
901            }
902        };
903
904        let honest_sigma = self.compute_honest_sigma(features);
905
906        GaussianPrediction {
907            mu,
908            sigma,
909            log_sigma,
910            honest_sigma,
911        }
912    }
913
914    /// Predict using per-node auto-bandwidth soft routing.
915    ///
916    /// Every prediction is a continuous weighted blend instead of a
917    /// piecewise-constant step function. No training changes.
918    pub fn predict_soft_routed(&self, features: &[f64]) -> GaussianPrediction {
919        // Location chain with soft routing
920        let mut mu = self.location_base;
921        for step in &self.location_steps {
922            mu += self.config.learning_rate * step.predict_soft_routed(features);
923        }
924
925        // Scale chain
926        let (sigma, log_sigma) = match self.scale_mode {
927            ScaleMode::Empirical => {
928                let s = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
929                (s, crate::math::ln(s))
930            }
931            ScaleMode::TreeChain => {
932                let mut ls = self.scale_base;
933                for step in &self.scale_steps {
934                    ls += self.config.learning_rate * step.predict_soft_routed(features);
935                }
936                (crate::math::exp(ls).max(1e-8), ls)
937            }
938        };
939
940        let honest_sigma = self.compute_honest_sigma(features);
941
942        GaussianPrediction {
943            mu,
944            sigma,
945            log_sigma,
946            honest_sigma,
947        }
948    }
949
950    /// Predict with σ-ratio diagnostic exposed.
951    ///
952    /// Returns `(mu, sigma, sigma_ratio)` where `sigma_ratio` is
953    /// `current_sigma / rolling_sigma_mean` -- the multiplier applied to the
954    /// location learning rate when [`uncertainty_modulated_lr`](SGBTConfig::uncertainty_modulated_lr)
955    /// is enabled.
956    ///
957    /// When σ-modulation is disabled, `sigma_ratio` is always `1.0`.
958    pub fn predict_distributional(&self, features: &[f64]) -> (f64, f64, f64) {
959        let pred = self.predict(features);
960        let sigma_ratio = if self.uncertainty_modulated_lr {
961            (pred.sigma / self.rolling_sigma_mean).clamp(0.1, 10.0)
962        } else {
963            1.0
964        };
965        (pred.mu, pred.sigma, sigma_ratio)
966    }
967
968    /// Current empirical sigma (`sqrt(ewma_sq_err)`).
969    ///
970    /// Returns the model's recent error magnitude. Available in both scale modes.
971    #[inline]
972    pub fn empirical_sigma(&self) -> f64 {
973        crate::math::sqrt(self.ewma_sq_err)
974    }
975
976    /// Current scale mode.
977    #[inline]
978    pub fn scale_mode(&self) -> ScaleMode {
979        self.scale_mode
980    }
981
982    /// Current σ velocity -- the EWMA-smoothed derivative of empirical σ.
983    ///
984    /// Positive values indicate growing prediction errors (model deteriorating
985    /// or regime change). Negative values indicate improving predictions.
986    /// Only meaningful when `ScaleMode::Empirical` is active.
987    #[inline]
988    pub fn sigma_velocity(&self) -> f64 {
989        self.sigma_velocity
990    }
991
992    /// Predict the mean (location parameter) only.
993    #[inline]
994    pub fn predict_mu(&self, features: &[f64]) -> f64 {
995        self.predict(features).mu
996    }
997
998    /// Predict the standard deviation (scale parameter) only.
999    #[inline]
1000    pub fn predict_sigma(&self, features: &[f64]) -> f64 {
1001        self.predict(features).sigma
1002    }
1003
1004    /// Predict a symmetric confidence interval.
1005    ///
1006    /// `confidence` is the Z-score multiplier:
1007    /// - 1.0 → 68% CI
1008    /// - 1.96 → 95% CI
1009    /// - 2.576 → 99% CI
1010    pub fn predict_interval(&self, features: &[f64], confidence: f64) -> (f64, f64) {
1011        let pred = self.predict(features);
1012        (pred.lower(confidence), pred.upper(confidence))
1013    }
1014
1015    /// Batch prediction.
1016    pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<GaussianPrediction> {
1017        feature_matrix.iter().map(|f| self.predict(f)).collect()
1018    }
1019
1020    /// Train on a batch of observations.
1021    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
1022        for sample in samples {
1023            self.train_one(sample);
1024        }
1025    }
1026
1027    /// Train on a batch with periodic callback.
1028    pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
1029        &mut self,
1030        samples: &[O],
1031        interval: usize,
1032        mut callback: F,
1033    ) {
1034        let interval = interval.max(1);
1035        for (i, sample) in samples.iter().enumerate() {
1036            self.train_one(sample);
1037            if (i + 1) % interval == 0 {
1038                callback(i + 1);
1039            }
1040        }
1041        let total = samples.len();
1042        if total % interval != 0 {
1043            callback(total);
1044        }
1045    }
1046
1047    /// Compute location gradient (squared or adaptive Huber).
1048    ///
1049    /// When `huber_k` is configured, uses Huber loss with adaptive
1050    /// `delta = k * empirical_sigma` for bounded gradients.
1051    #[inline]
1052    fn location_gradient(&self, mu: f64, target: f64) -> (f64, f64) {
1053        if let Some(k) = self.config.huber_k {
1054            let delta = k * crate::math::sqrt(self.ewma_sq_err).max(1e-8);
1055            let residual = mu - target;
1056            if crate::math::abs(residual) <= delta {
1057                (residual, 1.0)
1058            } else {
1059                (delta * residual.signum(), 1e-6)
1060            }
1061        } else {
1062            (mu - target, 1.0)
1063        }
1064    }
1065
1066    /// Welford update for ensemble-level gradient statistics.
1067    #[inline]
1068    fn update_ensemble_grad_stats(&mut self, gradient: f64) {
1069        self.ensemble_grad_count += 1;
1070        let delta = gradient - self.ensemble_grad_mean;
1071        self.ensemble_grad_mean += delta / self.ensemble_grad_count as f64;
1072        let delta2 = gradient - self.ensemble_grad_mean;
1073        self.ensemble_grad_m2 += delta * delta2;
1074    }
1075
1076    /// Ensemble-level gradient standard deviation.
1077    pub fn ensemble_grad_std(&self) -> f64 {
1078        if self.ensemble_grad_count < 2 {
1079            return 0.0;
1080        }
1081        crate::math::fmax(
1082            crate::math::sqrt(self.ensemble_grad_m2 / (self.ensemble_grad_count - 1) as f64),
1083            0.0,
1084        )
1085    }
1086
1087    /// Ensemble-level gradient mean.
1088    pub fn ensemble_grad_mean(&self) -> f64 {
1089        self.ensemble_grad_mean
1090    }
1091
1092    /// Check if packed cache should be refreshed and do so if interval reached.
1093    fn maybe_refresh_packed_cache(&mut self) {
1094        if self.packed_refresh_interval > 0 {
1095            self.samples_since_refresh += 1;
1096            if self.samples_since_refresh >= self.packed_refresh_interval {
1097                self.refresh_packed_cache();
1098                self.samples_since_refresh = 0;
1099            }
1100        }
1101    }
1102
1103    /// Re-export the location ensemble into the packed cache.
1104    ///
1105    /// In `irithyll-core` this is a no-op because `export_embedded` is not
1106    /// available. The full `irithyll` crate overrides this via its re-export
1107    /// layer to populate the cache with packed f32 bytes.
1108    fn refresh_packed_cache(&mut self) {
1109        // export_distributional_packed lives in irithyll (not irithyll-core),
1110        // so packed cache is not populated in core-only builds.
1111        // The full irithyll crate hooks into this via its own DistributionalSGBT
1112        // re-export which adds the packed cache logic.
1113    }
1114
1115    /// Enable or reconfigure the packed inference cache at runtime.
1116    ///
1117    /// Sets the refresh interval and immediately builds the initial cache
1118    /// if the model has been initialized. Pass `0` to disable.
1119    pub fn enable_packed_cache(&mut self, interval: u64) {
1120        self.packed_refresh_interval = interval;
1121        self.samples_since_refresh = 0;
1122        if interval > 0 && self.base_initialized {
1123            self.refresh_packed_cache();
1124        } else if interval == 0 {
1125            self.packed_cache = None;
1126        }
1127    }
1128
1129    /// Whether the packed inference cache is currently populated.
1130    #[inline]
1131    pub fn has_packed_cache(&self) -> bool {
1132        self.packed_cache.is_some()
1133    }
1134
1135    /// Refresh auto-bandwidths if any tree has been replaced since last computation.
1136    fn refresh_bandwidths(&mut self) {
1137        let current_sum: u64 = self
1138            .location_steps
1139            .iter()
1140            .chain(self.scale_steps.iter())
1141            .map(|s| s.slot().replacements())
1142            .sum();
1143        if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
1144            self.auto_bandwidths = self.compute_auto_bandwidths();
1145            self.last_replacement_sum = current_sum;
1146        }
1147    }
1148
1149    /// Compute per-feature auto-calibrated bandwidths from all trees.
1150    fn compute_auto_bandwidths(&self) -> Vec<f64> {
1151        const K: f64 = 2.0;
1152
1153        let n_features = self
1154            .location_steps
1155            .iter()
1156            .chain(self.scale_steps.iter())
1157            .filter_map(|s| s.slot().active_tree().n_features())
1158            .max()
1159            .unwrap_or(0);
1160
1161        if n_features == 0 {
1162            return Vec::new();
1163        }
1164
1165        let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
1166
1167        for step in self.location_steps.iter().chain(self.scale_steps.iter()) {
1168            let tree_thresholds = step
1169                .slot()
1170                .active_tree()
1171                .collect_split_thresholds_per_feature();
1172            for (i, ts) in tree_thresholds.into_iter().enumerate() {
1173                if i < n_features {
1174                    all_thresholds[i].extend(ts);
1175                }
1176            }
1177        }
1178
1179        let n_bins = self.config.n_bins as f64;
1180
1181        all_thresholds
1182            .iter()
1183            .map(|ts| {
1184                if ts.is_empty() {
1185                    return f64::INFINITY;
1186                }
1187
1188                let mut sorted = ts.clone();
1189                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
1190                sorted.dedup_by(|a, b| crate::math::abs(*a - *b) < 1e-15);
1191
1192                if sorted.len() < 2 {
1193                    return f64::INFINITY;
1194                }
1195
1196                let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
1197
1198                if sorted.len() < 3 {
1199                    let range = sorted.last().unwrap() - sorted.first().unwrap();
1200                    if range < 1e-15 {
1201                        return f64::INFINITY;
1202                    }
1203                    return (range / n_bins) * K;
1204                }
1205
1206                gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
1207                let median_gap = if gaps.len() % 2 == 0 {
1208                    (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
1209                } else {
1210                    gaps[gaps.len() / 2]
1211                };
1212
1213                if median_gap < 1e-15 {
1214                    f64::INFINITY
1215                } else {
1216                    median_gap * K
1217                }
1218            })
1219            .collect()
1220    }
1221
1222    /// Per-feature auto-calibrated bandwidths used by `predict()`.
1223    pub fn auto_bandwidths(&self) -> &[f64] {
1224        &self.auto_bandwidths
1225    }
1226
1227    /// Reset to initial untrained state.
1228    pub fn reset(&mut self) {
1229        for step in &mut self.location_steps {
1230            step.reset();
1231        }
1232        for step in &mut self.scale_steps {
1233            step.reset();
1234        }
1235        self.location_base = 0.0;
1236        self.scale_base = 0.0;
1237        self.base_initialized = false;
1238        self.initial_targets.clear();
1239        self.samples_seen = 0;
1240        self.rng_state = self.config.seed;
1241        self.rolling_sigma_mean = 1.0;
1242        self.ewma_sq_err = 1.0;
1243        self.prev_sigma = 0.0;
1244        self.sigma_velocity = 0.0;
1245        self.auto_bandwidths.clear();
1246        self.last_replacement_sum = 0;
1247        self.ensemble_grad_mean = 0.0;
1248        self.ensemble_grad_m2 = 0.0;
1249        self.ensemble_grad_count = 0;
1250        self.rolling_honest_sigma_mean = 0.0;
1251        self.packed_cache = None;
1252        self.samples_since_refresh = 0;
1253    }
1254
1255    /// Total samples trained.
1256    #[inline]
1257    pub fn n_samples_seen(&self) -> u64 {
1258        self.samples_seen
1259    }
1260
1261    /// Number of boosting steps (same for location and scale).
1262    #[inline]
1263    pub fn n_steps(&self) -> usize {
1264        self.location_steps.len()
1265    }
1266
1267    /// Total trees (location + scale, active + alternates).
1268    pub fn n_trees(&self) -> usize {
1269        let loc = self.location_steps.len()
1270            + self
1271                .location_steps
1272                .iter()
1273                .filter(|s| s.has_alternate())
1274                .count();
1275        let scale = self.scale_steps.len()
1276            + self
1277                .scale_steps
1278                .iter()
1279                .filter(|s| s.has_alternate())
1280                .count();
1281        loc + scale
1282    }
1283
1284    /// Total leaves across all active trees (location + scale).
1285    pub fn total_leaves(&self) -> usize {
1286        let loc: usize = self.location_steps.iter().map(|s| s.n_leaves()).sum();
1287        let scale: usize = self.scale_steps.iter().map(|s| s.n_leaves()).sum();
1288        loc + scale
1289    }
1290
1291    /// Whether base predictions have been initialized.
1292    #[inline]
1293    pub fn is_initialized(&self) -> bool {
1294        self.base_initialized
1295    }
1296
1297    /// Access the configuration.
1298    #[inline]
1299    pub fn config(&self) -> &SGBTConfig {
1300        &self.config
1301    }
1302
1303    /// Access the location boosting steps (for export/inspection).
1304    pub fn location_steps(&self) -> &[BoostingStep] {
1305        &self.location_steps
1306    }
1307
1308    /// Base prediction for the location (mean) ensemble.
1309    #[inline]
1310    pub fn location_base(&self) -> f64 {
1311        self.location_base
1312    }
1313
1314    /// Learning rate from the model configuration.
1315    #[inline]
1316    pub fn learning_rate(&self) -> f64 {
1317        self.config.learning_rate
1318    }
1319
1320    /// Current rolling σ mean (EWMA of predicted σ).
1321    ///
1322    /// Returns `1.0` if the model hasn't been initialized yet.
1323    #[inline]
1324    pub fn rolling_sigma_mean(&self) -> f64 {
1325        self.rolling_sigma_mean
1326    }
1327
1328    /// Whether σ-modulated learning rate is active.
1329    #[inline]
1330    pub fn is_uncertainty_modulated(&self) -> bool {
1331        self.uncertainty_modulated_lr
1332    }
1333
1334    /// Current rolling honest_sigma mean (EWMA of tree contribution std dev).
1335    ///
1336    /// Returns `0.0` before any training samples have been processed.
1337    #[inline]
1338    pub fn rolling_honest_sigma_mean(&self) -> f64 {
1339        self.rolling_honest_sigma_mean
1340    }
1341
1342    // -------------------------------------------------------------------
1343    // Diagnostics
1344    // -------------------------------------------------------------------
1345
1346    /// Full model diagnostics: per-tree structure, feature usage, base predictions.
1347    ///
1348    /// The `trees` vector contains location trees first (indices `0..n_steps`),
1349    /// then scale trees (`n_steps..2*n_steps`).
1350    ///
1351    /// `scale_trees_active` counts how many scale trees have actually split
1352    /// (more than 1 leaf). If this is 0, the scale chain is effectively frozen.
1353    pub fn diagnostics(&self) -> ModelDiagnostics {
1354        let n = self.location_steps.len();
1355        let mut trees = Vec::with_capacity(2 * n);
1356        let mut feature_split_counts: Vec<usize> = Vec::new();
1357
1358        fn collect_tree_diags(
1359            steps: &[BoostingStep],
1360            trees: &mut Vec<TreeDiagnostic>,
1361            feature_split_counts: &mut Vec<usize>,
1362        ) {
1363            for step in steps {
1364                let slot = step.slot();
1365                let tree = slot.active_tree();
1366                let arena = tree.arena();
1367
1368                let leaf_values: Vec<f64> = (0..arena.is_leaf.len())
1369                    .filter(|&i| arena.is_leaf[i])
1370                    .map(|i| arena.leaf_value[i])
1371                    .collect();
1372
1373                let leaf_sample_counts: Vec<u64> = (0..arena.is_leaf.len())
1374                    .filter(|&i| arena.is_leaf[i])
1375                    .map(|i| arena.sample_count[i])
1376                    .collect();
1377
1378                let max_depth_reached = (0..arena.is_leaf.len())
1379                    .filter(|&i| arena.is_leaf[i])
1380                    .map(|i| arena.depth[i] as usize)
1381                    .max()
1382                    .unwrap_or(0);
1383
1384                let leaf_weight_stats = if leaf_values.is_empty() {
1385                    (0.0, 0.0, 0.0, 0.0)
1386                } else {
1387                    let min = leaf_values.iter().cloned().fold(f64::INFINITY, f64::min);
1388                    let max = leaf_values
1389                        .iter()
1390                        .cloned()
1391                        .fold(f64::NEG_INFINITY, f64::max);
1392                    let sum: f64 = leaf_values.iter().sum();
1393                    let mean = sum / leaf_values.len() as f64;
1394                    let var: f64 = leaf_values
1395                        .iter()
1396                        .map(|v| crate::math::powi(v - mean, 2))
1397                        .sum::<f64>()
1398                        / leaf_values.len() as f64;
1399                    (min, max, mean, crate::math::sqrt(var))
1400                };
1401
1402                let gains = slot.split_gains();
1403                let split_features: Vec<usize> = gains
1404                    .iter()
1405                    .enumerate()
1406                    .filter(|(_, &g)| g > 0.0)
1407                    .map(|(i, _)| i)
1408                    .collect();
1409
1410                if !gains.is_empty() {
1411                    if feature_split_counts.is_empty() {
1412                        feature_split_counts.resize(gains.len(), 0);
1413                    }
1414                    for &fi in &split_features {
1415                        if fi < feature_split_counts.len() {
1416                            feature_split_counts[fi] += 1;
1417                        }
1418                    }
1419                }
1420
1421                trees.push(TreeDiagnostic {
1422                    n_leaves: leaf_values.len(),
1423                    max_depth_reached,
1424                    samples_seen: step.n_samples_seen(),
1425                    leaf_weight_stats,
1426                    split_features,
1427                    leaf_sample_counts,
1428                    prediction_mean: slot.prediction_mean(),
1429                    prediction_std: slot.prediction_std(),
1430                });
1431            }
1432        }
1433
1434        collect_tree_diags(&self.location_steps, &mut trees, &mut feature_split_counts);
1435        collect_tree_diags(&self.scale_steps, &mut trees, &mut feature_split_counts);
1436
1437        let location_trees = trees[..n].to_vec();
1438        let scale_trees = trees[n..].to_vec();
1439        let scale_trees_active = scale_trees.iter().filter(|t| t.n_leaves > 1).count();
1440
1441        ModelDiagnostics {
1442            trees,
1443            location_trees,
1444            scale_trees,
1445            feature_split_counts,
1446            location_base: self.location_base,
1447            scale_base: self.scale_base,
1448            empirical_sigma: crate::math::sqrt(self.ewma_sq_err),
1449            scale_mode: self.scale_mode,
1450            scale_trees_active,
1451            auto_bandwidths: self.auto_bandwidths.clone(),
1452            ensemble_grad_mean: self.ensemble_grad_mean,
1453            ensemble_grad_std: self.ensemble_grad_std(),
1454        }
1455    }
1456
1457    /// Per-tree contribution to the final prediction.
1458    ///
1459    /// Returns two vectors: location contributions and scale contributions.
1460    /// Each entry is `learning_rate * tree_prediction` -- the additive
1461    /// contribution of that boosting step to the final μ or log(σ).
1462    ///
1463    /// Summing `location_base + sum(location_contributions)` recovers μ.
1464    /// Summing `scale_base + sum(scale_contributions)` recovers log(σ).
1465    ///
1466    /// In `Empirical` scale mode, `scale_base` is `ln(empirical_sigma)` and
1467    /// `scale_contributions` are all zero (σ is not tree-derived).
1468    pub fn predict_decomposed(&self, features: &[f64]) -> DecomposedPrediction {
1469        let lr = self.config.learning_rate;
1470        let location: Vec<f64> = self
1471            .location_steps
1472            .iter()
1473            .map(|s| lr * s.predict(features))
1474            .collect();
1475
1476        let (sb, scale) = match self.scale_mode {
1477            ScaleMode::Empirical => {
1478                let empirical_sigma = crate::math::sqrt(self.ewma_sq_err).max(1e-8);
1479                (
1480                    crate::math::ln(empirical_sigma),
1481                    vec![0.0; self.location_steps.len()],
1482                )
1483            }
1484            ScaleMode::TreeChain => {
1485                let s: Vec<f64> = self
1486                    .scale_steps
1487                    .iter()
1488                    .map(|s| lr * s.predict(features))
1489                    .collect();
1490                (self.scale_base, s)
1491            }
1492        };
1493
1494        DecomposedPrediction {
1495            location_base: self.location_base,
1496            scale_base: sb,
1497            location_contributions: location,
1498            scale_contributions: scale,
1499        }
1500    }
1501
1502    /// Feature importances based on accumulated split gains across all trees.
1503    ///
1504    /// Aggregates gains from both location and scale ensembles, then
1505    /// normalizes to sum to 1.0. Indexed by feature.
1506    /// Returns an empty Vec if no splits have occurred yet.
1507    pub fn feature_importances(&self) -> Vec<f64> {
1508        let mut totals: Vec<f64> = Vec::new();
1509        for steps in [&self.location_steps, &self.scale_steps] {
1510            for step in steps {
1511                let gains = step.slot().split_gains();
1512                if totals.is_empty() && !gains.is_empty() {
1513                    totals.resize(gains.len(), 0.0);
1514                }
1515                for (i, &g) in gains.iter().enumerate() {
1516                    if i < totals.len() {
1517                        totals[i] += g;
1518                    }
1519                }
1520            }
1521        }
1522        let sum: f64 = totals.iter().sum();
1523        if sum > 0.0 {
1524            totals.iter_mut().for_each(|v| *v /= sum);
1525        }
1526        totals
1527    }
1528
1529    /// Feature importances split by ensemble: `(location_importances, scale_importances)`.
1530    ///
1531    /// Each vector is independently normalized to sum to 1.0.
1532    /// Useful for understanding which features drive the mean vs. the uncertainty.
1533    pub fn feature_importances_split(&self) -> (Vec<f64>, Vec<f64>) {
1534        fn aggregate(steps: &[BoostingStep]) -> Vec<f64> {
1535            let mut totals: Vec<f64> = Vec::new();
1536            for step in steps {
1537                let gains = step.slot().split_gains();
1538                if totals.is_empty() && !gains.is_empty() {
1539                    totals.resize(gains.len(), 0.0);
1540                }
1541                for (i, &g) in gains.iter().enumerate() {
1542                    if i < totals.len() {
1543                        totals[i] += g;
1544                    }
1545                }
1546            }
1547            let sum: f64 = totals.iter().sum();
1548            if sum > 0.0 {
1549                totals.iter_mut().for_each(|v| *v /= sum);
1550            }
1551            totals
1552        }
1553        (
1554            aggregate(&self.location_steps),
1555            aggregate(&self.scale_steps),
1556        )
1557    }
1558
1559    /// Convert this model into a serializable [`crate::serde_support::DistributionalModelState`].
1560    ///
1561    /// Captures the full ensemble state (both location and scale trees) for
1562    /// persistence. Histogram accumulators are NOT serialized -- they rebuild
1563    /// naturally from continued training.
1564    #[cfg(feature = "_serde_support")]
1565    pub fn to_distributional_state(&self) -> crate::serde_support::DistributionalModelState {
1566        use super::snapshot_tree;
1567        use crate::serde_support::{DistributionalModelState, StepSnapshot};
1568
1569        fn snapshot_step(step: &BoostingStep) -> StepSnapshot {
1570            let slot = step.slot();
1571            let tree_snap = snapshot_tree(slot.active_tree());
1572            let alt_snap = slot.alternate_tree().map(snapshot_tree);
1573            let drift_state = slot.detector().serialize_state();
1574            let alt_drift_state = slot.alt_detector().and_then(|d| d.serialize_state());
1575            StepSnapshot {
1576                tree: tree_snap,
1577                alternate_tree: alt_snap,
1578                drift_state,
1579                alt_drift_state,
1580            }
1581        }
1582
1583        DistributionalModelState {
1584            config: self.config.clone(),
1585            location_steps: self.location_steps.iter().map(snapshot_step).collect(),
1586            scale_steps: self.scale_steps.iter().map(snapshot_step).collect(),
1587            location_base: self.location_base,
1588            scale_base: self.scale_base,
1589            base_initialized: self.base_initialized,
1590            initial_targets: self.initial_targets.clone(),
1591            initial_target_count: self.initial_target_count,
1592            samples_seen: self.samples_seen,
1593            rng_state: self.rng_state,
1594            uncertainty_modulated_lr: self.uncertainty_modulated_lr,
1595            rolling_sigma_mean: self.rolling_sigma_mean,
1596            ewma_sq_err: self.ewma_sq_err,
1597            rolling_honest_sigma_mean: self.rolling_honest_sigma_mean,
1598        }
1599    }
1600
1601    /// Reconstruct a [`DistributionalSGBT`] from a serialized [`crate::serde_support::DistributionalModelState`].
1602    ///
1603    /// Rebuilds both location and scale ensembles including tree topology
1604    /// and leaf values. Histogram accumulators are left empty and will
1605    /// rebuild from continued training.
1606    #[cfg(feature = "_serde_support")]
1607    pub fn from_distributional_state(
1608        state: crate::serde_support::DistributionalModelState,
1609    ) -> Self {
1610        use super::rebuild_tree;
1611        use crate::ensemble::replacement::TreeSlot;
1612        use crate::serde_support::StepSnapshot;
1613
1614        let leaf_decay_alpha = state
1615            .config
1616            .leaf_half_life
1617            .map(|hl| crate::math::exp((-(crate::math::ln(2.0_f64)) / hl as f64)));
1618        let max_tree_samples = state.config.max_tree_samples;
1619
1620        let base_tree_config = TreeConfig::new()
1621            .max_depth(state.config.max_depth)
1622            .n_bins(state.config.n_bins)
1623            .lambda(state.config.lambda)
1624            .gamma(state.config.gamma)
1625            .grace_period(state.config.grace_period)
1626            .delta(state.config.delta)
1627            .feature_subsample_rate(state.config.feature_subsample_rate)
1628            .leaf_decay_alpha_opt(leaf_decay_alpha)
1629            .split_reeval_interval_opt(state.config.split_reeval_interval)
1630            .feature_types_opt(state.config.feature_types.clone())
1631            .gradient_clip_sigma_opt(state.config.gradient_clip_sigma)
1632            .monotone_constraints_opt(state.config.monotone_constraints.clone())
1633            .adaptive_depth_opt(state.config.adaptive_depth)
1634            .leaf_model_type(state.config.leaf_model_type.clone());
1635
1636        // Rebuild a Vec<BoostingStep> from step snapshots with a given seed transform.
1637        let rebuild_steps = |snaps: &[StepSnapshot], seed_xor: u64| -> Vec<BoostingStep> {
1638            snaps
1639                .iter()
1640                .enumerate()
1641                .map(|(i, snap)| {
1642                    let tc = base_tree_config
1643                        .clone()
1644                        .seed(state.config.seed ^ (i as u64) ^ seed_xor);
1645
1646                    let active = rebuild_tree(&snap.tree, tc.clone());
1647                    let alternate = snap
1648                        .alternate_tree
1649                        .as_ref()
1650                        .map(|s| rebuild_tree(s, tc.clone()));
1651
1652                    let mut detector = state.config.drift_detector.create();
1653                    if let Some(ref ds) = snap.drift_state {
1654                        detector.restore_state(ds);
1655                    }
1656                    let mut slot =
1657                        TreeSlot::from_trees(active, alternate, tc, detector, max_tree_samples);
1658                    if let Some(ref ads) = snap.alt_drift_state {
1659                        if let Some(alt_det) = slot.alt_detector_mut() {
1660                            alt_det.restore_state(ads);
1661                        }
1662                    }
1663                    BoostingStep::from_slot(slot)
1664                })
1665                .collect()
1666        };
1667
1668        // Location: seed offset 0, Scale: seed offset 0x0005_CA1E_0000_0000
1669        let location_steps = rebuild_steps(&state.location_steps, 0);
1670        let scale_steps = rebuild_steps(&state.scale_steps, 0x0005_CA1E_0000_0000);
1671
1672        let scale_mode = state.config.scale_mode;
1673        let empirical_sigma_alpha = state.config.empirical_sigma_alpha;
1674        let packed_refresh_interval = state.config.packed_refresh_interval;
1675        Self {
1676            config: state.config,
1677            location_steps,
1678            scale_steps,
1679            location_base: state.location_base,
1680            scale_base: state.scale_base,
1681            base_initialized: state.base_initialized,
1682            initial_targets: state.initial_targets,
1683            initial_target_count: state.initial_target_count,
1684            samples_seen: state.samples_seen,
1685            rng_state: state.rng_state,
1686            uncertainty_modulated_lr: state.uncertainty_modulated_lr,
1687            rolling_sigma_mean: state.rolling_sigma_mean,
1688            scale_mode,
1689            ewma_sq_err: state.ewma_sq_err,
1690            empirical_sigma_alpha,
1691            prev_sigma: 0.0,
1692            sigma_velocity: 0.0,
1693            auto_bandwidths: Vec::new(),
1694            last_replacement_sum: 0,
1695            ensemble_grad_mean: 0.0,
1696            ensemble_grad_m2: 0.0,
1697            ensemble_grad_count: 0,
1698            rolling_honest_sigma_mean: state.rolling_honest_sigma_mean,
1699            packed_cache: None,
1700            samples_since_refresh: 0,
1701            packed_refresh_interval,
1702        }
1703    }
1704}
1705
1706// ---------------------------------------------------------------------------
1707// StreamingLearner impl
1708// ---------------------------------------------------------------------------
1709
1710use crate::learner::StreamingLearner;
1711
1712impl StreamingLearner for DistributionalSGBT {
1713    fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
1714        let sample = SampleRef::weighted(features, target, weight);
1715        // UFCS: call the inherent train_one(&impl Observation), not this trait method.
1716        DistributionalSGBT::train_one(self, &sample);
1717    }
1718
1719    /// Returns the mean (μ) of the predicted Gaussian distribution.
1720    fn predict(&self, features: &[f64]) -> f64 {
1721        DistributionalSGBT::predict(self, features).mu
1722    }
1723
1724    fn n_samples_seen(&self) -> u64 {
1725        self.samples_seen
1726    }
1727
1728    fn reset(&mut self) {
1729        DistributionalSGBT::reset(self);
1730    }
1731}
1732
1733// ---------------------------------------------------------------------------
1734// Tests
1735// ---------------------------------------------------------------------------
1736
1737#[cfg(test)]
1738mod tests {
1739    use super::*;
1740    use alloc::format;
1741    use alloc::vec;
1742    use alloc::vec::Vec;
1743
1744    fn test_config() -> SGBTConfig {
1745        SGBTConfig::builder()
1746            .n_steps(10)
1747            .learning_rate(0.1)
1748            .grace_period(20)
1749            .max_depth(4)
1750            .n_bins(16)
1751            .initial_target_count(10)
1752            .build()
1753            .unwrap()
1754    }
1755
1756    #[test]
1757    fn fresh_model_predicts_zero() {
1758        let model = DistributionalSGBT::new(test_config());
1759        let pred = model.predict(&[1.0, 2.0, 3.0]);
1760        assert!(pred.mu.abs() < 1e-12);
1761        assert!(pred.sigma > 0.0);
1762    }
1763
1764    #[test]
1765    fn sigma_always_positive() {
1766        let mut model = DistributionalSGBT::new(test_config());
1767
1768        // Train on various data
1769        for i in 0..200 {
1770            let x = i as f64 * 0.1;
1771            model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
1772        }
1773
1774        // Check multiple predictions
1775        for i in 0..20 {
1776            let x = i as f64 * 0.5;
1777            let pred = model.predict(&[x, x * 0.5]);
1778            assert!(
1779                pred.sigma > 0.0,
1780                "sigma must be positive, got {}",
1781                pred.sigma
1782            );
1783            assert!(pred.sigma.is_finite(), "sigma must be finite");
1784        }
1785    }
1786
1787    #[test]
1788    fn constant_target_has_small_sigma() {
1789        let mut model = DistributionalSGBT::new(test_config());
1790
1791        // Train on constant target = 5.0 with varying features
1792        for i in 0..200 {
1793            let x = i as f64 * 0.1;
1794            model.train_one(&(vec![x, x * 2.0], 5.0));
1795        }
1796
1797        let pred = model.predict(&[1.0, 2.0]);
1798        assert!(pred.mu.is_finite());
1799        assert!(pred.sigma.is_finite());
1800        assert!(pred.sigma > 0.0);
1801        // Sigma should be relatively small for constant target
1802        // (no noise to explain)
1803    }
1804
1805    #[test]
1806    fn noisy_target_has_finite_predictions() {
1807        let mut model = DistributionalSGBT::new(test_config());
1808
1809        // Simple xorshift for deterministic "noise"
1810        let mut rng: u64 = 42;
1811        for i in 0..200 {
1812            rng ^= rng << 13;
1813            rng ^= rng >> 7;
1814            rng ^= rng << 17;
1815            let noise = (rng % 1000) as f64 / 500.0 - 1.0; // [-1, 1]
1816            let x = i as f64 * 0.1;
1817            model.train_one(&(vec![x], x * 2.0 + noise));
1818        }
1819
1820        let pred = model.predict(&[5.0]);
1821        assert!(pred.mu.is_finite());
1822        assert!(pred.sigma.is_finite());
1823        assert!(pred.sigma > 0.0);
1824    }
1825
1826    #[test]
1827    fn predict_interval_bounds_correct() {
1828        let mut model = DistributionalSGBT::new(test_config());
1829
1830        for i in 0..200 {
1831            let x = i as f64 * 0.1;
1832            model.train_one(&(vec![x], x * 2.0));
1833        }
1834
1835        let (lo, hi) = model.predict_interval(&[5.0], 1.96);
1836        let pred = model.predict(&[5.0]);
1837
1838        assert!(lo < pred.mu, "lower bound should be < mu");
1839        assert!(hi > pred.mu, "upper bound should be > mu");
1840        assert!((hi - lo - 2.0 * 1.96 * pred.sigma).abs() < 1e-10);
1841    }
1842
1843    #[test]
1844    fn batch_prediction_matches_individual() {
1845        let mut model = DistributionalSGBT::new(test_config());
1846
1847        for i in 0..100 {
1848            let x = i as f64 * 0.1;
1849            model.train_one(&(vec![x, x * 2.0], x));
1850        }
1851
1852        let features = vec![vec![1.0, 2.0], vec![3.0, 6.0], vec![5.0, 10.0]];
1853        let batch = model.predict_batch(&features);
1854
1855        for (feat, batch_pred) in features.iter().zip(batch.iter()) {
1856            let individual = model.predict(feat);
1857            assert!((batch_pred.mu - individual.mu).abs() < 1e-12);
1858            assert!((batch_pred.sigma - individual.sigma).abs() < 1e-12);
1859        }
1860    }
1861
1862    #[test]
1863    fn reset_clears_state() {
1864        let mut model = DistributionalSGBT::new(test_config());
1865
1866        for i in 0..200 {
1867            let x = i as f64 * 0.1;
1868            model.train_one(&(vec![x], x * 2.0));
1869        }
1870
1871        assert!(model.n_samples_seen() > 0);
1872        model.reset();
1873
1874        assert_eq!(model.n_samples_seen(), 0);
1875        assert!(!model.is_initialized());
1876    }
1877
1878    #[test]
1879    fn gaussian_prediction_lower_upper() {
1880        let pred = GaussianPrediction {
1881            mu: 10.0,
1882            sigma: 2.0,
1883            log_sigma: 2.0_f64.ln(),
1884            honest_sigma: 0.0,
1885        };
1886
1887        assert!((pred.lower(1.96) - (10.0 - 1.96 * 2.0)).abs() < 1e-10);
1888        assert!((pred.upper(1.96) - (10.0 + 1.96 * 2.0)).abs() < 1e-10);
1889    }
1890
1891    #[test]
1892    fn train_batch_works() {
1893        let mut model = DistributionalSGBT::new(test_config());
1894        let samples: Vec<(Vec<f64>, f64)> = (0..100)
1895            .map(|i| {
1896                let x = i as f64 * 0.1;
1897                (vec![x], x * 2.0)
1898            })
1899            .collect();
1900
1901        model.train_batch(&samples);
1902        assert_eq!(model.n_samples_seen(), 100);
1903    }
1904
1905    #[test]
1906    fn debug_format_works() {
1907        let model = DistributionalSGBT::new(test_config());
1908        let debug = format!("{:?}", model);
1909        assert!(debug.contains("DistributionalSGBT"));
1910    }
1911
1912    #[test]
1913    fn n_trees_counts_both_ensembles() {
1914        let model = DistributionalSGBT::new(test_config());
1915        // 10 location + 10 scale = 20 trees minimum
1916        assert!(model.n_trees() >= 20);
1917    }
1918
1919    // -- σ-modulated learning rate tests --
1920
1921    fn modulated_config() -> SGBTConfig {
1922        SGBTConfig::builder()
1923            .n_steps(10)
1924            .learning_rate(0.1)
1925            .grace_period(20)
1926            .max_depth(4)
1927            .n_bins(16)
1928            .initial_target_count(10)
1929            .uncertainty_modulated_lr(true)
1930            .build()
1931            .unwrap()
1932    }
1933
1934    #[test]
1935    fn sigma_modulated_initializes_rolling_mean() {
1936        let mut model = DistributionalSGBT::new(modulated_config());
1937        assert!(model.is_uncertainty_modulated());
1938
1939        // Before initialization, rolling_sigma_mean is 1.0 (placeholder)
1940        assert!((model.rolling_sigma_mean() - 1.0).abs() < 1e-12);
1941
1942        // Train past initialization
1943        for i in 0..200 {
1944            let x = i as f64 * 0.1;
1945            model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
1946        }
1947
1948        // After training, rolling_sigma_mean should have adapted from initial std
1949        assert!(model.rolling_sigma_mean() > 0.0);
1950        assert!(model.rolling_sigma_mean().is_finite());
1951    }
1952
1953    #[test]
1954    fn predict_distributional_returns_sigma_ratio() {
1955        let mut model = DistributionalSGBT::new(modulated_config());
1956
1957        for i in 0..200 {
1958            let x = i as f64 * 0.1;
1959            model.train_one(&(vec![x], x * 2.0 + 1.0));
1960        }
1961
1962        let (mu, sigma, sigma_ratio) = model.predict_distributional(&[5.0]);
1963        assert!(mu.is_finite());
1964        assert!(sigma > 0.0);
1965        assert!(
1966            (0.1..=10.0).contains(&sigma_ratio),
1967            "sigma_ratio={}",
1968            sigma_ratio
1969        );
1970    }
1971
1972    #[test]
1973    fn predict_distributional_without_modulation_returns_one() {
1974        let mut model = DistributionalSGBT::new(test_config());
1975        assert!(!model.is_uncertainty_modulated());
1976
1977        for i in 0..200 {
1978            let x = i as f64 * 0.1;
1979            model.train_one(&(vec![x], x * 2.0));
1980        }
1981
1982        let (_mu, _sigma, sigma_ratio) = model.predict_distributional(&[5.0]);
1983        assert!(
1984            (sigma_ratio - 1.0).abs() < 1e-12,
1985            "should be 1.0 when disabled"
1986        );
1987    }
1988
1989    #[test]
1990    fn modulated_model_sigma_finite_under_varying_noise() {
1991        let mut model = DistributionalSGBT::new(modulated_config());
1992
1993        let mut rng: u64 = 123;
1994        for i in 0..500 {
1995            rng ^= rng << 13;
1996            rng ^= rng >> 7;
1997            rng ^= rng << 17;
1998            let noise = (rng % 1000) as f64 / 100.0 - 5.0; // [-5, 5]
1999            let x = i as f64 * 0.1;
2000            // Regime shift at i=250: noise amplitude increases
2001            let scale = if i < 250 { 1.0 } else { 5.0 };
2002            model.train_one(&(vec![x], x * 2.0 + noise * scale));
2003        }
2004
2005        let pred = model.predict(&[10.0]);
2006        assert!(pred.mu.is_finite());
2007        assert!(pred.sigma.is_finite());
2008        assert!(pred.sigma > 0.0);
2009        assert!(model.rolling_sigma_mean().is_finite());
2010    }
2011
2012    #[test]
2013    fn reset_clears_rolling_sigma_mean() {
2014        let mut model = DistributionalSGBT::new(modulated_config());
2015
2016        for i in 0..200 {
2017            let x = i as f64 * 0.1;
2018            model.train_one(&(vec![x], x * 2.0));
2019        }
2020
2021        let sigma_before = model.rolling_sigma_mean();
2022        assert!(sigma_before > 0.0);
2023
2024        model.reset();
2025        assert!((model.rolling_sigma_mean() - 1.0).abs() < 1e-12);
2026    }
2027
2028    #[test]
2029    fn streaming_learner_returns_mu() {
2030        let mut model = DistributionalSGBT::new(test_config());
2031        for i in 0..200 {
2032            let x = i as f64 * 0.1;
2033            StreamingLearner::train(&mut model, &[x], x * 2.0 + 1.0);
2034        }
2035        let pred = StreamingLearner::predict(&model, &[5.0]);
2036        let gaussian = DistributionalSGBT::predict(&model, &[5.0]);
2037        assert!(
2038            (pred - gaussian.mu).abs() < 1e-12,
2039            "StreamingLearner::predict should return mu"
2040        );
2041    }
2042
2043    // -- Diagnostics tests --
2044
2045    fn trained_model() -> DistributionalSGBT {
2046        let config = SGBTConfig::builder()
2047            .n_steps(10)
2048            .learning_rate(0.1)
2049            .grace_period(10) // low grace period to ensure splits
2050            .max_depth(4)
2051            .n_bins(16)
2052            .initial_target_count(10)
2053            .build()
2054            .unwrap();
2055        let mut model = DistributionalSGBT::new(config);
2056        for i in 0..500 {
2057            let x = i as f64 * 0.1;
2058            model.train_one(&(vec![x, x * 0.5, (i % 3) as f64], x * 2.0 + 1.0));
2059        }
2060        model
2061    }
2062
2063    #[test]
2064    fn diagnostics_returns_correct_tree_count() {
2065        let model = trained_model();
2066        let diag = model.diagnostics();
2067        // 10 location + 10 scale = 20 trees
2068        assert_eq!(diag.trees.len(), 20, "should have 2*n_steps trees");
2069    }
2070
2071    #[test]
2072    fn diagnostics_trees_have_leaves() {
2073        let model = trained_model();
2074        let diag = model.diagnostics();
2075        for (i, tree) in diag.trees.iter().enumerate() {
2076            assert!(tree.n_leaves >= 1, "tree {i} should have at least 1 leaf");
2077        }
2078        // At least some trees should have seen samples.
2079        let total_samples: u64 = diag.trees.iter().map(|t| t.samples_seen).sum();
2080        assert!(
2081            total_samples > 0,
2082            "at least some trees should have seen samples"
2083        );
2084    }
2085
2086    #[test]
2087    fn diagnostics_leaf_weight_stats_finite() {
2088        let model = trained_model();
2089        let diag = model.diagnostics();
2090        for (i, tree) in diag.trees.iter().enumerate() {
2091            let (min, max, mean, std) = tree.leaf_weight_stats;
2092            assert!(min.is_finite(), "tree {i} min not finite");
2093            assert!(max.is_finite(), "tree {i} max not finite");
2094            assert!(mean.is_finite(), "tree {i} mean not finite");
2095            assert!(std.is_finite(), "tree {i} std not finite");
2096            assert!(min <= max, "tree {i} min > max");
2097        }
2098    }
2099
2100    #[test]
2101    fn diagnostics_base_predictions_match() {
2102        let model = trained_model();
2103        let diag = model.diagnostics();
2104        assert!(
2105            (diag.location_base - model.predict(&[0.0, 0.0, 0.0]).mu).abs() < 100.0,
2106            "location_base should be plausible"
2107        );
2108    }
2109
2110    #[test]
2111    fn predict_decomposed_reconstructs_prediction() {
2112        let model = trained_model();
2113        let features = [5.0, 2.5, 1.0];
2114        let pred = model.predict(&features);
2115        let decomp = model.predict_decomposed(&features);
2116
2117        assert!(
2118            (decomp.mu() - pred.mu).abs() < 1e-10,
2119            "decomposed mu ({}) != predict mu ({})",
2120            decomp.mu(),
2121            pred.mu
2122        );
2123        assert!(
2124            (decomp.sigma() - pred.sigma).abs() < 1e-10,
2125            "decomposed sigma ({}) != predict sigma ({})",
2126            decomp.sigma(),
2127            pred.sigma
2128        );
2129    }
2130
2131    #[test]
2132    fn predict_decomposed_correct_lengths() {
2133        let model = trained_model();
2134        let decomp = model.predict_decomposed(&[1.0, 0.5, 0.0]);
2135        assert_eq!(
2136            decomp.location_contributions.len(),
2137            model.n_steps(),
2138            "location contributions should have n_steps entries"
2139        );
2140        assert_eq!(
2141            decomp.scale_contributions.len(),
2142            model.n_steps(),
2143            "scale contributions should have n_steps entries"
2144        );
2145    }
2146
2147    #[test]
2148    fn feature_importances_work() {
2149        let model = trained_model();
2150        let imp = model.feature_importances();
2151        // After enough training, importances should be non-empty and non-negative.
2152        // They may or may not sum to 1.0 if no splits have occurred yet.
2153        for (i, &v) in imp.iter().enumerate() {
2154            assert!(v >= 0.0, "importance {i} should be non-negative, got {v}");
2155            assert!(v.is_finite(), "importance {i} should be finite");
2156        }
2157        let sum: f64 = imp.iter().sum();
2158        if sum > 0.0 {
2159            assert!(
2160                (sum - 1.0).abs() < 1e-10,
2161                "non-zero importances should sum to 1.0, got {sum}"
2162            );
2163        }
2164    }
2165
2166    #[test]
2167    fn feature_importances_split_works() {
2168        let model = trained_model();
2169        let (loc_imp, scale_imp) = model.feature_importances_split();
2170        for (name, imp) in [("location", &loc_imp), ("scale", &scale_imp)] {
2171            let sum: f64 = imp.iter().sum();
2172            if sum > 0.0 {
2173                assert!(
2174                    (sum - 1.0).abs() < 1e-10,
2175                    "{name} importances should sum to 1.0, got {sum}"
2176                );
2177            }
2178            for &v in imp.iter() {
2179                assert!(v >= 0.0 && v.is_finite());
2180            }
2181        }
2182    }
2183
2184    // -- Empirical σ tests --
2185
2186    #[test]
2187    fn empirical_sigma_default_mode() {
2188        use crate::ensemble::config::ScaleMode;
2189        let config = test_config();
2190        let model = DistributionalSGBT::new(config);
2191        assert_eq!(model.scale_mode(), ScaleMode::Empirical);
2192    }
2193
2194    #[test]
2195    fn empirical_sigma_tracks_errors() {
2196        let mut model = DistributionalSGBT::new(test_config());
2197
2198        // Train on clean linear data
2199        for i in 0..200 {
2200            let x = i as f64 * 0.1;
2201            model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
2202        }
2203
2204        let sigma_clean = model.empirical_sigma();
2205        assert!(sigma_clean > 0.0, "sigma should be positive");
2206        assert!(sigma_clean.is_finite(), "sigma should be finite");
2207
2208        // Now train on noisy data — sigma should increase
2209        let mut rng: u64 = 42;
2210        for i in 200..400 {
2211            rng ^= rng << 13;
2212            rng ^= rng >> 7;
2213            rng ^= rng << 17;
2214            let noise = (rng % 10000) as f64 / 100.0 - 50.0; // big noise
2215            let x = i as f64 * 0.1;
2216            model.train_one(&(vec![x, x * 0.5], x * 2.0 + noise));
2217        }
2218
2219        let sigma_noisy = model.empirical_sigma();
2220        assert!(
2221            sigma_noisy > sigma_clean,
2222            "noisy regime should increase sigma: clean={sigma_clean}, noisy={sigma_noisy}"
2223        );
2224    }
2225
2226    #[test]
2227    fn empirical_sigma_modulated_lr_adapts() {
2228        let config = SGBTConfig::builder()
2229            .n_steps(10)
2230            .learning_rate(0.1)
2231            .grace_period(20)
2232            .max_depth(4)
2233            .n_bins(16)
2234            .initial_target_count(10)
2235            .uncertainty_modulated_lr(true)
2236            .build()
2237            .unwrap();
2238        let mut model = DistributionalSGBT::new(config);
2239
2240        // Train and verify sigma_ratio changes
2241        for i in 0..300 {
2242            let x = i as f64 * 0.1;
2243            model.train_one(&(vec![x], x * 2.0 + 1.0));
2244        }
2245
2246        let (_, _, sigma_ratio) = model.predict_distributional(&[5.0]);
2247        assert!(sigma_ratio.is_finite());
2248        assert!(
2249            (0.1..=10.0).contains(&sigma_ratio),
2250            "sigma_ratio={sigma_ratio}"
2251        );
2252    }
2253
2254    #[test]
2255    fn tree_chain_mode_trains_scale_trees() {
2256        use crate::ensemble::config::ScaleMode;
2257        let config = SGBTConfig::builder()
2258            .n_steps(10)
2259            .learning_rate(0.1)
2260            .grace_period(10)
2261            .max_depth(4)
2262            .n_bins(16)
2263            .initial_target_count(10)
2264            .scale_mode(ScaleMode::TreeChain)
2265            .build()
2266            .unwrap();
2267        let mut model = DistributionalSGBT::new(config);
2268        assert_eq!(model.scale_mode(), ScaleMode::TreeChain);
2269
2270        for i in 0..500 {
2271            let x = i as f64 * 0.1;
2272            model.train_one(&(vec![x, x * 0.5, (i % 3) as f64], x * 2.0 + 1.0));
2273        }
2274
2275        let pred = model.predict(&[5.0, 2.5, 1.0]);
2276        assert!(pred.mu.is_finite());
2277        assert!(pred.sigma > 0.0);
2278        assert!(pred.sigma.is_finite());
2279    }
2280
2281    #[test]
2282    fn diagnostics_shows_empirical_sigma() {
2283        let model = trained_model();
2284        let diag = model.diagnostics();
2285        assert!(
2286            diag.empirical_sigma > 0.0,
2287            "empirical_sigma should be positive"
2288        );
2289        assert!(
2290            diag.empirical_sigma.is_finite(),
2291            "empirical_sigma should be finite"
2292        );
2293    }
2294
2295    #[test]
2296    fn diagnostics_scale_trees_split_fields() {
2297        let model = trained_model();
2298        let diag = model.diagnostics();
2299        assert_eq!(diag.location_trees.len(), model.n_steps());
2300        assert_eq!(diag.scale_trees.len(), model.n_steps());
2301        // In empirical mode, scale_trees_active might be 0 (trees not trained)
2302        // This is expected and actually the point.
2303    }
2304
2305    #[test]
2306    fn reset_clears_empirical_sigma() {
2307        let mut model = DistributionalSGBT::new(test_config());
2308        for i in 0..200 {
2309            let x = i as f64 * 0.1;
2310            model.train_one(&(vec![x], x * 2.0));
2311        }
2312        model.reset();
2313        // After reset, ewma_sq_err resets to 1.0
2314        assert!((model.empirical_sigma() - 1.0).abs() < 1e-12);
2315    }
2316
2317    #[test]
2318    fn predict_smooth_returns_finite() {
2319        let config = SGBTConfig::builder()
2320            .n_steps(3)
2321            .learning_rate(0.1)
2322            .grace_period(20)
2323            .max_depth(4)
2324            .n_bins(16)
2325            .initial_target_count(10)
2326            .build()
2327            .unwrap();
2328        let mut model = DistributionalSGBT::new(config);
2329
2330        for i in 0..200 {
2331            let x = (i as f64) * 0.1;
2332            let features = vec![x, x.sin()];
2333            let target = 2.0 * x + 1.0;
2334            model.train_one(&(features, target));
2335        }
2336
2337        let pred = model.predict_smooth(&[1.0, 1.0_f64.sin()], 0.5);
2338        assert!(pred.mu.is_finite(), "smooth mu should be finite");
2339        assert!(pred.sigma.is_finite(), "smooth sigma should be finite");
2340        assert!(pred.sigma > 0.0, "smooth sigma should be positive");
2341    }
2342
2343    // -- PD sigma modulation tests --
2344
2345    #[test]
2346    fn sigma_velocity_responds_to_error_spike() {
2347        let config = SGBTConfig::builder()
2348            .n_steps(3)
2349            .learning_rate(0.1)
2350            .grace_period(20)
2351            .max_depth(4)
2352            .n_bins(16)
2353            .initial_target_count(10)
2354            .uncertainty_modulated_lr(true)
2355            .build()
2356            .unwrap();
2357        let mut model = DistributionalSGBT::new(config);
2358
2359        // Stable phase
2360        for i in 0..200 {
2361            let x = (i as f64) * 0.1;
2362            model.train_one(&(vec![x, x.sin()], 2.0 * x + 1.0));
2363        }
2364
2365        let velocity_before = model.sigma_velocity();
2366
2367        // Error spike -- sudden regime change
2368        for i in 0..50 {
2369            let x = (i as f64) * 0.1;
2370            model.train_one(&(vec![x, x.sin()], 100.0 * x + 50.0));
2371        }
2372
2373        let velocity_after = model.sigma_velocity();
2374
2375        // Velocity should be positive (sigma increasing due to errors)
2376        assert!(
2377            velocity_after > velocity_before,
2378            "sigma velocity should increase after error spike: before={}, after={}",
2379            velocity_before,
2380            velocity_after,
2381        );
2382    }
2383
2384    #[test]
2385    fn sigma_velocity_getter_works() {
2386        let config = SGBTConfig::builder()
2387            .n_steps(2)
2388            .learning_rate(0.1)
2389            .grace_period(20)
2390            .max_depth(4)
2391            .n_bins(16)
2392            .initial_target_count(10)
2393            .build()
2394            .unwrap();
2395        let model = DistributionalSGBT::new(config);
2396        // Fresh model should have zero velocity
2397        assert_eq!(model.sigma_velocity(), 0.0);
2398    }
2399
2400    #[test]
2401    fn diagnostics_leaf_sample_counts_populated() {
2402        let config = SGBTConfig::builder()
2403            .n_steps(3)
2404            .learning_rate(0.1)
2405            .grace_period(10)
2406            .max_depth(4)
2407            .n_bins(16)
2408            .initial_target_count(10)
2409            .build()
2410            .unwrap();
2411        let mut model = DistributionalSGBT::new(config);
2412
2413        for i in 0..200 {
2414            let x = (i as f64) * 0.1;
2415            let features = vec![x, x.sin()];
2416            let target = 2.0 * x + 1.0;
2417            model.train_one(&(features, target));
2418        }
2419
2420        let diags = model.diagnostics();
2421        for (ti, tree) in diags.trees.iter().enumerate() {
2422            assert_eq!(
2423                tree.leaf_sample_counts.len(),
2424                tree.n_leaves,
2425                "tree {} should have sample count per leaf",
2426                ti,
2427            );
2428            // Total samples across leaves should equal samples_seen for trees with data
2429            if tree.samples_seen > 0 {
2430                let total: u64 = tree.leaf_sample_counts.iter().sum();
2431                assert!(
2432                    total > 0,
2433                    "tree {} has {} samples_seen but leaf counts sum to 0",
2434                    ti,
2435                    tree.samples_seen,
2436                );
2437            }
2438        }
2439    }
2440
2441    // -------------------------------------------------------------------
2442    // Packed cache tests (dual-path inference)
2443    // -------------------------------------------------------------------
2444
2445    #[test]
2446    fn packed_cache_disabled_by_default() {
2447        let model = DistributionalSGBT::new(test_config());
2448        assert!(!model.has_packed_cache());
2449        assert_eq!(model.config().packed_refresh_interval, 0);
2450    }
2451
2452    #[test]
2453    #[cfg(feature = "_packed_cache_tests_disabled")]
2454    fn packed_cache_refreshes_after_interval() {
2455        let config = SGBTConfig::builder()
2456            .n_steps(5)
2457            .learning_rate(0.1)
2458            .grace_period(5)
2459            .max_depth(3)
2460            .n_bins(8)
2461            .initial_target_count(10)
2462            .packed_refresh_interval(20)
2463            .build()
2464            .unwrap();
2465
2466        let mut model = DistributionalSGBT::new(config);
2467
2468        // Train past initialization (10 samples) + refresh interval (20 samples)
2469        for i in 0..40 {
2470            let x = i as f64 * 0.1;
2471            model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
2472        }
2473
2474        // Cache should exist after enough training
2475        assert!(
2476            model.has_packed_cache(),
2477            "packed cache should exist after training past refresh interval"
2478        );
2479
2480        // Predictions should be finite
2481        let pred = model.predict(&[2.0, 4.0, 1.0]);
2482        assert!(pred.mu.is_finite(), "mu should be finite: {}", pred.mu);
2483    }
2484
2485    #[test]
2486    #[cfg(feature = "_packed_cache_tests_disabled")]
2487    fn packed_cache_matches_full_tree() {
2488        let config = SGBTConfig::builder()
2489            .n_steps(5)
2490            .learning_rate(0.1)
2491            .grace_period(5)
2492            .max_depth(3)
2493            .n_bins(8)
2494            .initial_target_count(10)
2495            .build()
2496            .unwrap();
2497
2498        let mut model = DistributionalSGBT::new(config);
2499
2500        // Train
2501        for i in 0..80 {
2502            let x = i as f64 * 0.1;
2503            model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
2504        }
2505
2506        // Get full-tree prediction (no cache)
2507        assert!(!model.has_packed_cache());
2508        let full_pred = model.predict(&[2.0, 4.0, 1.0]);
2509
2510        // Enable cache and get cached prediction
2511        model.enable_packed_cache(10);
2512        assert!(model.has_packed_cache());
2513        let cached_pred = model.predict(&[2.0, 4.0, 1.0]);
2514
2515        // Should match within f32 precision (f64->f32 conversion loses some precision)
2516        let mu_diff = (full_pred.mu - cached_pred.mu).abs();
2517        assert!(
2518            mu_diff < 0.1,
2519            "packed cache mu ({}) should match full tree mu ({}) within f32 tolerance, diff={}",
2520            cached_pred.mu,
2521            full_pred.mu,
2522            mu_diff
2523        );
2524
2525        // Sigma should be identical (not affected by packed cache)
2526        assert!(
2527            (full_pred.sigma - cached_pred.sigma).abs() < 1e-12,
2528            "sigma should be identical: full={}, cached={}",
2529            full_pred.sigma,
2530            cached_pred.sigma
2531        );
2532    }
2533
2534    #[test]
2535    fn honest_sigma_in_gaussian_prediction() {
2536        let config = SGBTConfig::builder()
2537            .n_steps(5)
2538            .learning_rate(0.1)
2539            .max_depth(3)
2540            .grace_period(2)
2541            .initial_target_count(10)
2542            .build()
2543            .unwrap();
2544        let mut model = DistributionalSGBT::new(config);
2545        for i in 0..100 {
2546            let x = i as f64 * 0.1;
2547            model.train_one(&(vec![x], x * 2.0));
2548        }
2549
2550        let pred = model.predict(&[5.0]);
2551        // honest_sigma should be finite and non-negative
2552        assert!(
2553            pred.honest_sigma.is_finite(),
2554            "honest_sigma should be finite, got {}",
2555            pred.honest_sigma
2556        );
2557        assert!(
2558            pred.honest_sigma >= 0.0,
2559            "honest_sigma should be >= 0, got {}",
2560            pred.honest_sigma
2561        );
2562    }
2563
2564    #[test]
2565    fn honest_sigma_increases_with_divergence() {
2566        let config = SGBTConfig::builder()
2567            .n_steps(8)
2568            .learning_rate(0.1)
2569            .max_depth(4)
2570            .grace_period(2)
2571            .initial_target_count(10)
2572            .build()
2573            .unwrap();
2574        let mut model = DistributionalSGBT::new(config);
2575
2576        // Train on linear data
2577        for i in 0..200 {
2578            let x = i as f64 * 0.05;
2579            model.train_one(&(vec![x], x * 2.0));
2580        }
2581
2582        // Predict in-distribution
2583        let in_dist = model.predict(&[5.0]);
2584        // Predict very far out of distribution
2585        let out_dist = model.predict(&[500.0]);
2586
2587        // Both should be finite and non-negative
2588        assert!(in_dist.honest_sigma.is_finite());
2589        assert!(out_dist.honest_sigma.is_finite());
2590        assert!(in_dist.honest_sigma >= 0.0);
2591        assert!(out_dist.honest_sigma >= 0.0);
2592
2593        // Out-of-distribution should generally produce different honest_sigma
2594        // (trees will respond differently to unseen regions)
2595        // We can't guarantee strictly greater since all trees might
2596        // extrapolate identically, but at minimum both should be valid.
2597    }
2598
2599    #[test]
2600    fn honest_sigma_zero_for_single_step() {
2601        // With only 1 step, variance of 1 contribution is undefined -> returns 0
2602        let config = SGBTConfig::builder()
2603            .n_steps(1)
2604            .learning_rate(0.1)
2605            .max_depth(3)
2606            .grace_period(2)
2607            .initial_target_count(10)
2608            .build()
2609            .unwrap();
2610        let mut model = DistributionalSGBT::new(config);
2611        for i in 0..100 {
2612            let x = i as f64 * 0.1;
2613            model.train_one(&(vec![x], x * 2.0));
2614        }
2615        let pred = model.predict(&[5.0]);
2616        assert!(
2617            pred.honest_sigma.abs() < 1e-15,
2618            "honest_sigma should be 0 with 1 step, got {}",
2619            pred.honest_sigma
2620        );
2621    }
2622
2623    #[test]
2624    fn distributional_soft_routed_prediction() {
2625        let config = SGBTConfig::builder()
2626            .n_steps(10)
2627            .learning_rate(0.1)
2628            .grace_period(20)
2629            .max_depth(4)
2630            .n_bins(16)
2631            .initial_target_count(10)
2632            .build()
2633            .unwrap();
2634        let mut model = DistributionalSGBT::new(config);
2635
2636        // Train 200 samples
2637        for i in 0..200 {
2638            let x = i as f64 * 0.05;
2639            let y = x * 2.0 + 1.0;
2640            model.train_one(&(vec![x, x * 0.3], y));
2641        }
2642
2643        // Verify soft-routed prediction returns finite values
2644        let pred = model.predict_soft_routed(&[1.0, 0.3]);
2645        assert!(
2646            pred.mu.is_finite(),
2647            "soft-routed mu should be finite, got {}",
2648            pred.mu
2649        );
2650        assert!(
2651            pred.sigma.is_finite() && pred.sigma > 0.0,
2652            "soft-routed sigma should be finite and positive, got {}",
2653            pred.sigma
2654        );
2655        assert!(
2656            pred.log_sigma.is_finite(),
2657            "soft-routed log_sigma should be finite, got {}",
2658            pred.log_sigma
2659        );
2660        assert!(
2661            pred.honest_sigma.is_finite(),
2662            "soft-routed honest_sigma should be finite, got {}",
2663            pred.honest_sigma
2664        );
2665    }
2666}
2667
2668#[cfg(test)]
2669#[cfg(feature = "_serde_support")]
2670mod serde_tests {
2671    use super::*;
2672    use crate::SGBTConfig;
2673
2674    fn make_trained_distributional() -> DistributionalSGBT {
2675        let config = SGBTConfig::builder()
2676            .n_steps(5)
2677            .learning_rate(0.1)
2678            .max_depth(3)
2679            .grace_period(2)
2680            .initial_target_count(10)
2681            .build()
2682            .unwrap();
2683        let mut model = DistributionalSGBT::new(config);
2684        for i in 0..50 {
2685            let x = i as f64 * 0.1;
2686            model.train_one(&(vec![x], x.sin()));
2687        }
2688        model
2689    }
2690
2691    #[test]
2692    fn json_round_trip_preserves_predictions() {
2693        let model = make_trained_distributional();
2694        let state = model.to_distributional_state();
2695        let json = crate::serde_support::save_distributional_model(&state).unwrap();
2696        let loaded_state = crate::serde_support::load_distributional_model(&json).unwrap();
2697        let restored = DistributionalSGBT::from_distributional_state(loaded_state);
2698
2699        let test_points = [0.5, 1.0, 2.0, 3.0];
2700        for &x in &test_points {
2701            let orig = model.predict(&[x]);
2702            let rest = restored.predict(&[x]);
2703            assert!(
2704                (orig.mu - rest.mu).abs() < 1e-10,
2705                "JSON round-trip mu mismatch at x={}: {} vs {}",
2706                x,
2707                orig.mu,
2708                rest.mu
2709            );
2710            assert!(
2711                (orig.sigma - rest.sigma).abs() < 1e-10,
2712                "JSON round-trip sigma mismatch at x={}: {} vs {}",
2713                x,
2714                orig.sigma,
2715                rest.sigma
2716            );
2717        }
2718    }
2719
2720    #[test]
2721    fn state_preserves_rolling_sigma_mean() {
2722        let config = SGBTConfig::builder()
2723            .n_steps(5)
2724            .learning_rate(0.1)
2725            .max_depth(3)
2726            .grace_period(2)
2727            .initial_target_count(10)
2728            .uncertainty_modulated_lr(true)
2729            .build()
2730            .unwrap();
2731        let mut model = DistributionalSGBT::new(config);
2732        for i in 0..50 {
2733            let x = i as f64 * 0.1;
2734            model.train_one(&(vec![x], x.sin()));
2735        }
2736        let state = model.to_distributional_state();
2737        assert!(state.uncertainty_modulated_lr);
2738        assert!(state.rolling_sigma_mean >= 0.0);
2739
2740        let restored = DistributionalSGBT::from_distributional_state(state);
2741        assert_eq!(model.n_samples_seen(), restored.n_samples_seen());
2742    }
2743
2744    #[test]
2745    fn auto_bandwidth_computed_distributional() {
2746        let config = SGBTConfig::builder()
2747            .n_steps(3)
2748            .learning_rate(0.1)
2749            .grace_period(10)
2750            .initial_target_count(10)
2751            .build()
2752            .unwrap();
2753        let mut model = DistributionalSGBT::new(config);
2754
2755        for i in 0..200 {
2756            let x = (i as f64) * 0.1;
2757            model.train_one(&(vec![x, x.sin()], 2.0 * x + 1.0));
2758        }
2759
2760        // auto_bandwidths should be populated
2761        let bws = model.auto_bandwidths();
2762        assert_eq!(bws.len(), 2, "should have 2 feature bandwidths");
2763
2764        // Diagnostics should include auto_bandwidths
2765        let diag = model.diagnostics();
2766        assert_eq!(diag.auto_bandwidths.len(), 2);
2767
2768        // TreeDiagnostic should have prediction stats
2769        assert!(diag.location_trees[0].prediction_mean.is_finite());
2770        assert!(diag.location_trees[0].prediction_std.is_finite());
2771
2772        let pred = model.predict(&[1.0, 1.0_f64.sin()]);
2773        assert!(pred.mu.is_finite(), "auto-bandwidth mu should be finite");
2774        assert!(pred.sigma > 0.0, "auto-bandwidth sigma should be positive");
2775    }
2776
2777    #[test]
2778    fn max_leaf_output_clamps_predictions() {
2779        let config = SGBTConfig::builder()
2780            .n_steps(5)
2781            .learning_rate(1.0) // Intentionally large to force extreme leaf weights
2782            .max_leaf_output(0.5)
2783            .build()
2784            .unwrap();
2785        let mut model = DistributionalSGBT::new(config);
2786
2787        // Train with extreme targets to force large leaf weights
2788        for i in 0..200 {
2789            let target = if i % 2 == 0 { 100.0 } else { -100.0 };
2790            let sample = crate::Sample::new(vec![i as f64 % 5.0, (i as f64).sin()], target);
2791            model.train_one(&sample);
2792        }
2793
2794        // Each tree's prediction should be clamped to [-0.5, 0.5]
2795        let pred = model.predict(&[2.0, 0.5]);
2796        assert!(
2797            pred.mu.is_finite(),
2798            "prediction should be finite with clamping"
2799        );
2800    }
2801
2802    #[test]
2803    fn min_hessian_sum_suppresses_fresh_leaves() {
2804        let config = SGBTConfig::builder()
2805            .n_steps(3)
2806            .learning_rate(0.01)
2807            .min_hessian_sum(50.0)
2808            .build()
2809            .unwrap();
2810        let mut model = DistributionalSGBT::new(config);
2811
2812        // Train minimal samples
2813        for i in 0..60 {
2814            let sample = crate::Sample::new(vec![i as f64, (i as f64).sin()], i as f64 * 0.1);
2815            model.train_one(&sample);
2816        }
2817
2818        let pred = model.predict(&[30.0, 0.5]);
2819        assert!(
2820            pred.mu.is_finite(),
2821            "prediction should be finite with min_hessian_sum"
2822        );
2823    }
2824
2825    #[test]
2826    fn predict_interpolated_returns_finite() {
2827        let config = SGBTConfig::builder()
2828            .n_steps(10)
2829            .learning_rate(0.1)
2830            .grace_period(20)
2831            .max_depth(4)
2832            .n_bins(16)
2833            .initial_target_count(10)
2834            .build()
2835            .unwrap();
2836        let mut model = DistributionalSGBT::new(config);
2837        for i in 0..200 {
2838            let x = i as f64 * 0.1;
2839            let sample = crate::Sample::new(vec![x, x.sin()], x.cos());
2840            model.train_one(&sample);
2841        }
2842
2843        let pred = model.predict_interpolated(&[1.0, 0.5]);
2844        assert!(pred.mu.is_finite(), "interpolated mu should be finite");
2845        assert!(pred.sigma > 0.0, "interpolated sigma should be positive");
2846    }
2847
2848    #[test]
2849    fn huber_k_bounds_gradients() {
2850        let config = SGBTConfig::builder()
2851            .n_steps(5)
2852            .learning_rate(0.01)
2853            .huber_k(1.345)
2854            .build()
2855            .unwrap();
2856        let mut model = DistributionalSGBT::new(config);
2857
2858        // Train with occasional extreme outliers
2859        for i in 0..300 {
2860            let target = if i % 50 == 0 {
2861                1000.0
2862            } else {
2863                (i as f64 * 0.1).sin()
2864            };
2865            let sample = crate::Sample::new(vec![i as f64 % 10.0, (i as f64).cos()], target);
2866            model.train_one(&sample);
2867        }
2868
2869        let pred = model.predict(&[5.0, 0.3]);
2870        assert!(
2871            pred.mu.is_finite(),
2872            "Huber-loss mu should be finite despite outliers"
2873        );
2874        assert!(pred.sigma > 0.0, "sigma should be positive");
2875    }
2876
2877    #[test]
2878    fn ensemble_gradient_stats_populated() {
2879        let config = SGBTConfig::builder()
2880            .n_steps(10)
2881            .learning_rate(0.1)
2882            .grace_period(20)
2883            .max_depth(4)
2884            .n_bins(16)
2885            .initial_target_count(10)
2886            .build()
2887            .unwrap();
2888        let mut model = DistributionalSGBT::new(config);
2889        for i in 0..200 {
2890            let x = i as f64 * 0.1;
2891            let sample = crate::Sample::new(vec![x, x.sin()], x.cos());
2892            model.train_one(&sample);
2893        }
2894
2895        let diag = model.diagnostics();
2896        assert!(
2897            diag.ensemble_grad_mean.is_finite(),
2898            "ensemble grad mean should be finite"
2899        );
2900        assert!(
2901            diag.ensemble_grad_std >= 0.0,
2902            "ensemble grad std should be non-negative"
2903        );
2904        assert!(
2905            diag.ensemble_grad_std.is_finite(),
2906            "ensemble grad std should be finite"
2907        );
2908    }
2909
2910    #[test]
2911    fn huber_k_validation() {
2912        let result = SGBTConfig::builder()
2913            .n_steps(5)
2914            .learning_rate(0.01)
2915            .huber_k(-1.0)
2916            .build();
2917        assert!(result.is_err(), "negative huber_k should fail validation");
2918    }
2919
2920    #[test]
2921    fn max_leaf_output_validation() {
2922        let result = SGBTConfig::builder()
2923            .n_steps(5)
2924            .learning_rate(0.01)
2925            .max_leaf_output(-1.0)
2926            .build();
2927        assert!(
2928            result.is_err(),
2929            "negative max_leaf_output should fail validation"
2930        );
2931    }
2932
2933    #[test]
2934    fn predict_sibling_interpolated_varies_with_features() {
2935        let config = SGBTConfig::builder()
2936            .n_steps(10)
2937            .learning_rate(0.1)
2938            .grace_period(10)
2939            .max_depth(6)
2940            .delta(0.1)
2941            .initial_target_count(10)
2942            .build()
2943            .unwrap();
2944        let mut model = DistributionalSGBT::new(config);
2945
2946        for i in 0..2000 {
2947            let x = (i as f64) * 0.01;
2948            let y = x.sin() * x + 0.5 * (x * 2.0).cos();
2949            let sample = crate::Sample::new(vec![x, x * 0.3], y);
2950            model.train_one(&sample);
2951        }
2952
2953        // Verify the method runs correctly and produces finite predictions
2954        let pred = model.predict_sibling_interpolated(&[5.0, 1.5]);
2955        assert!(
2956            pred.mu.is_finite(),
2957            "sibling interpolated mu should be finite"
2958        );
2959        assert!(
2960            pred.sigma > 0.0,
2961            "sibling interpolated sigma should be positive"
2962        );
2963
2964        // Sweep — at minimum, sibling should produce same or more variation than hard
2965        let bws = model.auto_bandwidths();
2966        if bws.iter().any(|&b| b.is_finite()) {
2967            let hard_preds: Vec<f64> = (0..200)
2968                .map(|i| {
2969                    let x = i as f64 * 0.1;
2970                    model.predict(&[x, x * 0.3]).mu
2971                })
2972                .collect();
2973            let hard_changes = hard_preds
2974                .windows(2)
2975                .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2976                .count();
2977
2978            let preds: Vec<f64> = (0..200)
2979                .map(|i| {
2980                    let x = i as f64 * 0.1;
2981                    model.predict_sibling_interpolated(&[x, x * 0.3]).mu
2982                })
2983                .collect();
2984
2985            let sibling_changes = preds
2986                .windows(2)
2987                .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
2988                .count();
2989            assert!(
2990                sibling_changes >= hard_changes,
2991                "sibling should produce >= hard changes: sibling={}, hard={}",
2992                sibling_changes,
2993                hard_changes
2994            );
2995        }
2996    }
2997}