Skip to main content

irithyll/ensemble/
distributional.rs

1//! Distributional SGBT -- outputs Gaussian N(μ, σ²) instead of a point estimate.
2//!
3//! [`DistributionalSGBT`] supports two scale estimation modes via
4//! [`ScaleMode`]:
5//!
6//! ## Empirical σ (default)
7//!
8//! Tracks an EWMA of squared prediction errors:
9//!
10//! ```text
11//! err = target - mu
12//! ewma_sq_err = alpha * err² + (1 - alpha) * ewma_sq_err
13//! sigma = sqrt(ewma_sq_err)
14//! ```
15//!
16//! Always calibrated (σ literally *is* recent error magnitude), zero tuning,
17//! O(1) memory and compute.  When `uncertainty_modulated_lr` is enabled,
18//! high recent errors → σ large → location LR scales up → faster correction.
19//!
20//! ## Tree chain (NGBoost-style)
21//!
22//! Maintains two independent tree ensembles: one for location (μ), one for
23//! scale (log σ).  Gives feature-conditional uncertainty but requires strong
24//! scale-gradient signal for the trees to split.
25//!
26//! # References
27//!
28//! Duan et al. (2020). "NGBoost: Natural Gradient Boosting for Probabilistic Prediction."
29
30use crate::ensemble::config::{SGBTConfig, ScaleMode};
31use crate::ensemble::step::BoostingStep;
32use crate::export_embedded::export_distributional_packed;
33use crate::sample::{Observation, SampleRef};
34use crate::tree::builder::TreeConfig;
35
36/// Cached packed f32 binary for fast location-only inference.
37///
38/// Re-exported periodically from the location ensemble. Predictions use
39/// contiguous BFS-packed memory for cache-optimal tree traversal.
40struct PackedInferenceCache {
41    bytes: Vec<u8>,
42    base: f64,
43    n_features: usize,
44}
45
46impl Clone for PackedInferenceCache {
47    fn clone(&self) -> Self {
48        Self {
49            bytes: self.bytes.clone(),
50            base: self.base,
51            n_features: self.n_features,
52        }
53    }
54}
55
56/// Prediction from a distributional model: full Gaussian N(μ, σ²).
57#[derive(Debug, Clone, Copy)]
58pub struct GaussianPrediction {
59    /// Location parameter (mean).
60    pub mu: f64,
61    /// Scale parameter (standard deviation, always > 0).
62    pub sigma: f64,
63    /// Log of scale parameter (raw model output for scale ensemble).
64    pub log_sigma: f64,
65    /// Tree contribution variance (epistemic uncertainty).
66    ///
67    /// Standard deviation of individual location-tree contributions,
68    /// computed via one-pass Welford variance with Bessel's correction.
69    /// Reacts instantly when trees disagree (no EWMA lag), making it
70    /// superior to empirical sigma for regime-change detection.
71    ///
72    /// Zero when the model has 0 or 1 active location trees.
73    pub honest_sigma: f64,
74}
75
76impl GaussianPrediction {
77    /// Lower bound of a symmetric confidence interval.
78    ///
79    /// For 95% CI, use `z = 1.96`.
80    #[inline]
81    pub fn lower(&self, z: f64) -> f64 {
82        self.mu - z * self.sigma
83    }
84
85    /// Upper bound of a symmetric confidence interval.
86    #[inline]
87    pub fn upper(&self, z: f64) -> f64 {
88        self.mu + z * self.sigma
89    }
90}
91
92// ---------------------------------------------------------------------------
93// Diagnostic structs
94// ---------------------------------------------------------------------------
95
96/// Per-tree diagnostic summary.
97#[derive(Debug, Clone)]
98pub struct TreeDiagnostic {
99    /// Number of leaf nodes in this tree.
100    pub n_leaves: usize,
101    /// Maximum depth reached by any leaf.
102    pub max_depth_reached: usize,
103    /// Total samples this tree has seen.
104    pub samples_seen: u64,
105    /// Leaf weight statistics: `(min, max, mean, std)`.
106    pub leaf_weight_stats: (f64, f64, f64, f64),
107    /// Feature indices this tree has split on (non-zero gain).
108    pub split_features: Vec<usize>,
109    /// Per-leaf sample counts showing data distribution across leaves.
110    pub leaf_sample_counts: Vec<u64>,
111    /// Running mean of predictions from this tree (Welford online).
112    pub prediction_mean: f64,
113    /// Running standard deviation of predictions from this tree.
114    pub prediction_std: f64,
115}
116
117/// Full model diagnostics for [`DistributionalSGBT`].
118///
119/// Contains per-tree summaries, feature usage, base predictions, and
120/// empirical σ state.
121#[derive(Debug, Clone)]
122pub struct ModelDiagnostics {
123    /// Per-tree diagnostic summaries (location trees first, then scale trees).
124    pub trees: Vec<TreeDiagnostic>,
125    /// Location trees only (view into `trees`).
126    pub location_trees: Vec<TreeDiagnostic>,
127    /// Scale trees only (view into `trees`).
128    pub scale_trees: Vec<TreeDiagnostic>,
129    /// How many trees each feature is used in (split count per feature).
130    pub feature_split_counts: Vec<usize>,
131    /// Base prediction for location (mean).
132    pub location_base: f64,
133    /// Base prediction for scale (log-sigma).
134    pub scale_base: f64,
135    /// Current empirical σ (`sqrt(ewma_sq_err)`), always available.
136    pub empirical_sigma: f64,
137    /// Scale mode in use.
138    pub scale_mode: ScaleMode,
139    /// Number of scale trees that actually split (>1 leaf). 0 = frozen chain.
140    pub scale_trees_active: usize,
141    /// Per-feature auto-calibrated bandwidths for smooth prediction.
142    /// `f64::INFINITY` means that feature uses hard routing.
143    pub auto_bandwidths: Vec<f64>,
144    /// Ensemble-level gradient running mean.
145    pub ensemble_grad_mean: f64,
146    /// Ensemble-level gradient standard deviation.
147    pub ensemble_grad_std: f64,
148}
149
150/// Decomposed prediction showing each tree's contribution.
151#[derive(Debug, Clone)]
152pub struct DecomposedPrediction {
153    /// Base location prediction (mean of initial targets).
154    pub location_base: f64,
155    /// Base scale prediction (log-sigma of initial targets).
156    pub scale_base: f64,
157    /// Per-step location contributions: `learning_rate * tree_prediction`.
158    /// `location_base + sum(location_contributions)` = μ.
159    pub location_contributions: Vec<f64>,
160    /// Per-step scale contributions: `learning_rate * tree_prediction`.
161    /// `scale_base + sum(scale_contributions)` = log(σ).
162    pub scale_contributions: Vec<f64>,
163}
164
165impl DecomposedPrediction {
166    /// Reconstruct the final μ from base + contributions.
167    pub fn mu(&self) -> f64 {
168        self.location_base + self.location_contributions.iter().sum::<f64>()
169    }
170
171    /// Reconstruct the final log(σ) from base + contributions.
172    pub fn log_sigma(&self) -> f64 {
173        self.scale_base + self.scale_contributions.iter().sum::<f64>()
174    }
175
176    /// Reconstruct the final σ (exponentiated).
177    pub fn sigma(&self) -> f64 {
178        self.log_sigma().exp().max(1e-8)
179    }
180}
181
182// ---------------------------------------------------------------------------
183// DistributionalSGBT
184// ---------------------------------------------------------------------------
185
186/// NGBoost-style distributional streaming gradient boosted trees.
187///
188/// Outputs a full Gaussian predictive distribution N(μ, σ²) by maintaining two
189/// independent ensembles -- one for location (mean) and one for scale (log-sigma).
190///
191/// # Example
192///
193/// ```
194/// use irithyll::SGBTConfig;
195/// use irithyll::ensemble::distributional::DistributionalSGBT;
196///
197/// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
198/// let mut model = DistributionalSGBT::new(config);
199///
200/// // Train on streaming data
201/// model.train_one(&(vec![1.0, 2.0], 3.5));
202///
203/// // Get full distributional prediction
204/// let pred = model.predict(&[1.0, 2.0]);
205/// println!("mean={}, sigma={}", pred.mu, pred.sigma);
206/// ```
207pub struct DistributionalSGBT {
208    /// Configuration (shared between location and scale ensembles).
209    config: SGBTConfig,
210    /// Location (mean) boosting steps.
211    location_steps: Vec<BoostingStep>,
212    /// Scale (log-sigma) boosting steps (only used in `TreeChain` mode).
213    scale_steps: Vec<BoostingStep>,
214    /// Base prediction for location (mean of initial targets).
215    location_base: f64,
216    /// Base prediction for scale (log of std of initial targets).
217    scale_base: f64,
218    /// Whether base predictions have been initialized.
219    base_initialized: bool,
220    /// Running collection of initial targets for computing base predictions.
221    initial_targets: Vec<f64>,
222    /// Number of initial targets to collect before setting base predictions.
223    initial_target_count: usize,
224    /// Total samples trained.
225    samples_seen: u64,
226    /// RNG state for variant logic.
227    rng_state: u64,
228    /// Whether σ-modulated learning rate is enabled.
229    uncertainty_modulated_lr: bool,
230    /// EWMA of the model's predicted σ -- used as the denominator in σ-ratio.
231    ///
232    /// Updated with alpha = 0.001 (slow adaptation) after each training step.
233    /// Initialized from the standard deviation of the initial target collection.
234    rolling_sigma_mean: f64,
235    /// Scale estimation mode: Empirical (default) or TreeChain.
236    scale_mode: ScaleMode,
237    /// EWMA of squared prediction errors (empirical σ mode).
238    ///
239    /// `sigma = sqrt(ewma_sq_err)`. Updated every training step with
240    /// `ewma_sq_err = alpha * err² + (1 - alpha) * ewma_sq_err`.
241    ewma_sq_err: f64,
242    /// EWMA alpha for empirical σ.
243    empirical_sigma_alpha: f64,
244    /// Previous empirical σ value for computing σ velocity.
245    prev_sigma: f64,
246    /// EWMA-smoothed derivative of empirical σ (σ velocity).
247    /// Positive = σ increasing (errors growing), negative = σ decreasing.
248    sigma_velocity: f64,
249    /// Per-feature auto-calibrated bandwidths for smooth prediction.
250    auto_bandwidths: Vec<f64>,
251    /// Sum of replacement counts at last bandwidth computation.
252    last_replacement_sum: u64,
253    /// Ensemble-level gradient running mean (Welford).
254    ensemble_grad_mean: f64,
255    /// Ensemble-level gradient M2 (Welford sum of squared deviations).
256    ensemble_grad_m2: f64,
257    /// Ensemble-level gradient sample count.
258    ensemble_grad_count: u64,
259    /// EWMA of honest_sigma (tree contribution std dev) for σ-modulated LR.
260    ///
261    /// When `uncertainty_modulated_lr` is enabled, the honest_sigma ratio
262    /// (`honest_sigma / rolling_honest_sigma_mean`) modulates the effective
263    /// learning rate, reacting instantly to tree disagreement.
264    rolling_honest_sigma_mean: f64,
265    /// Packed f32 cache for fast location-only inference (dual-path).
266    packed_cache: Option<PackedInferenceCache>,
267    /// Samples trained since last packed cache refresh.
268    samples_since_refresh: u64,
269    /// How often to refresh the packed cache (0 = disabled).
270    packed_refresh_interval: u64,
271    // -----------------------------------------------------------------------
272    // Diagnostic cache (updated during train_one)
273    // -----------------------------------------------------------------------
274    /// Previous per-tree contributions for residual alignment (cosine similarity).
275    prev_contributions: Vec<f64>,
276    /// Contributions from two calls ago, for delta-based alignment.
277    prev_prev_contributions: Vec<f64>,
278    /// Cached cosine similarity of consecutive tree contribution vectors.
279    cached_residual_alignment: f64,
280    /// Cached mean |G|/(H+λ)² across all leaves.
281    cached_reg_sensitivity: f64,
282    /// Cached F-statistic (between-leaf / within-leaf variance).
283    cached_depth_sufficiency: f64,
284    /// Cached trace(H/(H+λ)) across all leaves.
285    cached_effective_dof: f64,
286    /// Per-tree EWMA of signed contribution accuracy. Positive = helps, negative = hurts.
287    contribution_accuracy: Vec<f64>,
288    /// EWMA alpha for contribution accuracy tracking.
289    prune_alpha: f64,
290}
291
292impl Clone for DistributionalSGBT {
293    fn clone(&self) -> Self {
294        Self {
295            config: self.config.clone(),
296            location_steps: self.location_steps.clone(),
297            scale_steps: self.scale_steps.clone(),
298            location_base: self.location_base,
299            scale_base: self.scale_base,
300            base_initialized: self.base_initialized,
301            initial_targets: self.initial_targets.clone(),
302            initial_target_count: self.initial_target_count,
303            samples_seen: self.samples_seen,
304            rng_state: self.rng_state,
305            uncertainty_modulated_lr: self.uncertainty_modulated_lr,
306            rolling_sigma_mean: self.rolling_sigma_mean,
307            scale_mode: self.scale_mode,
308            ewma_sq_err: self.ewma_sq_err,
309            empirical_sigma_alpha: self.empirical_sigma_alpha,
310            prev_sigma: self.prev_sigma,
311            sigma_velocity: self.sigma_velocity,
312            auto_bandwidths: self.auto_bandwidths.clone(),
313            last_replacement_sum: self.last_replacement_sum,
314            ensemble_grad_mean: self.ensemble_grad_mean,
315            ensemble_grad_m2: self.ensemble_grad_m2,
316            ensemble_grad_count: self.ensemble_grad_count,
317            rolling_honest_sigma_mean: self.rolling_honest_sigma_mean,
318            packed_cache: self.packed_cache.clone(),
319            samples_since_refresh: self.samples_since_refresh,
320            packed_refresh_interval: self.packed_refresh_interval,
321            prev_contributions: self.prev_contributions.clone(),
322            prev_prev_contributions: self.prev_prev_contributions.clone(),
323            cached_residual_alignment: self.cached_residual_alignment,
324            cached_reg_sensitivity: self.cached_reg_sensitivity,
325            cached_depth_sufficiency: self.cached_depth_sufficiency,
326            cached_effective_dof: self.cached_effective_dof,
327            contribution_accuracy: self.contribution_accuracy.clone(),
328            prune_alpha: self.prune_alpha,
329        }
330    }
331}
332
333impl std::fmt::Debug for DistributionalSGBT {
334    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
335        let mut s = f.debug_struct("DistributionalSGBT");
336        s.field("n_steps", &self.location_steps.len())
337            .field("samples_seen", &self.samples_seen)
338            .field("location_base", &self.location_base)
339            .field("scale_mode", &self.scale_mode)
340            .field("base_initialized", &self.base_initialized);
341        match self.scale_mode {
342            ScaleMode::Empirical => {
343                s.field("empirical_sigma", &self.ewma_sq_err.sqrt());
344            }
345            ScaleMode::TreeChain => {
346                s.field("scale_base", &self.scale_base);
347            }
348        }
349        if self.uncertainty_modulated_lr {
350            s.field("rolling_sigma_mean", &self.rolling_sigma_mean);
351        }
352        s.finish()
353    }
354}
355
356impl DistributionalSGBT {
357    /// Create a new distributional SGBT with the given configuration.
358    ///
359    /// When `scale_mode` is `Empirical` (default), scale trees are still allocated
360    /// but never trained — only the EWMA error tracker produces σ.  When
361    /// `scale_mode` is `TreeChain`, both location and scale ensembles are active.
362    pub fn new(config: SGBTConfig) -> Self {
363        let leaf_decay_alpha = config
364            .leaf_half_life
365            .map(|hl| (-(2.0_f64.ln()) / hl as f64).exp());
366
367        let tree_config = TreeConfig::new()
368            .max_depth(config.max_depth)
369            .n_bins(config.n_bins)
370            .lambda(config.lambda)
371            .gamma(config.gamma)
372            .grace_period(config.grace_period)
373            .delta(config.delta)
374            .feature_subsample_rate(config.feature_subsample_rate)
375            .leaf_decay_alpha_opt(leaf_decay_alpha)
376            .split_reeval_interval_opt(config.split_reeval_interval)
377            .feature_types_opt(config.feature_types.clone())
378            .gradient_clip_sigma_opt(config.gradient_clip_sigma)
379            .monotone_constraints_opt(config.monotone_constraints.clone())
380            .max_leaf_output_opt(config.max_leaf_output)
381            .adaptive_depth_opt(config.adaptive_depth)
382            .min_hessian_sum_opt(config.min_hessian_sum)
383            .leaf_model_type(config.leaf_model_type.clone());
384
385        let max_tree_samples = config.max_tree_samples;
386
387        // Location ensemble (seed offset: 0)
388        let shadow_warmup = config.shadow_warmup.unwrap_or(0);
389        let location_steps: Vec<BoostingStep> = (0..config.n_steps)
390            .map(|i| {
391                let mut tc = tree_config.clone();
392                tc.seed = config.seed ^ (i as u64);
393                let detector = config.drift_detector.create();
394                if shadow_warmup > 0 {
395                    BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
396                } else {
397                    BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
398                }
399            })
400            .collect();
401
402        // Scale ensemble (seed offset: 0xSCALE) -- only trained in TreeChain mode
403        let scale_steps: Vec<BoostingStep> = (0..config.n_steps)
404            .map(|i| {
405                let mut tc = tree_config.clone();
406                tc.seed = config.seed ^ (i as u64) ^ 0x0005_CA1E_0000_0000;
407                let detector = config.drift_detector.create();
408                if shadow_warmup > 0 {
409                    BoostingStep::new_with_graduated(tc, detector, max_tree_samples, shadow_warmup)
410                } else {
411                    BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
412                }
413            })
414            .collect();
415
416        let seed = config.seed;
417        let initial_target_count = config.initial_target_count;
418        let uncertainty_modulated_lr = config.uncertainty_modulated_lr;
419        let scale_mode = config.scale_mode;
420        let empirical_sigma_alpha = config.empirical_sigma_alpha;
421        let packed_refresh_interval = config.packed_refresh_interval;
422        let n_steps = config.n_steps;
423        let prune_alpha = if config.proactive_prune_interval.is_some() {
424            let gp = config.grace_period.max(1) as f64;
425            1.0 - (-2.0 / gp).exp()
426        } else {
427            0.01
428        };
429        Self {
430            config,
431            location_steps,
432            scale_steps,
433            location_base: 0.0,
434            scale_base: 0.0,
435            base_initialized: false,
436            initial_targets: Vec::new(),
437            initial_target_count,
438            samples_seen: 0,
439            rng_state: seed,
440            uncertainty_modulated_lr,
441            rolling_sigma_mean: 1.0, // overwritten during base initialization
442            scale_mode,
443            ewma_sq_err: 1.0, // overwritten during base initialization
444            empirical_sigma_alpha,
445            prev_sigma: 0.0,
446            sigma_velocity: 0.0,
447            auto_bandwidths: Vec::new(),
448            last_replacement_sum: 0,
449            ensemble_grad_mean: 0.0,
450            ensemble_grad_m2: 0.0,
451            ensemble_grad_count: 0,
452            rolling_honest_sigma_mean: 0.0,
453            packed_cache: None,
454            samples_since_refresh: 0,
455            packed_refresh_interval,
456            prev_contributions: Vec::new(),
457            prev_prev_contributions: Vec::new(),
458            cached_residual_alignment: 0.0,
459            cached_reg_sensitivity: 0.0,
460            cached_depth_sufficiency: 0.0,
461            cached_effective_dof: 0.0,
462            contribution_accuracy: vec![0.0; n_steps],
463            prune_alpha,
464        }
465    }
466
467    /// Compute honest_sigma from location tree contributions for a feature vector.
468    ///
469    /// Returns the Bessel-corrected standard deviation of individual tree
470    /// contributions (`learning_rate * tree.predict(features)`). This is
471    /// O(n_steps) — the same cost as a single prediction pass.
472    fn compute_honest_sigma(&self, features: &[f64]) -> f64 {
473        let n = self.location_steps.len();
474        if n <= 1 {
475            return 0.0;
476        }
477        let lr = self.config.learning_rate;
478        let mut sum = 0.0_f64;
479        let mut sq_sum = 0.0_f64;
480        for step in &self.location_steps {
481            let c = lr * step.predict(features);
482            sum += c;
483            sq_sum += c * c;
484        }
485        let nf = n as f64;
486        let mean_c = sum / nf;
487        let var = (sq_sum / nf) - (mean_c * mean_c);
488        let var_corrected = var * nf / (nf - 1.0);
489        var_corrected.max(0.0).sqrt()
490    }
491
492    /// Train on a single observation.
493    pub fn train_one(&mut self, sample: &impl Observation) {
494        self.samples_seen += 1;
495        let target = sample.target();
496        let features = sample.features();
497
498        // Initialize base predictions from first few targets
499        if !self.base_initialized {
500            self.initial_targets.push(target);
501            if self.initial_targets.len() >= self.initial_target_count {
502                // Location base = mean
503                let sum: f64 = self.initial_targets.iter().sum();
504                let mean = sum / self.initial_targets.len() as f64;
505                self.location_base = mean;
506
507                // Scale base = log(std) -- clamped for stability
508                let var: f64 = self
509                    .initial_targets
510                    .iter()
511                    .map(|&y| (y - mean) * (y - mean))
512                    .sum::<f64>()
513                    / self.initial_targets.len() as f64;
514                let initial_std = var.sqrt().max(1e-6);
515                self.scale_base = initial_std.ln();
516
517                // Initialize rolling sigma mean and ewma from initial targets std
518                self.rolling_sigma_mean = initial_std;
519                self.ewma_sq_err = var.max(1e-12);
520
521                // Initialize PD sigma state
522                self.prev_sigma = initial_std;
523                self.sigma_velocity = 0.0;
524
525                self.base_initialized = true;
526                self.initial_targets.clear();
527                self.initial_targets.shrink_to_fit();
528            }
529            return;
530        }
531
532        // Apply adaptive_mts: override per-step max_tree_samples based on honest_sigma.
533        if let Some((base_mts, k)) = self.config.adaptive_mts {
534            let sigma_ratio = if self.rolling_honest_sigma_mean > 1e-12 {
535                let honest_sigma = self.compute_honest_sigma(features);
536                honest_sigma / self.rolling_honest_sigma_mean
537            } else {
538                1.0
539            };
540            let effective_mts = (base_mts as f64 / (1.0 + k * sigma_ratio)).max(100.0) as u64;
541            for step in &mut self.location_steps {
542                step.slot_mut().set_max_tree_samples(Some(effective_mts));
543            }
544            for step in &mut self.scale_steps {
545                step.slot_mut().set_max_tree_samples(Some(effective_mts));
546            }
547        }
548
549        match self.scale_mode {
550            ScaleMode::Empirical => self.train_one_empirical(target, features),
551            ScaleMode::TreeChain => self.train_one_tree_chain(target, features),
552        }
553
554        // Proactive pruning: replace worst-contributing location tree every N samples.
555        if let Some(interval) = self.config.proactive_prune_interval {
556            // Always update contribution accuracy EWMAs (used by accuracy-based pruning).
557            if self.config.accuracy_based_pruning {
558                let mut location_pred = self.location_base;
559                for step in self.location_steps.iter() {
560                    location_pred += self.config.learning_rate * step.predict(features);
561                }
562                let residual = target - location_pred;
563                let sign = residual.signum();
564                for (i, step) in self.location_steps.iter().enumerate() {
565                    let contribution = self.config.learning_rate * step.predict(features);
566                    let alignment = contribution * sign;
567                    self.contribution_accuracy[i] = self.prune_alpha * alignment
568                        + (1.0 - self.prune_alpha) * self.contribution_accuracy[i];
569                }
570            }
571
572            // Interval-based fallback: fire prune check every N samples.
573            if interval > 0 && self.samples_seen % interval == 0 {
574                self.check_proactive_prune();
575            }
576        }
577
578        // Update diagnostic cache for config_diagnostics() signals
579        self.update_diagnostic_cache(features);
580
581        // Refresh auto-bandwidths when trees have been replaced
582        self.refresh_bandwidths();
583    }
584
585    /// Update cached diagnostic signals from tree internals.
586    ///
587    /// Computes four signals used by the auto-builder:
588    /// - **residual_alignment**: cosine similarity of consecutive tree contributions
589    /// - **regularization_sensitivity**: mean |G|/(H+λ)² across leaves
590    /// - **depth_sufficiency**: F-statistic (between-leaf / within-leaf variance)
591    /// - **effective_dof**: trace(H/(H+λ)) across all leaves
592    fn update_diagnostic_cache(&mut self, features: &[f64]) {
593        use crate::tree::node::NodeId;
594
595        let lambda = self.config.lambda;
596        let lr = self.config.learning_rate;
597        let n_steps = self.location_steps.len();
598
599        // 1. Residual alignment: cosine similarity of consecutive contribution vectors
600        let mut contributions = Vec::with_capacity(n_steps);
601        for step in &self.location_steps {
602            contributions.push(lr * step.predict(features));
603        }
604
605        if !self.prev_contributions.is_empty()
606            && self.prev_contributions.len() == contributions.len()
607            && !self.prev_prev_contributions.is_empty()
608            && self.prev_prev_contributions.len() == contributions.len()
609        {
610            // Delta-based alignment: cosine similarity of consecutive *changes*
611            // in the contribution vector, not the raw vectors themselves.
612            // This prevents saturation when contributions change slowly.
613            let delta_curr: Vec<f64> = contributions
614                .iter()
615                .zip(&self.prev_contributions)
616                .map(|(a, b)| a - b)
617                .collect();
618            let delta_prev: Vec<f64> = self
619                .prev_contributions
620                .iter()
621                .zip(&self.prev_prev_contributions)
622                .map(|(a, b)| a - b)
623                .collect();
624
625            let dot: f64 = delta_curr.iter().zip(&delta_prev).map(|(a, b)| a * b).sum();
626            let norm_curr: f64 = delta_curr.iter().map(|x| x * x).sum::<f64>().sqrt();
627            let norm_prev: f64 = delta_prev.iter().map(|x| x * x).sum::<f64>().sqrt();
628            self.cached_residual_alignment = if norm_curr > 1e-15 && norm_prev > 1e-15 {
629                dot / (norm_curr * norm_prev)
630            } else {
631                0.0
632            };
633        }
634        self.prev_prev_contributions =
635            core::mem::replace(&mut self.prev_contributions, contributions);
636
637        // 2-4. Leaf traversal for reg_sensitivity, depth_sufficiency, effective_dof
638        let mut total_sensitivity = 0.0;
639        let mut total_dof = 0.0;
640        let mut leaf_weights: Vec<f64> = Vec::new();
641        let mut leaf_within_vars: Vec<f64> = Vec::new();
642        let mut n_leaves_total: u64 = 0;
643
644        for step in &self.location_steps {
645            let tree = step.slot().active_tree();
646            let arena = tree.arena();
647
648            for node_idx in 0..arena.n_nodes() {
649                let nid = NodeId(node_idx as u32);
650                if arena.is_leaf(nid) {
651                    if let Some((g, h)) = tree.leaf_grad_hess(nid) {
652                        let denom = h + lambda;
653                        if denom.abs() > 1e-15 {
654                            // Reg sensitivity: |G| / (H+λ)²
655                            total_sensitivity += g.abs() / (denom * denom);
656                            // Effective DOF: H / (H+λ)
657                            total_dof += h / denom;
658                            // Leaf weight: w* = -G/(H+λ)
659                            leaf_weights.push(-g / denom);
660                            // Within-leaf variance: 1/(H+λ)
661                            leaf_within_vars.push(1.0 / denom);
662                            n_leaves_total += 1;
663                        }
664                    }
665                }
666            }
667        }
668
669        if n_leaves_total > 0 {
670            let n = n_leaves_total as f64;
671            self.cached_reg_sensitivity = total_sensitivity / n;
672            self.cached_effective_dof = total_dof;
673
674            // Depth sufficiency: F = between_var / within_var
675            let mean_weight = leaf_weights.iter().sum::<f64>() / n;
676            let between_var = leaf_weights
677                .iter()
678                .map(|w| (w - mean_weight).powi(2))
679                .sum::<f64>()
680                / (n - 1.0).max(1.0);
681            let within_var = leaf_within_vars.iter().sum::<f64>() / n;
682            self.cached_depth_sufficiency = between_var / within_var.max(1e-15);
683        }
684    }
685
686    /// Empirical-σ training: location trees only, σ from EWMA of squared errors.
687    fn train_one_empirical(&mut self, target: f64, features: &[f64]) {
688        // Current location prediction (before this step's update)
689        let mut mu = self.location_base;
690        for s in 0..self.location_steps.len() {
691            mu += self.config.learning_rate * self.location_steps[s].predict(features);
692        }
693
694        // Update rolling honest_sigma mean (tree contribution variance EWMA)
695        let honest_sigma = self.compute_honest_sigma(features);
696        const HONEST_SIGMA_ALPHA: f64 = 0.001;
697        self.rolling_honest_sigma_mean = (1.0 - HONEST_SIGMA_ALPHA)
698            * self.rolling_honest_sigma_mean
699            + HONEST_SIGMA_ALPHA * honest_sigma;
700
701        // Empirical sigma from EWMA of squared prediction errors
702        let err = target - mu;
703        let alpha = self.empirical_sigma_alpha;
704        self.ewma_sq_err = (1.0 - alpha) * self.ewma_sq_err + alpha * err * err;
705        let empirical_sigma = self.ewma_sq_err.sqrt().max(1e-8);
706
707        // Compute σ-ratio for uncertainty-modulated learning rate (PD controller)
708        let sigma_ratio = if self.uncertainty_modulated_lr {
709            // Compute sigma velocity (derivative of sigma over time)
710            let d_sigma = empirical_sigma - self.prev_sigma;
711            self.prev_sigma = empirical_sigma;
712
713            // EWMA-smooth the velocity (same alpha as empirical sigma for synchronization)
714            self.sigma_velocity = (1.0 - alpha) * self.sigma_velocity + alpha * d_sigma;
715
716            // Adaptive derivative gain: self-calibrating, no config needed
717            let k_d = if self.rolling_sigma_mean > 1e-12 {
718                self.sigma_velocity.abs() / self.rolling_sigma_mean
719            } else {
720                0.0
721            };
722
723            // PD ratio: proportional + derivative
724            let pd_sigma = empirical_sigma + k_d * self.sigma_velocity;
725            let ratio = (pd_sigma / self.rolling_sigma_mean).clamp(0.1, 10.0);
726
727            // Update rolling sigma mean with slow EWMA
728            const SIGMA_EWMA_ALPHA: f64 = 0.001;
729            self.rolling_sigma_mean = (1.0 - SIGMA_EWMA_ALPHA) * self.rolling_sigma_mean
730                + SIGMA_EWMA_ALPHA * empirical_sigma;
731
732            ratio
733        } else {
734            1.0
735        };
736
737        let base_lr = self.config.learning_rate;
738
739        // Train location steps only -- no scale trees needed
740        let mut mu_accum = self.location_base;
741        for s in 0..self.location_steps.len() {
742            let (g_mu, h_mu) = self.location_gradient(mu_accum, target);
743            // Welford update for ensemble gradient stats
744            self.update_ensemble_grad_stats(g_mu);
745            let train_count = self.config.variant.train_count(h_mu, &mut self.rng_state);
746            let loc_pred =
747                self.location_steps[s].train_and_predict(features, g_mu, h_mu, train_count);
748            mu_accum += (base_lr * sigma_ratio) * loc_pred;
749        }
750
751        // Refresh packed cache if interval reached
752        self.maybe_refresh_packed_cache();
753    }
754
755    /// Tree-chain training: full NGBoost dual-chain with location + scale trees.
756    fn train_one_tree_chain(&mut self, target: f64, features: &[f64]) {
757        let mut mu = self.location_base;
758        let mut log_sigma = self.scale_base;
759
760        // Update rolling honest_sigma mean (tree contribution variance EWMA)
761        let honest_sigma = self.compute_honest_sigma(features);
762        const HONEST_SIGMA_ALPHA: f64 = 0.001;
763        self.rolling_honest_sigma_mean = (1.0 - HONEST_SIGMA_ALPHA)
764            * self.rolling_honest_sigma_mean
765            + HONEST_SIGMA_ALPHA * honest_sigma;
766
767        // Compute σ-ratio for uncertainty-modulated learning rate (PD controller).
768        let sigma_ratio = if self.uncertainty_modulated_lr {
769            let current_sigma = log_sigma.exp().max(1e-8);
770
771            // Compute sigma velocity (derivative of sigma over time)
772            let d_sigma = current_sigma - self.prev_sigma;
773            self.prev_sigma = current_sigma;
774
775            // EWMA-smooth the velocity (same alpha as empirical sigma for synchronization)
776            let alpha = self.empirical_sigma_alpha;
777            self.sigma_velocity = (1.0 - alpha) * self.sigma_velocity + alpha * d_sigma;
778
779            // Adaptive derivative gain: self-calibrating, no config needed
780            let k_d = if self.rolling_sigma_mean > 1e-12 {
781                self.sigma_velocity.abs() / self.rolling_sigma_mean
782            } else {
783                0.0
784            };
785
786            // PD ratio: proportional + derivative
787            let pd_sigma = current_sigma + k_d * self.sigma_velocity;
788            let ratio = (pd_sigma / self.rolling_sigma_mean).clamp(0.1, 10.0);
789
790            const SIGMA_EWMA_ALPHA: f64 = 0.001;
791            self.rolling_sigma_mean = (1.0 - SIGMA_EWMA_ALPHA) * self.rolling_sigma_mean
792                + SIGMA_EWMA_ALPHA * current_sigma;
793
794            ratio
795        } else {
796            1.0
797        };
798
799        let base_lr = self.config.learning_rate;
800
801        // Sequential boosting: both ensembles target their respective residuals
802        for s in 0..self.location_steps.len() {
803            let sigma = log_sigma.exp().max(1e-8);
804            let z = (target - mu) / sigma;
805
806            // Location gradients (squared or Huber loss w.r.t. mu)
807            let (g_mu, h_mu) = self.location_gradient(mu, target);
808            // Welford update for ensemble gradient stats
809            self.update_ensemble_grad_stats(g_mu);
810
811            // Scale gradients (NLL w.r.t. log_sigma)
812            let g_sigma = 1.0 - z * z;
813            let h_sigma = (2.0 * z * z).clamp(0.01, 100.0);
814
815            let train_count = self.config.variant.train_count(h_mu, &mut self.rng_state);
816
817            // Train location step -- σ-modulated LR when enabled
818            let loc_pred =
819                self.location_steps[s].train_and_predict(features, g_mu, h_mu, train_count);
820            mu += (base_lr * sigma_ratio) * loc_pred;
821
822            // Train scale step -- ALWAYS at unmodulated base rate.
823            let scale_pred =
824                self.scale_steps[s].train_and_predict(features, g_sigma, h_sigma, train_count);
825            log_sigma += base_lr * scale_pred;
826        }
827
828        // Also update empirical sigma tracker for diagnostics
829        let err = target - mu;
830        let alpha = self.empirical_sigma_alpha;
831        self.ewma_sq_err = (1.0 - alpha) * self.ewma_sq_err + alpha * err * err;
832
833        // Refresh packed cache if interval reached
834        self.maybe_refresh_packed_cache();
835    }
836
837    /// Predict the full Gaussian distribution for a feature vector.
838    ///
839    /// When a packed cache is available, uses it for the location (μ) prediction
840    /// via contiguous BFS-packed memory traversal. Falls back to full tree
841    /// traversal if the cache is absent or produces non-finite results.
842    ///
843    /// Sigma computation always uses the primary path (EWMA or scale chain)
844    /// and is unaffected by the packed cache.
845    pub fn predict(&self, features: &[f64]) -> GaussianPrediction {
846        // Try packed cache for mu if available
847        let mu = if let Some(ref cache) = self.packed_cache {
848            let features_f32: Vec<f32> = features.iter().map(|&v| v as f32).collect();
849            match irithyll_core::EnsembleView::from_bytes(&cache.bytes) {
850                Ok(view) => {
851                    let packed_mu = cache.base + view.predict(&features_f32) as f64;
852                    if packed_mu.is_finite() {
853                        packed_mu
854                    } else {
855                        self.predict_full_trees(features)
856                    }
857                }
858                Err(_) => self.predict_full_trees(features),
859            }
860        } else {
861            self.predict_full_trees(features)
862        };
863
864        let (sigma, log_sigma) = match self.scale_mode {
865            ScaleMode::Empirical => {
866                let s = self.ewma_sq_err.sqrt().max(1e-8);
867                (s, s.ln())
868            }
869            ScaleMode::TreeChain => {
870                let mut ls = self.scale_base;
871                if self.auto_bandwidths.is_empty() {
872                    for s in 0..self.scale_steps.len() {
873                        ls += self.config.learning_rate * self.scale_steps[s].predict(features);
874                    }
875                } else {
876                    for s in 0..self.scale_steps.len() {
877                        ls += self.config.learning_rate
878                            * self.scale_steps[s]
879                                .predict_smooth_auto(features, &self.auto_bandwidths);
880                    }
881                }
882                (ls.exp().max(1e-8), ls)
883            }
884        };
885
886        let honest_sigma = self.compute_honest_sigma(features);
887
888        GaussianPrediction {
889            mu,
890            sigma,
891            log_sigma,
892            honest_sigma,
893        }
894    }
895
896    /// Full-tree location prediction (fallback when packed cache is unavailable).
897    fn predict_full_trees(&self, features: &[f64]) -> f64 {
898        let mut mu = self.location_base;
899        if self.auto_bandwidths.is_empty() {
900            for s in 0..self.location_steps.len() {
901                mu += self.config.learning_rate * self.location_steps[s].predict(features);
902            }
903        } else {
904            for s in 0..self.location_steps.len() {
905                mu += self.config.learning_rate
906                    * self.location_steps[s].predict_smooth_auto(features, &self.auto_bandwidths);
907            }
908        }
909        mu
910    }
911
912    /// Predict using sigmoid-blended soft routing for smooth interpolation.
913    ///
914    /// Instead of hard left/right routing at tree split nodes, each split
915    /// uses sigmoid blending: `alpha = sigmoid((threshold - feature) / bandwidth)`.
916    /// The result is a continuous function that varies smoothly with every
917    /// feature change.
918    ///
919    /// `bandwidth` controls transition sharpness: smaller = sharper (closer
920    /// to hard splits), larger = smoother.
921    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> GaussianPrediction {
922        let mut mu = self.location_base;
923        for s in 0..self.location_steps.len() {
924            mu += self.config.learning_rate
925                * self.location_steps[s].predict_smooth(features, bandwidth);
926        }
927
928        let (sigma, log_sigma) = match self.scale_mode {
929            ScaleMode::Empirical => {
930                let s = self.ewma_sq_err.sqrt().max(1e-8);
931                (s, s.ln())
932            }
933            ScaleMode::TreeChain => {
934                let mut ls = self.scale_base;
935                for s in 0..self.scale_steps.len() {
936                    ls += self.config.learning_rate
937                        * self.scale_steps[s].predict_smooth(features, bandwidth);
938                }
939                (ls.exp().max(1e-8), ls)
940            }
941        };
942
943        let honest_sigma = self.compute_honest_sigma(features);
944
945        GaussianPrediction {
946            mu,
947            sigma,
948            log_sigma,
949            honest_sigma,
950        }
951    }
952
953    /// Predict with parent-leaf linear interpolation.
954    ///
955    /// Blends each leaf prediction with its parent's preserved prediction
956    /// based on sample count, preventing stale predictions from fresh leaves.
957    pub fn predict_interpolated(&self, features: &[f64]) -> GaussianPrediction {
958        let mut mu = self.location_base;
959        for s in 0..self.location_steps.len() {
960            mu += self.config.learning_rate * self.location_steps[s].predict_interpolated(features);
961        }
962
963        let (sigma, log_sigma) = match self.scale_mode {
964            ScaleMode::Empirical => {
965                let s = self.ewma_sq_err.sqrt().max(1e-8);
966                (s, s.ln())
967            }
968            ScaleMode::TreeChain => {
969                let mut ls = self.scale_base;
970                for s in 0..self.scale_steps.len() {
971                    ls += self.config.learning_rate
972                        * self.scale_steps[s].predict_interpolated(features);
973                }
974                (ls.exp().max(1e-8), ls)
975            }
976        };
977
978        let honest_sigma = self.compute_honest_sigma(features);
979
980        GaussianPrediction {
981            mu,
982            sigma,
983            log_sigma,
984            honest_sigma,
985        }
986    }
987
988    /// Predict with sibling-based interpolation for feature-continuous predictions.
989    ///
990    /// At each split node near the threshold boundary, blends left and right
991    /// subtree predictions linearly. Uses auto-calibrated bandwidths as the
992    /// interpolation margin. Predictions vary continuously as features change.
993    pub fn predict_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
994        let mut mu = self.location_base;
995        for s in 0..self.location_steps.len() {
996            mu += self.config.learning_rate
997                * self.location_steps[s]
998                    .predict_sibling_interpolated(features, &self.auto_bandwidths);
999        }
1000
1001        let (sigma, log_sigma) = match self.scale_mode {
1002            ScaleMode::Empirical => {
1003                let s = self.ewma_sq_err.sqrt().max(1e-8);
1004                (s, s.ln())
1005            }
1006            ScaleMode::TreeChain => {
1007                let mut ls = self.scale_base;
1008                for s in 0..self.scale_steps.len() {
1009                    ls += self.config.learning_rate
1010                        * self.scale_steps[s]
1011                            .predict_sibling_interpolated(features, &self.auto_bandwidths);
1012                }
1013                (ls.exp().max(1e-8), ls)
1014            }
1015        };
1016
1017        let honest_sigma = self.compute_honest_sigma(features);
1018
1019        GaussianPrediction {
1020            mu,
1021            sigma,
1022            log_sigma,
1023            honest_sigma,
1024        }
1025    }
1026
1027    /// Predict with graduated active-shadow blending.
1028    ///
1029    /// Smoothly transitions between active and shadow trees during replacement.
1030    /// Requires `shadow_warmup` to be configured.
1031    pub fn predict_graduated(&self, features: &[f64]) -> GaussianPrediction {
1032        let mut mu = self.location_base;
1033        for s in 0..self.location_steps.len() {
1034            mu += self.config.learning_rate * self.location_steps[s].predict_graduated(features);
1035        }
1036
1037        let (sigma, log_sigma) = match self.scale_mode {
1038            ScaleMode::Empirical => {
1039                let s = self.ewma_sq_err.sqrt().max(1e-8);
1040                (s, s.ln())
1041            }
1042            ScaleMode::TreeChain => {
1043                let mut ls = self.scale_base;
1044                for s in 0..self.scale_steps.len() {
1045                    ls +=
1046                        self.config.learning_rate * self.scale_steps[s].predict_graduated(features);
1047                }
1048                (ls.exp().max(1e-8), ls)
1049            }
1050        };
1051
1052        let honest_sigma = self.compute_honest_sigma(features);
1053
1054        GaussianPrediction {
1055            mu,
1056            sigma,
1057            log_sigma,
1058            honest_sigma,
1059        }
1060    }
1061
1062    /// Predict with graduated blending + sibling interpolation (premium path).
1063    pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
1064        let mut mu = self.location_base;
1065        for s in 0..self.location_steps.len() {
1066            mu += self.config.learning_rate
1067                * self.location_steps[s]
1068                    .predict_graduated_sibling_interpolated(features, &self.auto_bandwidths);
1069        }
1070
1071        let (sigma, log_sigma) = match self.scale_mode {
1072            ScaleMode::Empirical => {
1073                let s = self.ewma_sq_err.sqrt().max(1e-8);
1074                (s, s.ln())
1075            }
1076            ScaleMode::TreeChain => {
1077                let mut ls = self.scale_base;
1078                for s in 0..self.scale_steps.len() {
1079                    ls += self.config.learning_rate
1080                        * self.scale_steps[s].predict_graduated_sibling_interpolated(
1081                            features,
1082                            &self.auto_bandwidths,
1083                        );
1084                }
1085                (ls.exp().max(1e-8), ls)
1086            }
1087        };
1088
1089        let honest_sigma = self.compute_honest_sigma(features);
1090
1091        GaussianPrediction {
1092            mu,
1093            sigma,
1094            log_sigma,
1095            honest_sigma,
1096        }
1097    }
1098
1099    /// Predict using per-node auto-bandwidth soft routing.
1100    ///
1101    /// Every prediction is a continuous weighted blend instead of a
1102    /// piecewise-constant step function. No training changes.
1103    pub fn predict_soft_routed(&self, features: &[f64]) -> GaussianPrediction {
1104        // Location chain with soft routing
1105        let mut mu = self.location_base;
1106        for step in &self.location_steps {
1107            mu += self.config.learning_rate * step.predict_soft_routed(features);
1108        }
1109
1110        // Scale chain
1111        let (sigma, log_sigma) = match self.scale_mode {
1112            ScaleMode::Empirical => {
1113                let s = self.ewma_sq_err.sqrt().max(1e-8);
1114                (s, s.ln())
1115            }
1116            ScaleMode::TreeChain => {
1117                let mut ls = self.scale_base;
1118                for step in &self.scale_steps {
1119                    ls += self.config.learning_rate * step.predict_soft_routed(features);
1120                }
1121                (ls.exp().max(1e-8), ls)
1122            }
1123        };
1124
1125        let honest_sigma = self.compute_honest_sigma(features);
1126
1127        GaussianPrediction {
1128            mu,
1129            sigma,
1130            log_sigma,
1131            honest_sigma,
1132        }
1133    }
1134
1135    /// Predict with σ-ratio diagnostic exposed.
1136    ///
1137    /// Returns `(mu, sigma, sigma_ratio)` where `sigma_ratio` is
1138    /// `current_sigma / rolling_sigma_mean` -- the multiplier applied to the
1139    /// location learning rate when [`uncertainty_modulated_lr`](SGBTConfig::uncertainty_modulated_lr)
1140    /// is enabled.
1141    ///
1142    /// When σ-modulation is disabled, `sigma_ratio` is always `1.0`.
1143    pub fn predict_distributional(&self, features: &[f64]) -> (f64, f64, f64) {
1144        let pred = self.predict(features);
1145        let sigma_ratio = if self.uncertainty_modulated_lr {
1146            (pred.sigma / self.rolling_sigma_mean).clamp(0.1, 10.0)
1147        } else {
1148            1.0
1149        };
1150        (pred.mu, pred.sigma, sigma_ratio)
1151    }
1152
1153    /// Current empirical sigma (`sqrt(ewma_sq_err)`).
1154    ///
1155    /// Returns the model's recent error magnitude. Available in both scale modes.
1156    #[inline]
1157    pub fn empirical_sigma(&self) -> f64 {
1158        self.ewma_sq_err.sqrt()
1159    }
1160
1161    /// Current scale mode.
1162    #[inline]
1163    pub fn scale_mode(&self) -> ScaleMode {
1164        self.scale_mode
1165    }
1166
1167    /// Current σ velocity -- the EWMA-smoothed derivative of empirical σ.
1168    ///
1169    /// Positive values indicate growing prediction errors (model deteriorating
1170    /// or regime change). Negative values indicate improving predictions.
1171    /// Only meaningful when `ScaleMode::Empirical` is active.
1172    #[inline]
1173    pub fn sigma_velocity(&self) -> f64 {
1174        self.sigma_velocity
1175    }
1176
1177    /// Predict the mean (location parameter) only.
1178    #[inline]
1179    pub fn predict_mu(&self, features: &[f64]) -> f64 {
1180        self.predict(features).mu
1181    }
1182
1183    /// Predict the standard deviation (scale parameter) only.
1184    #[inline]
1185    pub fn predict_sigma(&self, features: &[f64]) -> f64 {
1186        self.predict(features).sigma
1187    }
1188
1189    /// Predict a symmetric confidence interval.
1190    ///
1191    /// `confidence` is the Z-score multiplier:
1192    /// - 1.0 → 68% CI
1193    /// - 1.96 → 95% CI
1194    /// - 2.576 → 99% CI
1195    pub fn predict_interval(&self, features: &[f64], confidence: f64) -> (f64, f64) {
1196        let pred = self.predict(features);
1197        (pred.lower(confidence), pred.upper(confidence))
1198    }
1199
1200    /// Batch prediction.
1201    pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<GaussianPrediction> {
1202        feature_matrix.iter().map(|f| self.predict(f)).collect()
1203    }
1204
1205    /// Train on a batch of observations.
1206    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
1207        for sample in samples {
1208            self.train_one(sample);
1209        }
1210    }
1211
1212    /// Train on a batch with periodic callback.
1213    pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
1214        &mut self,
1215        samples: &[O],
1216        interval: usize,
1217        mut callback: F,
1218    ) {
1219        let interval = interval.max(1);
1220        for (i, sample) in samples.iter().enumerate() {
1221            self.train_one(sample);
1222            if (i + 1) % interval == 0 {
1223                callback(i + 1);
1224            }
1225        }
1226        let total = samples.len();
1227        if total % interval != 0 {
1228            callback(total);
1229        }
1230    }
1231
1232    /// Compute location gradient (squared or adaptive Huber).
1233    ///
1234    /// When `huber_k` is configured, uses Huber loss with adaptive
1235    /// `delta = k * empirical_sigma` for bounded gradients.
1236    #[inline]
1237    fn location_gradient(&self, mu: f64, target: f64) -> (f64, f64) {
1238        if let Some(k) = self.config.huber_k {
1239            let delta = k * self.ewma_sq_err.sqrt().max(1e-8);
1240            let residual = mu - target;
1241            if residual.abs() <= delta {
1242                (residual, 1.0)
1243            } else {
1244                (delta * residual.signum(), 1e-6)
1245            }
1246        } else {
1247            (mu - target, 1.0)
1248        }
1249    }
1250
1251    /// Welford update for ensemble-level gradient statistics.
1252    #[inline]
1253    fn update_ensemble_grad_stats(&mut self, gradient: f64) {
1254        self.ensemble_grad_count += 1;
1255        let delta = gradient - self.ensemble_grad_mean;
1256        self.ensemble_grad_mean += delta / self.ensemble_grad_count as f64;
1257        let delta2 = gradient - self.ensemble_grad_mean;
1258        self.ensemble_grad_m2 += delta * delta2;
1259    }
1260
1261    /// Ensemble-level gradient standard deviation.
1262    pub fn ensemble_grad_std(&self) -> f64 {
1263        if self.ensemble_grad_count < 2 {
1264            return 0.0;
1265        }
1266        (self.ensemble_grad_m2 / (self.ensemble_grad_count - 1) as f64)
1267            .sqrt()
1268            .max(0.0)
1269    }
1270
1271    /// Ensemble-level gradient mean.
1272    pub fn ensemble_grad_mean(&self) -> f64 {
1273        self.ensemble_grad_mean
1274    }
1275
1276    /// Check if packed cache should be refreshed and do so if interval reached.
1277    fn maybe_refresh_packed_cache(&mut self) {
1278        if self.packed_refresh_interval > 0 {
1279            self.samples_since_refresh += 1;
1280            if self.samples_since_refresh >= self.packed_refresh_interval {
1281                self.refresh_packed_cache();
1282                self.samples_since_refresh = 0;
1283            }
1284        }
1285    }
1286
1287    /// Re-export the location ensemble into the packed cache.
1288    fn refresh_packed_cache(&mut self) {
1289        // Determine n_features from the first tree that has been initialized
1290        let n_features = self
1291            .location_steps
1292            .iter()
1293            .filter_map(|s| s.slot().active_tree().n_features())
1294            .max()
1295            .unwrap_or(0);
1296
1297        if n_features == 0 {
1298            return;
1299        }
1300
1301        let (bytes, base) = export_distributional_packed(self, n_features);
1302        self.packed_cache = Some(PackedInferenceCache {
1303            bytes,
1304            base,
1305            n_features,
1306        });
1307    }
1308
1309    /// Enable or reconfigure the packed inference cache at runtime.
1310    ///
1311    /// Sets the refresh interval and immediately builds the initial cache
1312    /// if the model has been initialized. Pass `0` to disable.
1313    pub fn enable_packed_cache(&mut self, interval: u64) {
1314        self.packed_refresh_interval = interval;
1315        self.samples_since_refresh = 0;
1316        if interval > 0 && self.base_initialized {
1317            self.refresh_packed_cache();
1318        } else if interval == 0 {
1319            self.packed_cache = None;
1320        }
1321    }
1322
1323    /// Whether the packed inference cache is currently populated.
1324    #[inline]
1325    pub fn has_packed_cache(&self) -> bool {
1326        self.packed_cache.is_some()
1327    }
1328
1329    /// Refresh auto-bandwidths if any tree has been replaced since last computation.
1330    fn refresh_bandwidths(&mut self) {
1331        let current_sum: u64 = self
1332            .location_steps
1333            .iter()
1334            .chain(self.scale_steps.iter())
1335            .map(|s| s.slot().replacements())
1336            .sum();
1337        if current_sum != self.last_replacement_sum || self.auto_bandwidths.is_empty() {
1338            self.auto_bandwidths = self.compute_auto_bandwidths();
1339            self.last_replacement_sum = current_sum;
1340        }
1341    }
1342
1343    /// Compute per-feature auto-calibrated bandwidths from all trees.
1344    fn compute_auto_bandwidths(&self) -> Vec<f64> {
1345        const K: f64 = 2.0;
1346
1347        let n_features = self
1348            .location_steps
1349            .iter()
1350            .chain(self.scale_steps.iter())
1351            .filter_map(|s| s.slot().active_tree().n_features())
1352            .max()
1353            .unwrap_or(0);
1354
1355        if n_features == 0 {
1356            return Vec::new();
1357        }
1358
1359        let mut all_thresholds: Vec<Vec<f64>> = vec![Vec::new(); n_features];
1360
1361        for step in self.location_steps.iter().chain(self.scale_steps.iter()) {
1362            let tree_thresholds = step
1363                .slot()
1364                .active_tree()
1365                .collect_split_thresholds_per_feature();
1366            for (i, ts) in tree_thresholds.into_iter().enumerate() {
1367                if i < n_features {
1368                    all_thresholds[i].extend(ts);
1369                }
1370            }
1371        }
1372
1373        let n_bins = self.config.n_bins as f64;
1374
1375        all_thresholds
1376            .iter()
1377            .map(|ts| {
1378                if ts.is_empty() {
1379                    return f64::INFINITY;
1380                }
1381
1382                let mut sorted = ts.clone();
1383                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1384                sorted.dedup_by(|a, b| (*a - *b).abs() < 1e-15);
1385
1386                if sorted.len() < 2 {
1387                    return f64::INFINITY;
1388                }
1389
1390                let mut gaps: Vec<f64> = sorted.windows(2).map(|w| w[1] - w[0]).collect();
1391
1392                if sorted.len() < 3 {
1393                    let range = sorted.last().unwrap() - sorted.first().unwrap();
1394                    if range < 1e-15 {
1395                        return f64::INFINITY;
1396                    }
1397                    return (range / n_bins) * K;
1398                }
1399
1400                gaps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1401                let median_gap = if gaps.len() % 2 == 0 {
1402                    (gaps[gaps.len() / 2 - 1] + gaps[gaps.len() / 2]) / 2.0
1403                } else {
1404                    gaps[gaps.len() / 2]
1405                };
1406
1407                if median_gap < 1e-15 {
1408                    f64::INFINITY
1409                } else {
1410                    median_gap * K
1411                }
1412            })
1413            .collect()
1414    }
1415
1416    /// Per-feature auto-calibrated bandwidths used by `predict()`.
1417    pub fn auto_bandwidths(&self) -> &[f64] {
1418        &self.auto_bandwidths
1419    }
1420
1421    /// Reset to initial untrained state.
1422    pub fn reset(&mut self) {
1423        for step in &mut self.location_steps {
1424            step.reset();
1425        }
1426        for step in &mut self.scale_steps {
1427            step.reset();
1428        }
1429        self.location_base = 0.0;
1430        self.scale_base = 0.0;
1431        self.base_initialized = false;
1432        self.initial_targets.clear();
1433        self.samples_seen = 0;
1434        self.rng_state = self.config.seed;
1435        self.rolling_sigma_mean = 1.0;
1436        self.ewma_sq_err = 1.0;
1437        self.prev_sigma = 0.0;
1438        self.sigma_velocity = 0.0;
1439        self.auto_bandwidths.clear();
1440        self.last_replacement_sum = 0;
1441        self.ensemble_grad_mean = 0.0;
1442        self.ensemble_grad_m2 = 0.0;
1443        self.ensemble_grad_count = 0;
1444        self.rolling_honest_sigma_mean = 0.0;
1445        self.packed_cache = None;
1446        self.samples_since_refresh = 0;
1447        self.prev_contributions.clear();
1448        self.prev_prev_contributions.clear();
1449        self.cached_residual_alignment = 0.0;
1450        self.cached_reg_sensitivity = 0.0;
1451        self.cached_depth_sufficiency = 0.0;
1452        self.cached_effective_dof = 0.0;
1453        self.contribution_accuracy = vec![0.0; self.location_steps.len()];
1454    }
1455
1456    /// Total samples trained.
1457    #[inline]
1458    pub fn n_samples_seen(&self) -> u64 {
1459        self.samples_seen
1460    }
1461
1462    /// Number of boosting steps (same for location and scale).
1463    #[inline]
1464    pub fn n_steps(&self) -> usize {
1465        self.location_steps.len()
1466    }
1467
1468    /// Total trees (location + scale, active + alternates).
1469    pub fn n_trees(&self) -> usize {
1470        let loc = self.location_steps.len()
1471            + self
1472                .location_steps
1473                .iter()
1474                .filter(|s| s.has_alternate())
1475                .count();
1476        let scale = self.scale_steps.len()
1477            + self
1478                .scale_steps
1479                .iter()
1480                .filter(|s| s.has_alternate())
1481                .count();
1482        loc + scale
1483    }
1484
1485    /// Total leaves across all active trees (location + scale).
1486    pub fn total_leaves(&self) -> usize {
1487        let loc: usize = self.location_steps.iter().map(|s| s.n_leaves()).sum();
1488        let scale: usize = self.scale_steps.iter().map(|s| s.n_leaves()).sum();
1489        loc + scale
1490    }
1491
1492    /// Whether base predictions have been initialized.
1493    #[inline]
1494    pub fn is_initialized(&self) -> bool {
1495        self.base_initialized
1496    }
1497
1498    /// Access the configuration.
1499    #[inline]
1500    pub fn config(&self) -> &SGBTConfig {
1501        &self.config
1502    }
1503
1504    /// Set the learning rate for future boosting rounds.
1505    #[inline]
1506    pub fn set_learning_rate(&mut self, lr: f64) {
1507        self.config.learning_rate = lr;
1508    }
1509
1510    /// Set the L2 regularization parameter (lambda) for future boosting rounds.
1511    #[inline]
1512    pub fn set_lambda(&mut self, lambda: f64) {
1513        self.config.lambda = lambda.max(0.0);
1514    }
1515
1516    /// Set the maximum tree depth for future replacement trees.
1517    #[inline]
1518    pub fn set_max_depth(&mut self, depth: usize) {
1519        self.config.max_depth = depth.clamp(1, 20);
1520    }
1521
1522    /// Adjust the number of boosting steps (trees per chain).
1523    ///
1524    /// Both location and scale ensembles are resized together.
1525    /// Clamped to `3..=1000`.
1526    pub fn set_n_steps(&mut self, n: usize) {
1527        let n = n.clamp(3, 1000);
1528        let current = self.location_steps.len();
1529        if n > current {
1530            let leaf_decay_alpha = self
1531                .config
1532                .leaf_half_life
1533                .map(|hl| (-(2.0_f64.ln()) / hl as f64).exp());
1534            let tree_config = TreeConfig::new()
1535                .max_depth(self.config.max_depth)
1536                .n_bins(self.config.n_bins)
1537                .lambda(self.config.lambda)
1538                .gamma(self.config.gamma)
1539                .grace_period(self.config.grace_period)
1540                .delta(self.config.delta)
1541                .feature_subsample_rate(self.config.feature_subsample_rate)
1542                .leaf_decay_alpha_opt(leaf_decay_alpha)
1543                .split_reeval_interval_opt(self.config.split_reeval_interval)
1544                .feature_types_opt(self.config.feature_types.clone())
1545                .gradient_clip_sigma_opt(self.config.gradient_clip_sigma)
1546                .monotone_constraints_opt(self.config.monotone_constraints.clone())
1547                .max_leaf_output_opt(self.config.max_leaf_output)
1548                .adaptive_depth_opt(self.config.adaptive_depth)
1549                .min_hessian_sum_opt(self.config.min_hessian_sum)
1550                .leaf_model_type(self.config.leaf_model_type.clone());
1551            let mts = self.config.max_tree_samples;
1552            let shadow_warmup = self.config.shadow_warmup.unwrap_or(0);
1553            for i in current..n {
1554                // Location step
1555                let mut tc = tree_config.clone();
1556                tc.seed = self.config.seed ^ (i as u64);
1557                let detector = self.config.drift_detector.create();
1558                let step = if shadow_warmup > 0 {
1559                    BoostingStep::new_with_graduated(tc, detector, mts, shadow_warmup)
1560                } else {
1561                    BoostingStep::new_with_max_samples(tc, detector, mts)
1562                };
1563                self.location_steps.push(step);
1564
1565                // Scale step
1566                let mut tc = tree_config.clone();
1567                tc.seed = self.config.seed ^ (i as u64) ^ 0x0005_CA1E_0000_0000;
1568                let detector = self.config.drift_detector.create();
1569                let step = if shadow_warmup > 0 {
1570                    BoostingStep::new_with_graduated(tc, detector, mts, shadow_warmup)
1571                } else {
1572                    BoostingStep::new_with_max_samples(tc, detector, mts)
1573                };
1574                self.scale_steps.push(step);
1575            }
1576        } else if n < current {
1577            self.location_steps.truncate(n);
1578            self.scale_steps.truncate(n);
1579        }
1580        self.contribution_accuracy.resize(n, 0.0);
1581        self.config.n_steps = n;
1582    }
1583
1584    /// Total tree replacements across all boosting steps (location + scale).
1585    pub fn total_replacements(&self) -> u64 {
1586        self.location_steps
1587            .iter()
1588            .chain(self.scale_steps.iter())
1589            .map(|s| s.slot().replacements())
1590            .sum()
1591    }
1592
1593    /// Manually trigger a proactive prune check on the location chain.
1594    ///
1595    /// Same semantics as [`crate::ensemble::SGBT::check_proactive_prune`]: finds the worst
1596    /// mature tree and replaces it if harmful (accuracy-based) or stale
1597    /// (variance-based). Returns `true` if a tree was replaced.
1598    pub fn check_proactive_prune(&mut self) -> bool {
1599        if self.location_steps.len() <= 1 {
1600            return false;
1601        }
1602        if self.config.accuracy_based_pruning {
1603            let grace_period = self.config.grace_period as u64;
1604            let worst = self
1605                .location_steps
1606                .iter()
1607                .enumerate()
1608                .zip(self.contribution_accuracy.iter())
1609                .filter(|((_, step), _)| step.slot().n_samples_seen() >= grace_period)
1610                .min_by(|((_, _), a), ((_, _), b)| {
1611                    a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
1612                });
1613            if let Some(((worst_idx, _), &worst_acc)) = worst {
1614                if worst_acc < 0.0 {
1615                    self.location_steps[worst_idx].slot_mut().replace_active();
1616                    self.contribution_accuracy[worst_idx] = 0.0;
1617                    return true;
1618                }
1619            }
1620            false
1621        } else {
1622            let worst_idx = self
1623                .location_steps
1624                .iter()
1625                .enumerate()
1626                .min_by(|(_, a), (_, b)| {
1627                    let a_std = a.slot().prediction_std();
1628                    let b_std = b.slot().prediction_std();
1629                    a_std
1630                        .partial_cmp(&b_std)
1631                        .unwrap_or(std::cmp::Ordering::Equal)
1632                })
1633                .map(|(i, _)| i)
1634                .unwrap_or(0);
1635            self.location_steps[worst_idx].slot_mut().replace_active();
1636            true
1637        }
1638    }
1639
1640    /// Access the location boosting steps (for export/inspection).
1641    pub fn location_steps(&self) -> &[BoostingStep] {
1642        &self.location_steps
1643    }
1644
1645    /// Base prediction for the location (mean) ensemble.
1646    #[inline]
1647    pub fn location_base(&self) -> f64 {
1648        self.location_base
1649    }
1650
1651    /// Learning rate from the model configuration.
1652    #[inline]
1653    pub fn learning_rate(&self) -> f64 {
1654        self.config.learning_rate
1655    }
1656
1657    /// Current rolling σ mean (EWMA of predicted σ).
1658    ///
1659    /// Returns `1.0` if the model hasn't been initialized yet.
1660    #[inline]
1661    pub fn rolling_sigma_mean(&self) -> f64 {
1662        self.rolling_sigma_mean
1663    }
1664
1665    /// Whether σ-modulated learning rate is active.
1666    #[inline]
1667    pub fn is_uncertainty_modulated(&self) -> bool {
1668        self.uncertainty_modulated_lr
1669    }
1670
1671    /// Current rolling honest_sigma mean (EWMA of tree contribution std dev).
1672    ///
1673    /// Returns `0.0` before any training samples have been processed.
1674    #[inline]
1675    pub fn rolling_honest_sigma_mean(&self) -> f64 {
1676        self.rolling_honest_sigma_mean
1677    }
1678
1679    // -------------------------------------------------------------------
1680    // Diagnostics
1681    // -------------------------------------------------------------------
1682
1683    /// Full model diagnostics: per-tree structure, feature usage, base predictions.
1684    ///
1685    /// The `trees` vector contains location trees first (indices `0..n_steps`),
1686    /// then scale trees (`n_steps..2*n_steps`).
1687    ///
1688    /// `scale_trees_active` counts how many scale trees have actually split
1689    /// (more than 1 leaf). If this is 0, the scale chain is effectively frozen.
1690    pub fn diagnostics(&self) -> ModelDiagnostics {
1691        let n = self.location_steps.len();
1692        let mut trees = Vec::with_capacity(2 * n);
1693        let mut feature_split_counts: Vec<usize> = Vec::new();
1694
1695        fn collect_tree_diags(
1696            steps: &[BoostingStep],
1697            trees: &mut Vec<TreeDiagnostic>,
1698            feature_split_counts: &mut Vec<usize>,
1699        ) {
1700            for step in steps {
1701                let slot = step.slot();
1702                let tree = slot.active_tree();
1703                let arena = tree.arena();
1704
1705                let leaf_values: Vec<f64> = (0..arena.is_leaf.len())
1706                    .filter(|&i| arena.is_leaf[i])
1707                    .map(|i| arena.leaf_value[i])
1708                    .collect();
1709
1710                let leaf_sample_counts: Vec<u64> = (0..arena.is_leaf.len())
1711                    .filter(|&i| arena.is_leaf[i])
1712                    .map(|i| arena.sample_count[i])
1713                    .collect();
1714
1715                let max_depth_reached = (0..arena.is_leaf.len())
1716                    .filter(|&i| arena.is_leaf[i])
1717                    .map(|i| arena.depth[i] as usize)
1718                    .max()
1719                    .unwrap_or(0);
1720
1721                let leaf_weight_stats = if leaf_values.is_empty() {
1722                    (0.0, 0.0, 0.0, 0.0)
1723                } else {
1724                    let min = leaf_values.iter().cloned().fold(f64::INFINITY, f64::min);
1725                    let max = leaf_values
1726                        .iter()
1727                        .cloned()
1728                        .fold(f64::NEG_INFINITY, f64::max);
1729                    let sum: f64 = leaf_values.iter().sum();
1730                    let mean = sum / leaf_values.len() as f64;
1731                    let var: f64 = leaf_values.iter().map(|v| (v - mean).powi(2)).sum::<f64>()
1732                        / leaf_values.len() as f64;
1733                    (min, max, mean, var.sqrt())
1734                };
1735
1736                let gains = slot.split_gains();
1737                let split_features: Vec<usize> = gains
1738                    .iter()
1739                    .enumerate()
1740                    .filter(|(_, &g)| g > 0.0)
1741                    .map(|(i, _)| i)
1742                    .collect();
1743
1744                if !gains.is_empty() {
1745                    if feature_split_counts.is_empty() {
1746                        feature_split_counts.resize(gains.len(), 0);
1747                    }
1748                    for &fi in &split_features {
1749                        if fi < feature_split_counts.len() {
1750                            feature_split_counts[fi] += 1;
1751                        }
1752                    }
1753                }
1754
1755                trees.push(TreeDiagnostic {
1756                    n_leaves: leaf_values.len(),
1757                    max_depth_reached,
1758                    samples_seen: step.n_samples_seen(),
1759                    leaf_weight_stats,
1760                    split_features,
1761                    leaf_sample_counts,
1762                    prediction_mean: slot.prediction_mean(),
1763                    prediction_std: slot.prediction_std(),
1764                });
1765            }
1766        }
1767
1768        collect_tree_diags(&self.location_steps, &mut trees, &mut feature_split_counts);
1769        collect_tree_diags(&self.scale_steps, &mut trees, &mut feature_split_counts);
1770
1771        let location_trees = trees[..n].to_vec();
1772        let scale_trees = trees[n..].to_vec();
1773        let scale_trees_active = scale_trees.iter().filter(|t| t.n_leaves > 1).count();
1774
1775        ModelDiagnostics {
1776            trees,
1777            location_trees,
1778            scale_trees,
1779            feature_split_counts,
1780            location_base: self.location_base,
1781            scale_base: self.scale_base,
1782            empirical_sigma: self.ewma_sq_err.sqrt(),
1783            scale_mode: self.scale_mode,
1784            scale_trees_active,
1785            auto_bandwidths: self.auto_bandwidths.clone(),
1786            ensemble_grad_mean: self.ensemble_grad_mean,
1787            ensemble_grad_std: self.ensemble_grad_std(),
1788        }
1789    }
1790
1791    /// Ensemble-level diagnostics (location + optional scale) with per-tree
1792    /// contributions for a given input.
1793    ///
1794    /// Returns [`DistributionalDiagnostics`](crate::ensemble::diagnostics::DistributionalDiagnostics)
1795    /// containing location ensemble diagnostics, optional scale diagnostics
1796    /// (when in `TreeChain` mode), current `honest_sigma`, and the rolling
1797    /// `honest_sigma` baseline.
1798    pub fn ensemble_diagnostics(
1799        &self,
1800        features: &[f64],
1801    ) -> crate::ensemble::diagnostics::DistributionalDiagnostics {
1802        use crate::ensemble::diagnostics::build_ensemble_diagnostics;
1803
1804        let location = build_ensemble_diagnostics(
1805            &self.location_steps,
1806            self.location_base,
1807            self.config.learning_rate,
1808            self.samples_seen,
1809            Some(features),
1810        );
1811
1812        let scale = match self.scale_mode {
1813            ScaleMode::TreeChain => Some(build_ensemble_diagnostics(
1814                &self.scale_steps,
1815                self.scale_base,
1816                self.config.learning_rate,
1817                self.samples_seen,
1818                Some(features),
1819            )),
1820            ScaleMode::Empirical => None,
1821        };
1822
1823        let honest_sigma = self.compute_honest_sigma(features);
1824
1825        // Compute effective_mts if adaptive_mts is configured
1826        let effective_mts = self.config.adaptive_mts.map(|(base_mts, k)| {
1827            let sigma_ratio = if self.rolling_honest_sigma_mean > 1e-12 {
1828                honest_sigma / self.rolling_honest_sigma_mean
1829            } else {
1830                1.0
1831            };
1832            (base_mts as f64 / (1.0 + k * sigma_ratio)).max(100.0) as u64
1833        });
1834
1835        crate::ensemble::diagnostics::DistributionalDiagnostics {
1836            location,
1837            scale,
1838            honest_sigma,
1839            rolling_honest_sigma_mean: self.rolling_honest_sigma_mean,
1840            effective_mts,
1841        }
1842    }
1843
1844    /// Per-tree contribution to the final prediction.
1845    ///
1846    /// Returns two vectors: location contributions and scale contributions.
1847    /// Each entry is `learning_rate * tree_prediction` -- the additive
1848    /// contribution of that boosting step to the final μ or log(σ).
1849    ///
1850    /// Summing `location_base + sum(location_contributions)` recovers μ.
1851    /// Summing `scale_base + sum(scale_contributions)` recovers log(σ).
1852    ///
1853    /// In `Empirical` scale mode, `scale_base` is `ln(empirical_sigma)` and
1854    /// `scale_contributions` are all zero (σ is not tree-derived).
1855    pub fn predict_decomposed(&self, features: &[f64]) -> DecomposedPrediction {
1856        let lr = self.config.learning_rate;
1857        let location: Vec<f64> = self
1858            .location_steps
1859            .iter()
1860            .map(|s| lr * s.predict(features))
1861            .collect();
1862
1863        let (sb, scale) = match self.scale_mode {
1864            ScaleMode::Empirical => {
1865                let empirical_sigma = self.ewma_sq_err.sqrt().max(1e-8);
1866                (empirical_sigma.ln(), vec![0.0; self.location_steps.len()])
1867            }
1868            ScaleMode::TreeChain => {
1869                let s: Vec<f64> = self
1870                    .scale_steps
1871                    .iter()
1872                    .map(|s| lr * s.predict(features))
1873                    .collect();
1874                (self.scale_base, s)
1875            }
1876        };
1877
1878        DecomposedPrediction {
1879            location_base: self.location_base,
1880            scale_base: sb,
1881            location_contributions: location,
1882            scale_contributions: scale,
1883        }
1884    }
1885
1886    /// Feature importances based on accumulated split gains across all trees.
1887    ///
1888    /// Aggregates gains from both location and scale ensembles, then
1889    /// normalizes to sum to 1.0. Indexed by feature.
1890    /// Returns an empty Vec if no splits have occurred yet.
1891    pub fn feature_importances(&self) -> Vec<f64> {
1892        let mut totals: Vec<f64> = Vec::new();
1893        for steps in [&self.location_steps, &self.scale_steps] {
1894            for step in steps {
1895                let gains = step.slot().split_gains();
1896                if totals.is_empty() && !gains.is_empty() {
1897                    totals.resize(gains.len(), 0.0);
1898                }
1899                for (i, &g) in gains.iter().enumerate() {
1900                    if i < totals.len() {
1901                        totals[i] += g;
1902                    }
1903                }
1904            }
1905        }
1906        let sum: f64 = totals.iter().sum();
1907        if sum > 0.0 {
1908            totals.iter_mut().for_each(|v| *v /= sum);
1909        }
1910        totals
1911    }
1912
1913    /// Feature importances split by ensemble: `(location_importances, scale_importances)`.
1914    ///
1915    /// Each vector is independently normalized to sum to 1.0.
1916    /// Useful for understanding which features drive the mean vs. the uncertainty.
1917    pub fn feature_importances_split(&self) -> (Vec<f64>, Vec<f64>) {
1918        fn aggregate(steps: &[BoostingStep]) -> Vec<f64> {
1919            let mut totals: Vec<f64> = Vec::new();
1920            for step in steps {
1921                let gains = step.slot().split_gains();
1922                if totals.is_empty() && !gains.is_empty() {
1923                    totals.resize(gains.len(), 0.0);
1924                }
1925                for (i, &g) in gains.iter().enumerate() {
1926                    if i < totals.len() {
1927                        totals[i] += g;
1928                    }
1929                }
1930            }
1931            let sum: f64 = totals.iter().sum();
1932            if sum > 0.0 {
1933                totals.iter_mut().for_each(|v| *v /= sum);
1934            }
1935            totals
1936        }
1937        (
1938            aggregate(&self.location_steps),
1939            aggregate(&self.scale_steps),
1940        )
1941    }
1942
1943    /// Convert this model into a serializable [`crate::serde_support::DistributionalModelState`].
1944    ///
1945    /// Captures the full ensemble state (both location and scale trees) for
1946    /// persistence. Histogram accumulators are NOT serialized -- they rebuild
1947    /// naturally from continued training.
1948    #[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
1949    pub fn to_distributional_state(&self) -> crate::serde_support::DistributionalModelState {
1950        use super::snapshot_tree;
1951        use crate::serde_support::{DistributionalModelState, StepSnapshot};
1952
1953        fn snapshot_step(step: &BoostingStep) -> StepSnapshot {
1954            let slot = step.slot();
1955            let tree_snap = snapshot_tree(slot.active_tree());
1956            let alt_snap = slot.alternate_tree().map(snapshot_tree);
1957            let drift_state = slot.detector().serialize_state();
1958            let alt_drift_state = slot.alt_detector().and_then(|d| d.serialize_state());
1959            StepSnapshot {
1960                tree: tree_snap,
1961                alternate_tree: alt_snap,
1962                drift_state,
1963                alt_drift_state,
1964            }
1965        }
1966
1967        DistributionalModelState {
1968            config: self.config.clone(),
1969            location_steps: self.location_steps.iter().map(snapshot_step).collect(),
1970            scale_steps: self.scale_steps.iter().map(snapshot_step).collect(),
1971            location_base: self.location_base,
1972            scale_base: self.scale_base,
1973            base_initialized: self.base_initialized,
1974            initial_targets: self.initial_targets.clone(),
1975            initial_target_count: self.initial_target_count,
1976            samples_seen: self.samples_seen,
1977            rng_state: self.rng_state,
1978            uncertainty_modulated_lr: self.uncertainty_modulated_lr,
1979            rolling_sigma_mean: self.rolling_sigma_mean,
1980            ewma_sq_err: self.ewma_sq_err,
1981            rolling_honest_sigma_mean: self.rolling_honest_sigma_mean,
1982        }
1983    }
1984
1985    /// Reconstruct a [`DistributionalSGBT`] from a serialized [`crate::serde_support::DistributionalModelState`].
1986    ///
1987    /// Rebuilds both location and scale ensembles including tree topology
1988    /// and leaf values. Histogram accumulators are left empty and will
1989    /// rebuild from continued training.
1990    #[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
1991    pub fn from_distributional_state(
1992        state: crate::serde_support::DistributionalModelState,
1993    ) -> Self {
1994        use super::rebuild_tree;
1995        use crate::ensemble::replacement::TreeSlot;
1996        use crate::serde_support::StepSnapshot;
1997
1998        let leaf_decay_alpha = state
1999            .config
2000            .leaf_half_life
2001            .map(|hl| (-(2.0_f64.ln()) / hl as f64).exp());
2002        let max_tree_samples = state.config.max_tree_samples;
2003
2004        let base_tree_config = TreeConfig::new()
2005            .max_depth(state.config.max_depth)
2006            .n_bins(state.config.n_bins)
2007            .lambda(state.config.lambda)
2008            .gamma(state.config.gamma)
2009            .grace_period(state.config.grace_period)
2010            .delta(state.config.delta)
2011            .feature_subsample_rate(state.config.feature_subsample_rate)
2012            .leaf_decay_alpha_opt(leaf_decay_alpha)
2013            .split_reeval_interval_opt(state.config.split_reeval_interval)
2014            .feature_types_opt(state.config.feature_types.clone())
2015            .gradient_clip_sigma_opt(state.config.gradient_clip_sigma)
2016            .monotone_constraints_opt(state.config.monotone_constraints.clone())
2017            .adaptive_depth_opt(state.config.adaptive_depth)
2018            .leaf_model_type(state.config.leaf_model_type.clone());
2019
2020        // Rebuild a Vec<BoostingStep> from step snapshots with a given seed transform.
2021        let rebuild_steps = |snaps: &[StepSnapshot], seed_xor: u64| -> Vec<BoostingStep> {
2022            snaps
2023                .iter()
2024                .enumerate()
2025                .map(|(i, snap)| {
2026                    let tc = base_tree_config
2027                        .clone()
2028                        .seed(state.config.seed ^ (i as u64) ^ seed_xor);
2029
2030                    let active = rebuild_tree(&snap.tree, tc.clone());
2031                    let alternate = snap
2032                        .alternate_tree
2033                        .as_ref()
2034                        .map(|s| rebuild_tree(s, tc.clone()));
2035
2036                    let mut detector = state.config.drift_detector.create();
2037                    if let Some(ref ds) = snap.drift_state {
2038                        detector.restore_state(ds);
2039                    }
2040                    let mut slot =
2041                        TreeSlot::from_trees(active, alternate, tc, detector, max_tree_samples);
2042                    if let Some(ref ads) = snap.alt_drift_state {
2043                        if let Some(alt_det) = slot.alt_detector_mut() {
2044                            alt_det.restore_state(ads);
2045                        }
2046                    }
2047                    BoostingStep::from_slot(slot)
2048                })
2049                .collect()
2050        };
2051
2052        // Location: seed offset 0, Scale: seed offset 0x0005_CA1E_0000_0000
2053        let location_steps = rebuild_steps(&state.location_steps, 0);
2054        let scale_steps = rebuild_steps(&state.scale_steps, 0x0005_CA1E_0000_0000);
2055
2056        let scale_mode = state.config.scale_mode;
2057        let empirical_sigma_alpha = state.config.empirical_sigma_alpha;
2058        let packed_refresh_interval = state.config.packed_refresh_interval;
2059        let n_location_steps = location_steps.len();
2060        let prune_alpha = if state.config.proactive_prune_interval.is_some() {
2061            let gp = state.config.grace_period.max(1) as f64;
2062            1.0 - (-2.0 / gp).exp()
2063        } else {
2064            0.01
2065        };
2066        Self {
2067            config: state.config,
2068            location_steps,
2069            scale_steps,
2070            location_base: state.location_base,
2071            scale_base: state.scale_base,
2072            base_initialized: state.base_initialized,
2073            initial_targets: state.initial_targets,
2074            initial_target_count: state.initial_target_count,
2075            samples_seen: state.samples_seen,
2076            rng_state: state.rng_state,
2077            uncertainty_modulated_lr: state.uncertainty_modulated_lr,
2078            rolling_sigma_mean: state.rolling_sigma_mean,
2079            scale_mode,
2080            ewma_sq_err: state.ewma_sq_err,
2081            empirical_sigma_alpha,
2082            prev_sigma: 0.0,
2083            sigma_velocity: 0.0,
2084            auto_bandwidths: Vec::new(),
2085            last_replacement_sum: 0,
2086            ensemble_grad_mean: 0.0,
2087            ensemble_grad_m2: 0.0,
2088            ensemble_grad_count: 0,
2089            rolling_honest_sigma_mean: state.rolling_honest_sigma_mean,
2090            packed_cache: None,
2091            samples_since_refresh: 0,
2092            packed_refresh_interval,
2093            prev_contributions: Vec::new(),
2094            prev_prev_contributions: Vec::new(),
2095            cached_residual_alignment: 0.0,
2096            cached_reg_sensitivity: 0.0,
2097            cached_depth_sufficiency: 0.0,
2098            cached_effective_dof: 0.0,
2099            contribution_accuracy: vec![0.0; n_location_steps],
2100            prune_alpha,
2101        }
2102    }
2103}
2104
2105// ---------------------------------------------------------------------------
2106// StreamingLearner impl
2107// ---------------------------------------------------------------------------
2108
2109use crate::learner::StreamingLearner;
2110
2111impl StreamingLearner for DistributionalSGBT {
2112    fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
2113        let sample = SampleRef::weighted(features, target, weight);
2114        // UFCS: call the inherent train_one(&impl Observation), not this trait method.
2115        DistributionalSGBT::train_one(self, &sample);
2116    }
2117
2118    /// Returns the mean (μ) of the predicted Gaussian distribution.
2119    fn predict(&self, features: &[f64]) -> f64 {
2120        DistributionalSGBT::predict(self, features).mu
2121    }
2122
2123    fn n_samples_seen(&self) -> u64 {
2124        self.samples_seen
2125    }
2126
2127    fn reset(&mut self) {
2128        DistributionalSGBT::reset(self);
2129    }
2130
2131    fn diagnostics_array(&self) -> [f64; 5] {
2132        [
2133            self.cached_residual_alignment,
2134            self.cached_reg_sensitivity,
2135            self.cached_depth_sufficiency,
2136            self.cached_effective_dof,
2137            self.rolling_honest_sigma_mean(),
2138        ]
2139    }
2140
2141    fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
2142        self.set_learning_rate(self.config.learning_rate * lr_multiplier);
2143        self.set_lambda(self.config.lambda + lambda_delta);
2144    }
2145
2146    fn apply_structural_change(&mut self, depth_delta: i32, steps_delta: i32) {
2147        if depth_delta != 0 {
2148            let current = self.config.max_depth as i32;
2149            self.set_max_depth((current + depth_delta).max(1) as usize);
2150        }
2151        if steps_delta != 0 {
2152            let current = self.config.n_steps as i32;
2153            self.set_n_steps((current + steps_delta).max(3) as usize);
2154        }
2155    }
2156
2157    fn replacement_count(&self) -> u64 {
2158        self.total_replacements()
2159    }
2160}
2161
2162// ---------------------------------------------------------------------------
2163// DiagnosticSource impl
2164// ---------------------------------------------------------------------------
2165
2166impl crate::automl::DiagnosticSource for DistributionalSGBT {
2167    fn config_diagnostics(&self) -> Option<crate::automl::auto_builder::ConfigDiagnostics> {
2168        Some(crate::automl::auto_builder::ConfigDiagnostics {
2169            residual_alignment: self.cached_residual_alignment,
2170            regularization_sensitivity: self.cached_reg_sensitivity,
2171            depth_sufficiency: self.cached_depth_sufficiency,
2172            effective_dof: self.cached_effective_dof,
2173            uncertainty: self.rolling_honest_sigma_mean(),
2174        })
2175    }
2176}
2177
2178// ---------------------------------------------------------------------------
2179// Tests
2180// ---------------------------------------------------------------------------
2181
2182#[cfg(test)]
2183mod tests {
2184    use super::*;
2185
2186    fn test_config() -> SGBTConfig {
2187        SGBTConfig::builder()
2188            .n_steps(10)
2189            .learning_rate(0.1)
2190            .grace_period(20)
2191            .max_depth(4)
2192            .n_bins(16)
2193            .initial_target_count(10)
2194            .build()
2195            .unwrap()
2196    }
2197
2198    #[test]
2199    fn fresh_model_predicts_zero() {
2200        let model = DistributionalSGBT::new(test_config());
2201        let pred = model.predict(&[1.0, 2.0, 3.0]);
2202        assert!(pred.mu.abs() < 1e-12);
2203        assert!(pred.sigma > 0.0);
2204    }
2205
2206    #[test]
2207    fn sigma_always_positive() {
2208        let mut model = DistributionalSGBT::new(test_config());
2209
2210        // Train on various data
2211        for i in 0..200 {
2212            let x = i as f64 * 0.1;
2213            model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
2214        }
2215
2216        // Check multiple predictions
2217        for i in 0..20 {
2218            let x = i as f64 * 0.5;
2219            let pred = model.predict(&[x, x * 0.5]);
2220            assert!(
2221                pred.sigma > 0.0,
2222                "sigma must be positive, got {}",
2223                pred.sigma
2224            );
2225            assert!(pred.sigma.is_finite(), "sigma must be finite");
2226        }
2227    }
2228
2229    #[test]
2230    fn constant_target_has_small_sigma() {
2231        let mut model = DistributionalSGBT::new(test_config());
2232
2233        // Train on constant target = 5.0 with varying features
2234        for i in 0..200 {
2235            let x = i as f64 * 0.1;
2236            model.train_one(&(vec![x, x * 2.0], 5.0));
2237        }
2238
2239        let pred = model.predict(&[1.0, 2.0]);
2240        assert!(pred.mu.is_finite());
2241        assert!(pred.sigma.is_finite());
2242        assert!(pred.sigma > 0.0);
2243        // Sigma should be relatively small for constant target
2244        // (no noise to explain)
2245    }
2246
2247    #[test]
2248    fn noisy_target_has_finite_predictions() {
2249        let mut model = DistributionalSGBT::new(test_config());
2250
2251        // Simple xorshift for deterministic "noise"
2252        let mut rng: u64 = 42;
2253        for i in 0..200 {
2254            rng ^= rng << 13;
2255            rng ^= rng >> 7;
2256            rng ^= rng << 17;
2257            let noise = (rng % 1000) as f64 / 500.0 - 1.0; // [-1, 1]
2258            let x = i as f64 * 0.1;
2259            model.train_one(&(vec![x], x * 2.0 + noise));
2260        }
2261
2262        let pred = model.predict(&[5.0]);
2263        assert!(pred.mu.is_finite());
2264        assert!(pred.sigma.is_finite());
2265        assert!(pred.sigma > 0.0);
2266    }
2267
2268    #[test]
2269    fn predict_interval_bounds_correct() {
2270        let mut model = DistributionalSGBT::new(test_config());
2271
2272        for i in 0..200 {
2273            let x = i as f64 * 0.1;
2274            model.train_one(&(vec![x], x * 2.0));
2275        }
2276
2277        let (lo, hi) = model.predict_interval(&[5.0], 1.96);
2278        let pred = model.predict(&[5.0]);
2279
2280        assert!(lo < pred.mu, "lower bound should be < mu");
2281        assert!(hi > pred.mu, "upper bound should be > mu");
2282        assert!((hi - lo - 2.0 * 1.96 * pred.sigma).abs() < 1e-10);
2283    }
2284
2285    #[test]
2286    fn batch_prediction_matches_individual() {
2287        let mut model = DistributionalSGBT::new(test_config());
2288
2289        for i in 0..100 {
2290            let x = i as f64 * 0.1;
2291            model.train_one(&(vec![x, x * 2.0], x));
2292        }
2293
2294        let features = vec![vec![1.0, 2.0], vec![3.0, 6.0], vec![5.0, 10.0]];
2295        let batch = model.predict_batch(&features);
2296
2297        for (feat, batch_pred) in features.iter().zip(batch.iter()) {
2298            let individual = model.predict(feat);
2299            assert!((batch_pred.mu - individual.mu).abs() < 1e-12);
2300            assert!((batch_pred.sigma - individual.sigma).abs() < 1e-12);
2301        }
2302    }
2303
2304    #[test]
2305    fn reset_clears_state() {
2306        let mut model = DistributionalSGBT::new(test_config());
2307
2308        for i in 0..200 {
2309            let x = i as f64 * 0.1;
2310            model.train_one(&(vec![x], x * 2.0));
2311        }
2312
2313        assert!(model.n_samples_seen() > 0);
2314        model.reset();
2315
2316        assert_eq!(model.n_samples_seen(), 0);
2317        assert!(!model.is_initialized());
2318    }
2319
2320    #[test]
2321    fn gaussian_prediction_lower_upper() {
2322        let pred = GaussianPrediction {
2323            mu: 10.0,
2324            sigma: 2.0,
2325            log_sigma: 2.0_f64.ln(),
2326            honest_sigma: 0.0,
2327        };
2328
2329        assert!((pred.lower(1.96) - (10.0 - 1.96 * 2.0)).abs() < 1e-10);
2330        assert!((pred.upper(1.96) - (10.0 + 1.96 * 2.0)).abs() < 1e-10);
2331    }
2332
2333    #[test]
2334    fn train_batch_works() {
2335        let mut model = DistributionalSGBT::new(test_config());
2336        let samples: Vec<(Vec<f64>, f64)> = (0..100)
2337            .map(|i| {
2338                let x = i as f64 * 0.1;
2339                (vec![x], x * 2.0)
2340            })
2341            .collect();
2342
2343        model.train_batch(&samples);
2344        assert_eq!(model.n_samples_seen(), 100);
2345    }
2346
2347    #[test]
2348    fn debug_format_works() {
2349        let model = DistributionalSGBT::new(test_config());
2350        let debug = format!("{:?}", model);
2351        assert!(debug.contains("DistributionalSGBT"));
2352    }
2353
2354    #[test]
2355    fn n_trees_counts_both_ensembles() {
2356        let model = DistributionalSGBT::new(test_config());
2357        // 10 location + 10 scale = 20 trees minimum
2358        assert!(model.n_trees() >= 20);
2359    }
2360
2361    // -- σ-modulated learning rate tests --
2362
2363    fn modulated_config() -> SGBTConfig {
2364        SGBTConfig::builder()
2365            .n_steps(10)
2366            .learning_rate(0.1)
2367            .grace_period(20)
2368            .max_depth(4)
2369            .n_bins(16)
2370            .initial_target_count(10)
2371            .uncertainty_modulated_lr(true)
2372            .build()
2373            .unwrap()
2374    }
2375
2376    #[test]
2377    fn sigma_modulated_initializes_rolling_mean() {
2378        let mut model = DistributionalSGBT::new(modulated_config());
2379        assert!(model.is_uncertainty_modulated());
2380
2381        // Before initialization, rolling_sigma_mean is 1.0 (placeholder)
2382        assert!((model.rolling_sigma_mean() - 1.0).abs() < 1e-12);
2383
2384        // Train past initialization
2385        for i in 0..200 {
2386            let x = i as f64 * 0.1;
2387            model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
2388        }
2389
2390        // After training, rolling_sigma_mean should have adapted from initial std
2391        assert!(model.rolling_sigma_mean() > 0.0);
2392        assert!(model.rolling_sigma_mean().is_finite());
2393    }
2394
2395    #[test]
2396    fn predict_distributional_returns_sigma_ratio() {
2397        let mut model = DistributionalSGBT::new(modulated_config());
2398
2399        for i in 0..200 {
2400            let x = i as f64 * 0.1;
2401            model.train_one(&(vec![x], x * 2.0 + 1.0));
2402        }
2403
2404        let (mu, sigma, sigma_ratio) = model.predict_distributional(&[5.0]);
2405        assert!(mu.is_finite());
2406        assert!(sigma > 0.0);
2407        assert!(
2408            (0.1..=10.0).contains(&sigma_ratio),
2409            "sigma_ratio={}",
2410            sigma_ratio
2411        );
2412    }
2413
2414    #[test]
2415    fn predict_distributional_without_modulation_returns_one() {
2416        let mut model = DistributionalSGBT::new(test_config());
2417        assert!(!model.is_uncertainty_modulated());
2418
2419        for i in 0..200 {
2420            let x = i as f64 * 0.1;
2421            model.train_one(&(vec![x], x * 2.0));
2422        }
2423
2424        let (_mu, _sigma, sigma_ratio) = model.predict_distributional(&[5.0]);
2425        assert!(
2426            (sigma_ratio - 1.0).abs() < 1e-12,
2427            "should be 1.0 when disabled"
2428        );
2429    }
2430
2431    #[test]
2432    fn modulated_model_sigma_finite_under_varying_noise() {
2433        let mut model = DistributionalSGBT::new(modulated_config());
2434
2435        let mut rng: u64 = 123;
2436        for i in 0..500 {
2437            rng ^= rng << 13;
2438            rng ^= rng >> 7;
2439            rng ^= rng << 17;
2440            let noise = (rng % 1000) as f64 / 100.0 - 5.0; // [-5, 5]
2441            let x = i as f64 * 0.1;
2442            // Regime shift at i=250: noise amplitude increases
2443            let scale = if i < 250 { 1.0 } else { 5.0 };
2444            model.train_one(&(vec![x], x * 2.0 + noise * scale));
2445        }
2446
2447        let pred = model.predict(&[10.0]);
2448        assert!(pred.mu.is_finite());
2449        assert!(pred.sigma.is_finite());
2450        assert!(pred.sigma > 0.0);
2451        assert!(model.rolling_sigma_mean().is_finite());
2452    }
2453
2454    #[test]
2455    fn reset_clears_rolling_sigma_mean() {
2456        let mut model = DistributionalSGBT::new(modulated_config());
2457
2458        for i in 0..200 {
2459            let x = i as f64 * 0.1;
2460            model.train_one(&(vec![x], x * 2.0));
2461        }
2462
2463        let sigma_before = model.rolling_sigma_mean();
2464        assert!(sigma_before > 0.0);
2465
2466        model.reset();
2467        assert!((model.rolling_sigma_mean() - 1.0).abs() < 1e-12);
2468    }
2469
2470    #[test]
2471    fn streaming_learner_returns_mu() {
2472        let mut model = DistributionalSGBT::new(test_config());
2473        for i in 0..200 {
2474            let x = i as f64 * 0.1;
2475            StreamingLearner::train(&mut model, &[x], x * 2.0 + 1.0);
2476        }
2477        let pred = StreamingLearner::predict(&model, &[5.0]);
2478        let gaussian = DistributionalSGBT::predict(&model, &[5.0]);
2479        assert!(
2480            (pred - gaussian.mu).abs() < 1e-12,
2481            "StreamingLearner::predict should return mu"
2482        );
2483    }
2484
2485    // -- Diagnostics tests --
2486
2487    fn trained_model() -> DistributionalSGBT {
2488        let config = SGBTConfig::builder()
2489            .n_steps(10)
2490            .learning_rate(0.1)
2491            .grace_period(10) // low grace period to ensure splits
2492            .max_depth(4)
2493            .n_bins(16)
2494            .initial_target_count(10)
2495            .build()
2496            .unwrap();
2497        let mut model = DistributionalSGBT::new(config);
2498        for i in 0..500 {
2499            let x = i as f64 * 0.1;
2500            model.train_one(&(vec![x, x * 0.5, (i % 3) as f64], x * 2.0 + 1.0));
2501        }
2502        model
2503    }
2504
2505    #[test]
2506    fn diagnostics_returns_correct_tree_count() {
2507        let model = trained_model();
2508        let diag = model.diagnostics();
2509        // 10 location + 10 scale = 20 trees
2510        assert_eq!(diag.trees.len(), 20, "should have 2*n_steps trees");
2511    }
2512
2513    #[test]
2514    fn diagnostics_trees_have_leaves() {
2515        let model = trained_model();
2516        let diag = model.diagnostics();
2517        for (i, tree) in diag.trees.iter().enumerate() {
2518            assert!(tree.n_leaves >= 1, "tree {i} should have at least 1 leaf");
2519        }
2520        // At least some trees should have seen samples.
2521        let total_samples: u64 = diag.trees.iter().map(|t| t.samples_seen).sum();
2522        assert!(
2523            total_samples > 0,
2524            "at least some trees should have seen samples"
2525        );
2526    }
2527
2528    #[test]
2529    fn diagnostics_leaf_weight_stats_finite() {
2530        let model = trained_model();
2531        let diag = model.diagnostics();
2532        for (i, tree) in diag.trees.iter().enumerate() {
2533            let (min, max, mean, std) = tree.leaf_weight_stats;
2534            assert!(min.is_finite(), "tree {i} min not finite");
2535            assert!(max.is_finite(), "tree {i} max not finite");
2536            assert!(mean.is_finite(), "tree {i} mean not finite");
2537            assert!(std.is_finite(), "tree {i} std not finite");
2538            assert!(min <= max, "tree {i} min > max");
2539        }
2540    }
2541
2542    #[test]
2543    fn diagnostics_base_predictions_match() {
2544        let model = trained_model();
2545        let diag = model.diagnostics();
2546        assert!(
2547            (diag.location_base - model.predict(&[0.0, 0.0, 0.0]).mu).abs() < 100.0,
2548            "location_base should be plausible"
2549        );
2550    }
2551
2552    #[test]
2553    fn predict_decomposed_reconstructs_prediction() {
2554        let model = trained_model();
2555        let features = [5.0, 2.5, 1.0];
2556        let pred = model.predict(&features);
2557        let decomp = model.predict_decomposed(&features);
2558
2559        assert!(
2560            (decomp.mu() - pred.mu).abs() < 1e-10,
2561            "decomposed mu ({}) != predict mu ({})",
2562            decomp.mu(),
2563            pred.mu
2564        );
2565        assert!(
2566            (decomp.sigma() - pred.sigma).abs() < 1e-10,
2567            "decomposed sigma ({}) != predict sigma ({})",
2568            decomp.sigma(),
2569            pred.sigma
2570        );
2571    }
2572
2573    #[test]
2574    fn predict_decomposed_correct_lengths() {
2575        let model = trained_model();
2576        let decomp = model.predict_decomposed(&[1.0, 0.5, 0.0]);
2577        assert_eq!(
2578            decomp.location_contributions.len(),
2579            model.n_steps(),
2580            "location contributions should have n_steps entries"
2581        );
2582        assert_eq!(
2583            decomp.scale_contributions.len(),
2584            model.n_steps(),
2585            "scale contributions should have n_steps entries"
2586        );
2587    }
2588
2589    #[test]
2590    fn feature_importances_work() {
2591        let model = trained_model();
2592        let imp = model.feature_importances();
2593        // After enough training, importances should be non-empty and non-negative.
2594        // They may or may not sum to 1.0 if no splits have occurred yet.
2595        for (i, &v) in imp.iter().enumerate() {
2596            assert!(v >= 0.0, "importance {i} should be non-negative, got {v}");
2597            assert!(v.is_finite(), "importance {i} should be finite");
2598        }
2599        let sum: f64 = imp.iter().sum();
2600        if sum > 0.0 {
2601            assert!(
2602                (sum - 1.0).abs() < 1e-10,
2603                "non-zero importances should sum to 1.0, got {sum}"
2604            );
2605        }
2606    }
2607
2608    #[test]
2609    fn feature_importances_split_works() {
2610        let model = trained_model();
2611        let (loc_imp, scale_imp) = model.feature_importances_split();
2612        for (name, imp) in [("location", &loc_imp), ("scale", &scale_imp)] {
2613            let sum: f64 = imp.iter().sum();
2614            if sum > 0.0 {
2615                assert!(
2616                    (sum - 1.0).abs() < 1e-10,
2617                    "{name} importances should sum to 1.0, got {sum}"
2618                );
2619            }
2620            for &v in imp.iter() {
2621                assert!(v >= 0.0 && v.is_finite());
2622            }
2623        }
2624    }
2625
2626    // -- Empirical σ tests --
2627
2628    #[test]
2629    fn empirical_sigma_default_mode() {
2630        use crate::ensemble::config::ScaleMode;
2631        let config = test_config();
2632        let model = DistributionalSGBT::new(config);
2633        assert_eq!(model.scale_mode(), ScaleMode::Empirical);
2634    }
2635
2636    #[test]
2637    fn empirical_sigma_tracks_errors() {
2638        let mut model = DistributionalSGBT::new(test_config());
2639
2640        // Train on clean linear data
2641        for i in 0..200 {
2642            let x = i as f64 * 0.1;
2643            model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
2644        }
2645
2646        let sigma_clean = model.empirical_sigma();
2647        assert!(sigma_clean > 0.0, "sigma should be positive");
2648        assert!(sigma_clean.is_finite(), "sigma should be finite");
2649
2650        // Now train on noisy data — sigma should increase
2651        let mut rng: u64 = 42;
2652        for i in 200..400 {
2653            rng ^= rng << 13;
2654            rng ^= rng >> 7;
2655            rng ^= rng << 17;
2656            let noise = (rng % 10000) as f64 / 100.0 - 50.0; // big noise
2657            let x = i as f64 * 0.1;
2658            model.train_one(&(vec![x, x * 0.5], x * 2.0 + noise));
2659        }
2660
2661        let sigma_noisy = model.empirical_sigma();
2662        assert!(
2663            sigma_noisy > sigma_clean,
2664            "noisy regime should increase sigma: clean={sigma_clean}, noisy={sigma_noisy}"
2665        );
2666    }
2667
2668    #[test]
2669    fn empirical_sigma_modulated_lr_adapts() {
2670        let config = SGBTConfig::builder()
2671            .n_steps(10)
2672            .learning_rate(0.1)
2673            .grace_period(20)
2674            .max_depth(4)
2675            .n_bins(16)
2676            .initial_target_count(10)
2677            .uncertainty_modulated_lr(true)
2678            .build()
2679            .unwrap();
2680        let mut model = DistributionalSGBT::new(config);
2681
2682        // Train and verify sigma_ratio changes
2683        for i in 0..300 {
2684            let x = i as f64 * 0.1;
2685            model.train_one(&(vec![x], x * 2.0 + 1.0));
2686        }
2687
2688        let (_, _, sigma_ratio) = model.predict_distributional(&[5.0]);
2689        assert!(sigma_ratio.is_finite());
2690        assert!(
2691            (0.1..=10.0).contains(&sigma_ratio),
2692            "sigma_ratio={sigma_ratio}"
2693        );
2694    }
2695
2696    #[test]
2697    fn tree_chain_mode_trains_scale_trees() {
2698        use crate::ensemble::config::ScaleMode;
2699        let config = SGBTConfig::builder()
2700            .n_steps(10)
2701            .learning_rate(0.1)
2702            .grace_period(10)
2703            .max_depth(4)
2704            .n_bins(16)
2705            .initial_target_count(10)
2706            .scale_mode(ScaleMode::TreeChain)
2707            .build()
2708            .unwrap();
2709        let mut model = DistributionalSGBT::new(config);
2710        assert_eq!(model.scale_mode(), ScaleMode::TreeChain);
2711
2712        for i in 0..500 {
2713            let x = i as f64 * 0.1;
2714            model.train_one(&(vec![x, x * 0.5, (i % 3) as f64], x * 2.0 + 1.0));
2715        }
2716
2717        let pred = model.predict(&[5.0, 2.5, 1.0]);
2718        assert!(pred.mu.is_finite());
2719        assert!(pred.sigma > 0.0);
2720        assert!(pred.sigma.is_finite());
2721    }
2722
2723    #[test]
2724    fn diagnostics_shows_empirical_sigma() {
2725        let model = trained_model();
2726        let diag = model.diagnostics();
2727        assert!(
2728            diag.empirical_sigma > 0.0,
2729            "empirical_sigma should be positive"
2730        );
2731        assert!(
2732            diag.empirical_sigma.is_finite(),
2733            "empirical_sigma should be finite"
2734        );
2735    }
2736
2737    #[test]
2738    fn diagnostics_scale_trees_split_fields() {
2739        let model = trained_model();
2740        let diag = model.diagnostics();
2741        assert_eq!(diag.location_trees.len(), model.n_steps());
2742        assert_eq!(diag.scale_trees.len(), model.n_steps());
2743        // In empirical mode, scale_trees_active might be 0 (trees not trained)
2744        // This is expected and actually the point.
2745    }
2746
2747    #[test]
2748    fn reset_clears_empirical_sigma() {
2749        let mut model = DistributionalSGBT::new(test_config());
2750        for i in 0..200 {
2751            let x = i as f64 * 0.1;
2752            model.train_one(&(vec![x], x * 2.0));
2753        }
2754        model.reset();
2755        // After reset, ewma_sq_err resets to 1.0
2756        assert!((model.empirical_sigma() - 1.0).abs() < 1e-12);
2757    }
2758
2759    #[test]
2760    fn predict_smooth_returns_finite() {
2761        let config = SGBTConfig::builder()
2762            .n_steps(3)
2763            .learning_rate(0.1)
2764            .grace_period(20)
2765            .max_depth(4)
2766            .n_bins(16)
2767            .initial_target_count(10)
2768            .build()
2769            .unwrap();
2770        let mut model = DistributionalSGBT::new(config);
2771
2772        for i in 0..200 {
2773            let x = (i as f64) * 0.1;
2774            let features = vec![x, x.sin()];
2775            let target = 2.0 * x + 1.0;
2776            model.train_one(&(features, target));
2777        }
2778
2779        let pred = model.predict_smooth(&[1.0, 1.0_f64.sin()], 0.5);
2780        assert!(pred.mu.is_finite(), "smooth mu should be finite");
2781        assert!(pred.sigma.is_finite(), "smooth sigma should be finite");
2782        assert!(pred.sigma > 0.0, "smooth sigma should be positive");
2783    }
2784
2785    // -- PD sigma modulation tests --
2786
2787    #[test]
2788    fn sigma_velocity_responds_to_error_spike() {
2789        let config = SGBTConfig::builder()
2790            .n_steps(3)
2791            .learning_rate(0.1)
2792            .grace_period(20)
2793            .max_depth(4)
2794            .n_bins(16)
2795            .initial_target_count(10)
2796            .uncertainty_modulated_lr(true)
2797            .build()
2798            .unwrap();
2799        let mut model = DistributionalSGBT::new(config);
2800
2801        // Stable phase
2802        for i in 0..200 {
2803            let x = (i as f64) * 0.1;
2804            model.train_one(&(vec![x, x.sin()], 2.0 * x + 1.0));
2805        }
2806
2807        let velocity_before = model.sigma_velocity();
2808
2809        // Error spike -- sudden regime change
2810        for i in 0..50 {
2811            let x = (i as f64) * 0.1;
2812            model.train_one(&(vec![x, x.sin()], 100.0 * x + 50.0));
2813        }
2814
2815        let velocity_after = model.sigma_velocity();
2816
2817        // Velocity should be positive (sigma increasing due to errors)
2818        assert!(
2819            velocity_after > velocity_before,
2820            "sigma velocity should increase after error spike: before={}, after={}",
2821            velocity_before,
2822            velocity_after,
2823        );
2824    }
2825
2826    #[test]
2827    fn sigma_velocity_getter_works() {
2828        let config = SGBTConfig::builder()
2829            .n_steps(2)
2830            .learning_rate(0.1)
2831            .grace_period(20)
2832            .max_depth(4)
2833            .n_bins(16)
2834            .initial_target_count(10)
2835            .build()
2836            .unwrap();
2837        let model = DistributionalSGBT::new(config);
2838        // Fresh model should have zero velocity
2839        assert_eq!(model.sigma_velocity(), 0.0);
2840    }
2841
2842    #[test]
2843    fn diagnostics_leaf_sample_counts_populated() {
2844        let config = SGBTConfig::builder()
2845            .n_steps(3)
2846            .learning_rate(0.1)
2847            .grace_period(10)
2848            .max_depth(4)
2849            .n_bins(16)
2850            .initial_target_count(10)
2851            .build()
2852            .unwrap();
2853        let mut model = DistributionalSGBT::new(config);
2854
2855        for i in 0..200 {
2856            let x = (i as f64) * 0.1;
2857            let features = vec![x, x.sin()];
2858            let target = 2.0 * x + 1.0;
2859            model.train_one(&(features, target));
2860        }
2861
2862        let diags = model.diagnostics();
2863        for (ti, tree) in diags.trees.iter().enumerate() {
2864            assert_eq!(
2865                tree.leaf_sample_counts.len(),
2866                tree.n_leaves,
2867                "tree {} should have sample count per leaf",
2868                ti,
2869            );
2870            // Total samples across leaves should equal samples_seen for trees with data
2871            if tree.samples_seen > 0 {
2872                let total: u64 = tree.leaf_sample_counts.iter().sum();
2873                assert!(
2874                    total > 0,
2875                    "tree {} has {} samples_seen but leaf counts sum to 0",
2876                    ti,
2877                    tree.samples_seen,
2878                );
2879            }
2880        }
2881    }
2882
2883    // -------------------------------------------------------------------
2884    // Packed cache tests (dual-path inference)
2885    // -------------------------------------------------------------------
2886
2887    #[test]
2888    fn packed_cache_disabled_by_default() {
2889        let model = DistributionalSGBT::new(test_config());
2890        assert!(!model.has_packed_cache());
2891        assert_eq!(model.config().packed_refresh_interval, 0);
2892    }
2893
2894    #[test]
2895    fn packed_cache_refreshes_after_interval() {
2896        let config = SGBTConfig::builder()
2897            .n_steps(5)
2898            .learning_rate(0.1)
2899            .grace_period(5)
2900            .max_depth(3)
2901            .n_bins(8)
2902            .initial_target_count(10)
2903            .packed_refresh_interval(20)
2904            .build()
2905            .unwrap();
2906
2907        let mut model = DistributionalSGBT::new(config);
2908
2909        // Train past initialization (10 samples) + refresh interval (20 samples)
2910        for i in 0..40 {
2911            let x = i as f64 * 0.1;
2912            model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
2913        }
2914
2915        // Cache should exist after enough training
2916        assert!(
2917            model.has_packed_cache(),
2918            "packed cache should exist after training past refresh interval"
2919        );
2920
2921        // Predictions should be finite
2922        let pred = model.predict(&[2.0, 4.0, 1.0]);
2923        assert!(pred.mu.is_finite(), "mu should be finite: {}", pred.mu);
2924    }
2925
2926    #[test]
2927    fn packed_cache_matches_full_tree() {
2928        let config = SGBTConfig::builder()
2929            .n_steps(5)
2930            .learning_rate(0.1)
2931            .grace_period(5)
2932            .max_depth(3)
2933            .n_bins(8)
2934            .initial_target_count(10)
2935            .build()
2936            .unwrap();
2937
2938        let mut model = DistributionalSGBT::new(config);
2939
2940        // Train
2941        for i in 0..80 {
2942            let x = i as f64 * 0.1;
2943            model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
2944        }
2945
2946        // Get full-tree prediction (no cache)
2947        assert!(!model.has_packed_cache());
2948        let full_pred = model.predict(&[2.0, 4.0, 1.0]);
2949
2950        // Enable cache and get cached prediction
2951        model.enable_packed_cache(10);
2952        assert!(model.has_packed_cache());
2953        let cached_pred = model.predict(&[2.0, 4.0, 1.0]);
2954
2955        // Should match within f32 precision (f64->f32 conversion loses some precision)
2956        let mu_diff = (full_pred.mu - cached_pred.mu).abs();
2957        assert!(
2958            mu_diff < 0.1,
2959            "packed cache mu ({}) should match full tree mu ({}) within f32 tolerance, diff={}",
2960            cached_pred.mu,
2961            full_pred.mu,
2962            mu_diff
2963        );
2964
2965        // Sigma should be identical (not affected by packed cache)
2966        assert!(
2967            (full_pred.sigma - cached_pred.sigma).abs() < 1e-12,
2968            "sigma should be identical: full={}, cached={}",
2969            full_pred.sigma,
2970            cached_pred.sigma
2971        );
2972    }
2973
2974    #[test]
2975    fn honest_sigma_in_gaussian_prediction() {
2976        let config = SGBTConfig::builder()
2977            .n_steps(5)
2978            .learning_rate(0.1)
2979            .max_depth(3)
2980            .grace_period(2)
2981            .initial_target_count(10)
2982            .build()
2983            .unwrap();
2984        let mut model = DistributionalSGBT::new(config);
2985        for i in 0..100 {
2986            let x = i as f64 * 0.1;
2987            model.train_one(&(vec![x], x * 2.0));
2988        }
2989
2990        let pred = model.predict(&[5.0]);
2991        // honest_sigma should be finite and non-negative
2992        assert!(
2993            pred.honest_sigma.is_finite(),
2994            "honest_sigma should be finite, got {}",
2995            pred.honest_sigma
2996        );
2997        assert!(
2998            pred.honest_sigma >= 0.0,
2999            "honest_sigma should be >= 0, got {}",
3000            pred.honest_sigma
3001        );
3002    }
3003
3004    #[test]
3005    fn honest_sigma_increases_with_divergence() {
3006        let config = SGBTConfig::builder()
3007            .n_steps(8)
3008            .learning_rate(0.1)
3009            .max_depth(4)
3010            .grace_period(2)
3011            .initial_target_count(10)
3012            .build()
3013            .unwrap();
3014        let mut model = DistributionalSGBT::new(config);
3015
3016        // Train on linear data
3017        for i in 0..200 {
3018            let x = i as f64 * 0.05;
3019            model.train_one(&(vec![x], x * 2.0));
3020        }
3021
3022        // Predict in-distribution
3023        let in_dist = model.predict(&[5.0]);
3024        // Predict very far out of distribution
3025        let out_dist = model.predict(&[500.0]);
3026
3027        // Both should be finite and non-negative
3028        assert!(in_dist.honest_sigma.is_finite());
3029        assert!(out_dist.honest_sigma.is_finite());
3030        assert!(in_dist.honest_sigma >= 0.0);
3031        assert!(out_dist.honest_sigma >= 0.0);
3032
3033        // Out-of-distribution should generally produce different honest_sigma
3034        // (trees will respond differently to unseen regions)
3035        // We can't guarantee strictly greater since all trees might
3036        // extrapolate identically, but at minimum both should be valid.
3037    }
3038
3039    #[test]
3040    fn honest_sigma_zero_for_single_step() {
3041        // With only 1 step, variance of 1 contribution is undefined -> returns 0
3042        let config = SGBTConfig::builder()
3043            .n_steps(1)
3044            .learning_rate(0.1)
3045            .max_depth(3)
3046            .grace_period(2)
3047            .initial_target_count(10)
3048            .build()
3049            .unwrap();
3050        let mut model = DistributionalSGBT::new(config);
3051        for i in 0..100 {
3052            let x = i as f64 * 0.1;
3053            model.train_one(&(vec![x], x * 2.0));
3054        }
3055        let pred = model.predict(&[5.0]);
3056        assert!(
3057            pred.honest_sigma.abs() < 1e-15,
3058            "honest_sigma should be 0 with 1 step, got {}",
3059            pred.honest_sigma
3060        );
3061    }
3062
3063    #[test]
3064    fn adaptive_mts_with_distributional() {
3065        let config = SGBTConfig::builder()
3066            .n_steps(10)
3067            .learning_rate(0.1)
3068            .max_depth(3)
3069            .grace_period(5)
3070            .initial_target_count(10)
3071            .adaptive_mts(500, 1.0)
3072            .build()
3073            .unwrap();
3074        let mut model = DistributionalSGBT::new(config);
3075
3076        // Train 500 samples -- should not crash and produce finite predictions.
3077        for i in 0..500 {
3078            let x = (i as f64) * 0.02;
3079            model.train_one(&(vec![x, x * 0.5], x.sin()));
3080        }
3081        let pred = model.predict(&[1.0, 0.5]);
3082        assert!(
3083            pred.mu.is_finite(),
3084            "adaptive_mts distributional: mu should be finite, got {}",
3085            pred.mu
3086        );
3087        assert!(
3088            pred.sigma.is_finite() && pred.sigma > 0.0,
3089            "adaptive_mts distributional: sigma should be finite and positive, got {}",
3090            pred.sigma
3091        );
3092    }
3093
3094    #[test]
3095    fn proactive_prune_with_distributional() {
3096        let config = SGBTConfig::builder()
3097            .n_steps(10)
3098            .learning_rate(0.1)
3099            .max_depth(3)
3100            .grace_period(5)
3101            .initial_target_count(10)
3102            .proactive_prune_interval(100)
3103            .build()
3104            .unwrap();
3105        let mut model = DistributionalSGBT::new(config);
3106
3107        // Train 300 samples -- at sample 100, 200, 300 a prune fires.
3108        for i in 0..300 {
3109            let x = (i as f64) * 0.03;
3110            model.train_one(&(vec![x, x * 0.7], x.cos()));
3111        }
3112        let pred = model.predict(&[2.0, 1.4]);
3113        assert!(
3114            pred.mu.is_finite(),
3115            "proactive_prune distributional: mu should be finite, got {}",
3116            pred.mu
3117        );
3118        assert!(
3119            pred.sigma.is_finite() && pred.sigma > 0.0,
3120            "proactive_prune distributional: sigma should be finite and positive, got {}",
3121            pred.sigma
3122        );
3123    }
3124
3125    /// Verify DistributionalSGBT implements DiagnosticSource and returns Some.
3126    #[test]
3127    fn diagnostic_source_impl() {
3128        use crate::automl::DiagnosticSource;
3129
3130        let mut model = DistributionalSGBT::new(test_config());
3131
3132        for i in 0..200 {
3133            let x = (i as f64) * 0.1;
3134            model.train_one(&(vec![x, x * 0.3], x.sin()));
3135        }
3136
3137        let diag = model.config_diagnostics();
3138        assert!(
3139            diag.is_some(),
3140            "config_diagnostics should return Some after training"
3141        );
3142        let diag = diag.unwrap();
3143
3144        assert!(
3145            diag.effective_dof > 0.0,
3146            "effective_dof should be > 0, got {}",
3147            diag.effective_dof
3148        );
3149        assert!(
3150            diag.uncertainty.is_finite(),
3151            "uncertainty should be finite, got {}",
3152            diag.uncertainty
3153        );
3154    }
3155
3156    // -------------------------------------------------------------------
3157    // Diagnostic signal tests
3158    // -------------------------------------------------------------------
3159
3160    #[test]
3161    fn diagnostic_signals_populated() {
3162        use crate::automl::DiagnosticSource;
3163        let mut model = DistributionalSGBT::new(test_config());
3164
3165        // Train enough samples for meaningful diagnostics
3166        for i in 0..500 {
3167            let x = i as f64 * 0.1;
3168            model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
3169        }
3170
3171        let diag = model
3172            .config_diagnostics()
3173            .expect("should return Some diagnostics");
3174
3175        // After 500 samples all signals should be populated (non-zero)
3176        assert!(
3177            diag.residual_alignment != 0.0,
3178            "residual_alignment should be non-zero, got {}",
3179            diag.residual_alignment
3180        );
3181        assert!(
3182            diag.regularization_sensitivity != 0.0,
3183            "regularization_sensitivity should be non-zero, got {}",
3184            diag.regularization_sensitivity
3185        );
3186        assert!(
3187            diag.depth_sufficiency != 0.0,
3188            "depth_sufficiency should be non-zero, got {}",
3189            diag.depth_sufficiency
3190        );
3191        assert!(
3192            diag.effective_dof != 0.0,
3193            "effective_dof should be non-zero, got {}",
3194            diag.effective_dof
3195        );
3196        assert!(
3197            diag.uncertainty.is_finite(),
3198            "uncertainty should be finite, got {}",
3199            diag.uncertainty
3200        );
3201    }
3202
3203    #[test]
3204    fn residual_alignment_range() {
3205        use crate::automl::DiagnosticSource;
3206        let mut model = DistributionalSGBT::new(test_config());
3207
3208        for i in 0..500 {
3209            let x = i as f64 * 0.1;
3210            model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
3211        }
3212
3213        let diag = model
3214            .config_diagnostics()
3215            .expect("should return Some diagnostics");
3216
3217        assert!(
3218            diag.residual_alignment >= -1.0 && diag.residual_alignment <= 1.0,
3219            "residual_alignment should be in [-1, 1], got {}",
3220            diag.residual_alignment
3221        );
3222    }
3223
3224    #[test]
3225    fn effective_dof_positive() {
3226        use crate::automl::DiagnosticSource;
3227        let mut model = DistributionalSGBT::new(test_config());
3228
3229        for i in 0..500 {
3230            let x = i as f64 * 0.1;
3231            model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
3232        }
3233
3234        let diag = model
3235            .config_diagnostics()
3236            .expect("should return Some diagnostics");
3237
3238        assert!(
3239            diag.effective_dof > 0.0,
3240            "effective_dof should be positive after training, got {}",
3241            diag.effective_dof
3242        );
3243    }
3244
3245    #[test]
3246    fn depth_sufficiency_positive() {
3247        use crate::automl::DiagnosticSource;
3248        let mut model = DistributionalSGBT::new(test_config());
3249
3250        for i in 0..500 {
3251            let x = i as f64 * 0.1;
3252            model.train_one(&(vec![x, x * 0.5], x * 2.0 + 1.0));
3253        }
3254
3255        let diag = model
3256            .config_diagnostics()
3257            .expect("should return Some diagnostics");
3258
3259        assert!(
3260            diag.depth_sufficiency > 0.0,
3261            "depth_sufficiency should be positive after training, got {}",
3262            diag.depth_sufficiency
3263        );
3264    }
3265}
3266
3267#[cfg(test)]
3268#[cfg(feature = "serde-json")]
3269mod serde_tests {
3270    use super::*;
3271    use crate::SGBTConfig;
3272
3273    fn make_trained_distributional() -> DistributionalSGBT {
3274        let config = SGBTConfig::builder()
3275            .n_steps(5)
3276            .learning_rate(0.1)
3277            .max_depth(3)
3278            .grace_period(2)
3279            .initial_target_count(10)
3280            .build()
3281            .unwrap();
3282        let mut model = DistributionalSGBT::new(config);
3283        for i in 0..50 {
3284            let x = i as f64 * 0.1;
3285            model.train_one(&(vec![x], x.sin()));
3286        }
3287        model
3288    }
3289
3290    #[test]
3291    fn json_round_trip_preserves_predictions() {
3292        let model = make_trained_distributional();
3293        let state = model.to_distributional_state();
3294        let json = crate::serde_support::save_distributional_model(&state).unwrap();
3295        let loaded_state = crate::serde_support::load_distributional_model(&json).unwrap();
3296        let restored = DistributionalSGBT::from_distributional_state(loaded_state);
3297
3298        let test_points = [0.5, 1.0, 2.0, 3.0];
3299        for &x in &test_points {
3300            let orig = model.predict(&[x]);
3301            let rest = restored.predict(&[x]);
3302            assert!(
3303                (orig.mu - rest.mu).abs() < 1e-10,
3304                "JSON round-trip mu mismatch at x={}: {} vs {}",
3305                x,
3306                orig.mu,
3307                rest.mu
3308            );
3309            assert!(
3310                (orig.sigma - rest.sigma).abs() < 1e-10,
3311                "JSON round-trip sigma mismatch at x={}: {} vs {}",
3312                x,
3313                orig.sigma,
3314                rest.sigma
3315            );
3316        }
3317    }
3318
3319    #[test]
3320    fn state_preserves_rolling_sigma_mean() {
3321        let config = SGBTConfig::builder()
3322            .n_steps(5)
3323            .learning_rate(0.1)
3324            .max_depth(3)
3325            .grace_period(2)
3326            .initial_target_count(10)
3327            .uncertainty_modulated_lr(true)
3328            .build()
3329            .unwrap();
3330        let mut model = DistributionalSGBT::new(config);
3331        for i in 0..50 {
3332            let x = i as f64 * 0.1;
3333            model.train_one(&(vec![x], x.sin()));
3334        }
3335        let state = model.to_distributional_state();
3336        assert!(state.uncertainty_modulated_lr);
3337        assert!(state.rolling_sigma_mean >= 0.0);
3338
3339        let restored = DistributionalSGBT::from_distributional_state(state);
3340        assert_eq!(model.n_samples_seen(), restored.n_samples_seen());
3341    }
3342
3343    #[test]
3344    fn auto_bandwidth_computed_distributional() {
3345        let config = SGBTConfig::builder()
3346            .n_steps(3)
3347            .learning_rate(0.1)
3348            .grace_period(10)
3349            .initial_target_count(10)
3350            .build()
3351            .unwrap();
3352        let mut model = DistributionalSGBT::new(config);
3353
3354        for i in 0..200 {
3355            let x = (i as f64) * 0.1;
3356            model.train_one(&(vec![x, x.sin()], 2.0 * x + 1.0));
3357        }
3358
3359        // auto_bandwidths should be populated
3360        let bws = model.auto_bandwidths();
3361        assert_eq!(bws.len(), 2, "should have 2 feature bandwidths");
3362
3363        // Diagnostics should include auto_bandwidths
3364        let diag = model.diagnostics();
3365        assert_eq!(diag.auto_bandwidths.len(), 2);
3366
3367        // TreeDiagnostic should have prediction stats
3368        assert!(diag.location_trees[0].prediction_mean.is_finite());
3369        assert!(diag.location_trees[0].prediction_std.is_finite());
3370
3371        let pred = model.predict(&[1.0, 1.0_f64.sin()]);
3372        assert!(pred.mu.is_finite(), "auto-bandwidth mu should be finite");
3373        assert!(pred.sigma > 0.0, "auto-bandwidth sigma should be positive");
3374    }
3375
3376    #[test]
3377    fn max_leaf_output_clamps_predictions() {
3378        let config = SGBTConfig::builder()
3379            .n_steps(5)
3380            .learning_rate(1.0) // Intentionally large to force extreme leaf weights
3381            .max_leaf_output(0.5)
3382            .build()
3383            .unwrap();
3384        let mut model = DistributionalSGBT::new(config);
3385
3386        // Train with extreme targets to force large leaf weights
3387        for i in 0..200 {
3388            let target = if i % 2 == 0 { 100.0 } else { -100.0 };
3389            let sample = crate::Sample::new(vec![i as f64 % 5.0, (i as f64).sin()], target);
3390            model.train_one(&sample);
3391        }
3392
3393        // Each tree's prediction should be clamped to [-0.5, 0.5]
3394        let pred = model.predict(&[2.0, 0.5]);
3395        assert!(
3396            pred.mu.is_finite(),
3397            "prediction should be finite with clamping"
3398        );
3399    }
3400
3401    #[test]
3402    fn min_hessian_sum_suppresses_fresh_leaves() {
3403        let config = SGBTConfig::builder()
3404            .n_steps(3)
3405            .learning_rate(0.01)
3406            .min_hessian_sum(50.0)
3407            .build()
3408            .unwrap();
3409        let mut model = DistributionalSGBT::new(config);
3410
3411        // Train minimal samples
3412        for i in 0..60 {
3413            let sample = crate::Sample::new(vec![i as f64, (i as f64).sin()], i as f64 * 0.1);
3414            model.train_one(&sample);
3415        }
3416
3417        let pred = model.predict(&[30.0, 0.5]);
3418        assert!(
3419            pred.mu.is_finite(),
3420            "prediction should be finite with min_hessian_sum"
3421        );
3422    }
3423
3424    #[test]
3425    fn predict_interpolated_returns_finite() {
3426        let config = SGBTConfig::builder()
3427            .n_steps(10)
3428            .learning_rate(0.1)
3429            .grace_period(20)
3430            .max_depth(4)
3431            .n_bins(16)
3432            .initial_target_count(10)
3433            .build()
3434            .unwrap();
3435        let mut model = DistributionalSGBT::new(config);
3436        for i in 0..200 {
3437            let x = i as f64 * 0.1;
3438            let sample = crate::Sample::new(vec![x, x.sin()], x.cos());
3439            model.train_one(&sample);
3440        }
3441
3442        let pred = model.predict_interpolated(&[1.0, 0.5]);
3443        assert!(pred.mu.is_finite(), "interpolated mu should be finite");
3444        assert!(pred.sigma > 0.0, "interpolated sigma should be positive");
3445    }
3446
3447    #[test]
3448    fn huber_k_bounds_gradients() {
3449        let config = SGBTConfig::builder()
3450            .n_steps(5)
3451            .learning_rate(0.01)
3452            .huber_k(1.345)
3453            .build()
3454            .unwrap();
3455        let mut model = DistributionalSGBT::new(config);
3456
3457        // Train with occasional extreme outliers
3458        for i in 0..300 {
3459            let target = if i % 50 == 0 {
3460                1000.0
3461            } else {
3462                (i as f64 * 0.1).sin()
3463            };
3464            let sample = crate::Sample::new(vec![i as f64 % 10.0, (i as f64).cos()], target);
3465            model.train_one(&sample);
3466        }
3467
3468        let pred = model.predict(&[5.0, 0.3]);
3469        assert!(
3470            pred.mu.is_finite(),
3471            "Huber-loss mu should be finite despite outliers"
3472        );
3473        assert!(pred.sigma > 0.0, "sigma should be positive");
3474    }
3475
3476    #[test]
3477    fn ensemble_gradient_stats_populated() {
3478        let config = SGBTConfig::builder()
3479            .n_steps(10)
3480            .learning_rate(0.1)
3481            .grace_period(20)
3482            .max_depth(4)
3483            .n_bins(16)
3484            .initial_target_count(10)
3485            .build()
3486            .unwrap();
3487        let mut model = DistributionalSGBT::new(config);
3488        for i in 0..200 {
3489            let x = i as f64 * 0.1;
3490            let sample = crate::Sample::new(vec![x, x.sin()], x.cos());
3491            model.train_one(&sample);
3492        }
3493
3494        let diag = model.diagnostics();
3495        assert!(
3496            diag.ensemble_grad_mean.is_finite(),
3497            "ensemble grad mean should be finite"
3498        );
3499        assert!(
3500            diag.ensemble_grad_std >= 0.0,
3501            "ensemble grad std should be non-negative"
3502        );
3503        assert!(
3504            diag.ensemble_grad_std.is_finite(),
3505            "ensemble grad std should be finite"
3506        );
3507    }
3508
3509    #[test]
3510    fn huber_k_validation() {
3511        let result = SGBTConfig::builder()
3512            .n_steps(5)
3513            .learning_rate(0.01)
3514            .huber_k(-1.0)
3515            .build();
3516        assert!(result.is_err(), "negative huber_k should fail validation");
3517    }
3518
3519    #[test]
3520    fn max_leaf_output_validation() {
3521        let result = SGBTConfig::builder()
3522            .n_steps(5)
3523            .learning_rate(0.01)
3524            .max_leaf_output(-1.0)
3525            .build();
3526        assert!(
3527            result.is_err(),
3528            "negative max_leaf_output should fail validation"
3529        );
3530    }
3531
3532    #[test]
3533    fn predict_sibling_interpolated_varies_with_features() {
3534        let config = SGBTConfig::builder()
3535            .n_steps(10)
3536            .learning_rate(0.1)
3537            .grace_period(10)
3538            .max_depth(6)
3539            .delta(0.1)
3540            .initial_target_count(10)
3541            .build()
3542            .unwrap();
3543        let mut model = DistributionalSGBT::new(config);
3544
3545        for i in 0..2000 {
3546            let x = (i as f64) * 0.01;
3547            let y = x.sin() * x + 0.5 * (x * 2.0).cos();
3548            let sample = crate::Sample::new(vec![x, x * 0.3], y);
3549            model.train_one(&sample);
3550        }
3551
3552        // Verify the method runs correctly and produces finite predictions
3553        let pred = model.predict_sibling_interpolated(&[5.0, 1.5]);
3554        assert!(
3555            pred.mu.is_finite(),
3556            "sibling interpolated mu should be finite"
3557        );
3558        assert!(
3559            pred.sigma > 0.0,
3560            "sibling interpolated sigma should be positive"
3561        );
3562
3563        // Sweep — at minimum, sibling should produce same or more variation than hard
3564        let bws = model.auto_bandwidths();
3565        if bws.iter().any(|&b| b.is_finite()) {
3566            let hard_preds: Vec<f64> = (0..200)
3567                .map(|i| {
3568                    let x = i as f64 * 0.1;
3569                    model.predict(&[x, x * 0.3]).mu
3570                })
3571                .collect();
3572            let hard_changes = hard_preds
3573                .windows(2)
3574                .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
3575                .count();
3576
3577            let preds: Vec<f64> = (0..200)
3578                .map(|i| {
3579                    let x = i as f64 * 0.1;
3580                    model.predict_sibling_interpolated(&[x, x * 0.3]).mu
3581                })
3582                .collect();
3583
3584            let sibling_changes = preds
3585                .windows(2)
3586                .filter(|w| (w[0] - w[1]).abs() > f64::EPSILON)
3587                .count();
3588            assert!(
3589                sibling_changes >= hard_changes,
3590                "sibling should produce >= hard changes: sibling={}, hard={}",
3591                sibling_changes,
3592                hard_changes
3593            );
3594        }
3595    }
3596}