Skip to main content

irithyll_core/ensemble/distributional/
mod.rs

1//! Distributional SGBT -- outputs Gaussian N(μ, σ²) instead of a point estimate.
2//!
3//! [`DistributionalSGBT`] supports two scale estimation modes via
4//! [`ScaleMode`]:
5//!
6//! ## Empirical σ (default)
7//!
8//! Tracks an EWMA of squared prediction errors:
9//!
10//! ```text
11//! err = target - mu
12//! ewma_sq_err = alpha * err² + (1 - alpha) * ewma_sq_err
13//! sigma = sqrt(ewma_sq_err)
14//! ```
15//!
16//! Always calibrated (σ literally *is* recent error magnitude), zero tuning,
17//! O(1) memory and compute.  When `uncertainty_modulated_lr` is enabled,
18//! high recent errors → σ large → location LR scales up → faster correction.
19//!
20//! ## Tree chain (NGBoost-style)
21//!
22//! Maintains two independent tree ensembles: one for location (μ), one for
23//! scale (log σ).  Gives feature-conditional uncertainty but requires strong
24//! scale-gradient signal for the trees to split.
25//!
26//! # References
27//!
28//! Duan et al. (2020). "NGBoost: Natural Gradient Boosting for Probabilistic Prediction."
29
30mod diagnostics;
31mod inference;
32mod training;
33
34#[cfg(test)]
35mod tests;
36
37pub use diagnostics::{DecomposedPrediction, DistributionalTreeDiagnostic, ModelDiagnostics};
38
39use alloc::vec::Vec;
40
41use crate::ensemble::config::{SGBTConfig, ScaleMode};
42use crate::ensemble::step::BoostingStep;
43use crate::sample::{Observation, SampleRef};
44
45/// Cached packed f32 binary for fast location-only inference.
46///
47/// Re-exported periodically from the location ensemble. Predictions use
48/// contiguous BFS-packed memory for cache-optimal tree traversal.
49struct PackedInferenceCache {
50    bytes: Vec<u8>,
51    base: f64,
52    n_features: usize,
53}
54
55impl Clone for PackedInferenceCache {
56    fn clone(&self) -> Self {
57        Self {
58            bytes: self.bytes.clone(),
59            base: self.base,
60            n_features: self.n_features,
61        }
62    }
63}
64
65/// Prediction from a distributional model: full Gaussian N(μ, σ²).
66#[derive(Debug, Clone, Copy)]
67pub struct GaussianPrediction {
68    /// Location parameter (mean).
69    pub mu: f64,
70    /// Scale parameter (standard deviation, always > 0).
71    pub sigma: f64,
72    /// Log of scale parameter (raw model output for scale ensemble).
73    pub log_sigma: f64,
74    /// Tree contribution variance (epistemic uncertainty).
75    ///
76    /// Standard deviation of individual location-tree contributions,
77    /// computed via one-pass Welford variance with Bessel's correction.
78    /// Reacts instantly when trees disagree (no EWMA lag), making it
79    /// superior to empirical sigma for regime-change detection.
80    ///
81    /// Zero when the model has 0 or 1 active location trees.
82    pub honest_sigma: f64,
83}
84
85impl GaussianPrediction {
86    /// Lower bound of a symmetric confidence interval.
87    ///
88    /// For 95% CI, use `z = 1.96`.
89    #[inline]
90    pub fn lower(&self, z: f64) -> f64 {
91        self.mu - z * self.sigma
92    }
93
94    /// Upper bound of a symmetric confidence interval.
95    #[inline]
96    pub fn upper(&self, z: f64) -> f64 {
97        self.mu + z * self.sigma
98    }
99}
100
101/// NGBoost-style distributional streaming gradient boosted trees.
102///
103/// Outputs a full Gaussian predictive distribution N(μ, σ²) by maintaining two
104/// independent ensembles -- one for location (mean) and one for scale (log-sigma).
105///
106/// # Example
107///
108/// ```text
109/// use irithyll_core::SGBTConfig;
110/// use irithyll_core::ensemble::distributional::DistributionalSGBT;
111///
112/// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
113/// let mut model = DistributionalSGBT::new(config);
114///
115/// // Train on streaming data
116/// model.train_one(&(vec![1.0, 2.0], 3.5));
117///
118/// // Get full distributional prediction
119/// let pred = model.predict(&[1.0, 2.0]);
120/// println!("mean={}, sigma={}", pred.mu, pred.sigma);
121/// ```
122pub struct DistributionalSGBT {
123    config: SGBTConfig,
124    location_steps: Vec<BoostingStep>,
125    scale_steps: Vec<BoostingStep>,
126    location_base: f64,
127    scale_base: f64,
128    base_initialized: bool,
129    initial_targets: Vec<f64>,
130    initial_target_count: usize,
131    samples_seen: u64,
132    rng_state: u64,
133    uncertainty_modulated_lr: bool,
134    rolling_sigma_mean: f64,
135    scale_mode: ScaleMode,
136    ewma_sq_err: f64,
137    empirical_sigma_alpha: f64,
138    prev_sigma: f64,
139    sigma_velocity: f64,
140    auto_bandwidths: Vec<f64>,
141    last_replacement_sum: u64,
142    ensemble_grad_mean: f64,
143    ensemble_grad_m2: f64,
144    ensemble_grad_count: u64,
145    rolling_honest_sigma_mean: f64,
146    packed_cache: Option<PackedInferenceCache>,
147    samples_since_refresh: u64,
148    packed_refresh_interval: u64,
149}
150
151impl Clone for DistributionalSGBT {
152    fn clone(&self) -> Self {
153        Self {
154            config: self.config.clone(),
155            location_steps: self.location_steps.clone(),
156            scale_steps: self.scale_steps.clone(),
157            location_base: self.location_base,
158            scale_base: self.scale_base,
159            base_initialized: self.base_initialized,
160            initial_targets: self.initial_targets.clone(),
161            initial_target_count: self.initial_target_count,
162            samples_seen: self.samples_seen,
163            rng_state: self.rng_state,
164            uncertainty_modulated_lr: self.uncertainty_modulated_lr,
165            rolling_sigma_mean: self.rolling_sigma_mean,
166            scale_mode: self.scale_mode,
167            ewma_sq_err: self.ewma_sq_err,
168            empirical_sigma_alpha: self.empirical_sigma_alpha,
169            prev_sigma: self.prev_sigma,
170            sigma_velocity: self.sigma_velocity,
171            auto_bandwidths: self.auto_bandwidths.clone(),
172            last_replacement_sum: self.last_replacement_sum,
173            ensemble_grad_mean: self.ensemble_grad_mean,
174            ensemble_grad_m2: self.ensemble_grad_m2,
175            ensemble_grad_count: self.ensemble_grad_count,
176            rolling_honest_sigma_mean: self.rolling_honest_sigma_mean,
177            packed_cache: self.packed_cache.clone(),
178            samples_since_refresh: self.samples_since_refresh,
179            packed_refresh_interval: self.packed_refresh_interval,
180        }
181    }
182}
183
184impl core::fmt::Debug for DistributionalSGBT {
185    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
186        let mut s = f.debug_struct("DistributionalSGBT");
187        s.field("n_steps", &self.location_steps.len())
188            .field("samples_seen", &self.samples_seen)
189            .field("location_base", &self.location_base)
190            .field("scale_mode", &self.scale_mode)
191            .field("base_initialized", &self.base_initialized);
192        match self.scale_mode {
193            ScaleMode::Empirical => {
194                s.field("empirical_sigma", &crate::math::sqrt(self.ewma_sq_err));
195            }
196            ScaleMode::TreeChain => {
197                s.field("scale_base", &self.scale_base);
198            }
199        }
200        if self.uncertainty_modulated_lr {
201            s.field("rolling_sigma_mean", &self.rolling_sigma_mean);
202        }
203        s.finish()
204    }
205}
206
207impl DistributionalSGBT {
208    /// Create a new distributional SGBT model.
209    pub fn new(config: SGBTConfig) -> Self {
210        let n_steps = config.n_steps;
211        let initial_target_count = config.initial_target_count;
212        let seed = config.seed;
213        let uncertainty_modulated_lr = config.uncertainty_modulated_lr;
214        let scale_mode = config.scale_mode;
215
216        let leaf_decay_alpha = config
217            .leaf_half_life
218            .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
219        let tree_config = crate::ensemble::config::build_tree_config(&config)
220            .leaf_decay_alpha_opt(leaf_decay_alpha);
221        let max_tree_samples = config.max_tree_samples;
222        let shadow_warmup = config.shadow_warmup.unwrap_or(0);
223
224        let build_steps = |salt: u64| -> Vec<BoostingStep> {
225            (0..n_steps)
226                .map(|i| {
227                    let mut tc = tree_config.clone();
228                    tc.seed = seed ^ salt ^ (i as u64);
229                    let detector = config.drift_detector.create();
230                    if shadow_warmup > 0 {
231                        BoostingStep::new_with_graduated(
232                            tc,
233                            detector,
234                            max_tree_samples,
235                            shadow_warmup,
236                        )
237                    } else {
238                        BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
239                    }
240                })
241                .collect()
242        };
243
244        let location_steps = build_steps(0);
245        let scale_steps = build_steps(0xD15C_A1E5_5CA1_E000);
246
247        Self {
248            config,
249            location_steps,
250            scale_steps,
251            location_base: 0.0,
252            scale_base: 0.0,
253            base_initialized: false,
254            initial_targets: Vec::with_capacity(initial_target_count),
255            initial_target_count,
256            samples_seen: 0,
257            rng_state: 1u64.wrapping_add(seed),
258            uncertainty_modulated_lr,
259            rolling_sigma_mean: 1.0,
260            scale_mode,
261            ewma_sq_err: 0.0,
262            empirical_sigma_alpha: 0.05,
263            prev_sigma: 0.0,
264            sigma_velocity: 0.0,
265            auto_bandwidths: Vec::new(),
266            last_replacement_sum: 0,
267            ensemble_grad_mean: 0.0,
268            ensemble_grad_m2: 0.0,
269            ensemble_grad_count: 0,
270            rolling_honest_sigma_mean: 1.0,
271            packed_cache: None,
272            samples_since_refresh: 0,
273            packed_refresh_interval: 1000,
274        }
275    }
276
277    /// Access the configuration.
278    pub fn config(&self) -> &SGBTConfig {
279        &self.config
280    }
281
282    /// Train on a single observation.
283    pub fn train_one(&mut self, obs: &impl Observation) {
284        training::train_distributional_one(self, obs);
285    }
286
287    /// Train on a batch of observations.
288    pub fn train_batch(&mut self, samples: &[(Vec<f64>, f64)]) {
289        for (features, target) in samples {
290            self.train_one(&(features.clone(), *target));
291        }
292    }
293
294    /// Predict the full distributional output for a single sample.
295    pub fn predict(&self, features: &[f64]) -> GaussianPrediction {
296        inference::predict_distributional(self, features)
297    }
298
299    /// Predict distributional output for multiple samples (batch).
300    pub fn predict_batch(&self, batch: &[Vec<f64>]) -> Vec<GaussianPrediction> {
301        batch.iter().map(|f| self.predict(f)).collect()
302    }
303
304    /// Predict a confidence interval with the given z-score.
305    ///
306    /// Returns `(lower, upper)` for a symmetric interval around μ.
307    pub fn predict_interval(&self, features: &[f64], z: f64) -> (f64, f64) {
308        let pred = self.predict(features);
309        (pred.lower(z), pred.upper(z))
310    }
311
312    /// Tuple form: `(μ, σ, σ_ratio)` where σ_ratio is 1.0 if uncertainty
313    /// modulation is disabled, otherwise the ratio for LR scaling.
314    pub fn predict_distributional(&self, features: &[f64]) -> (f64, f64, f64) {
315        let pred = self.predict(features);
316        let ratio = if self.uncertainty_modulated_lr {
317            (pred.honest_sigma / self.rolling_honest_sigma_mean).clamp(0.1, 10.0)
318        } else {
319            1.0
320        };
321        (pred.mu, pred.sigma, ratio)
322    }
323
324    /// Predict with sigmoid-blended soft routing for smooth interpolation.
325    ///
326    /// Instead of hard left/right routing at tree split nodes, each split
327    /// uses sigmoid blending: `alpha = sigmoid((threshold - feature) / bandwidth)`.
328    /// The result is a continuous function that varies smoothly with every
329    /// feature change.
330    ///
331    /// `bandwidth` controls transition sharpness: smaller = sharper (closer
332    /// to hard splits), larger = smoother.
333    pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> GaussianPrediction {
334        inference::predict_smooth(self, features, bandwidth)
335    }
336
337    /// Predict with parent-leaf linear interpolation.
338    ///
339    /// Blends each leaf prediction with its parent's preserved prediction
340    /// based on sample count, preventing stale predictions from fresh leaves.
341    pub fn predict_interpolated(&self, features: &[f64]) -> GaussianPrediction {
342        inference::predict_interpolated(self, features)
343    }
344
345    /// Predict with sibling-based interpolation for feature-continuous predictions.
346    ///
347    /// At each split node near the threshold boundary, blends left and right
348    /// subtree predictions linearly. Uses auto-calibrated bandwidths as the
349    /// interpolation margin. Predictions vary continuously as features change.
350    pub fn predict_sibling_interpolated(&self, features: &[f64]) -> GaussianPrediction {
351        inference::predict_sibling_interpolated(self, features)
352    }
353
354    /// Is the model initialized with base predictions?
355    pub fn is_initialized(&self) -> bool {
356        self.base_initialized
357    }
358
359    /// Number of location ensemble trees.
360    pub fn n_location_trees(&self) -> usize {
361        self.location_steps.len()
362    }
363
364    /// Number of scale ensemble trees.
365    pub fn n_scale_trees(&self) -> usize {
366        self.scale_steps.len()
367    }
368
369    /// Total number of trees in both ensembles.
370    pub fn n_trees(&self) -> usize {
371        self.location_steps.len() + self.scale_steps.len()
372    }
373
374    /// Total samples seen.
375    pub fn n_samples_seen(&self) -> u64 {
376        self.samples_seen
377    }
378
379    /// Is uncertainty-modulated learning rate enabled?
380    pub fn is_uncertainty_modulated(&self) -> bool {
381        self.uncertainty_modulated_lr
382    }
383
384    /// Current rolling mean of predicted σ.
385    pub fn rolling_sigma_mean(&self) -> f64 {
386        self.rolling_sigma_mean
387    }
388
389    /// Reset all state.
390    pub fn reset(&mut self) {
391        self.location_steps.clear();
392        self.scale_steps.clear();
393        self.location_base = 0.0;
394        self.scale_base = 0.0;
395        self.base_initialized = false;
396        self.initial_targets.clear();
397        self.samples_seen = 0;
398        self.rng_state = 1u64.wrapping_add(self.config.seed);
399        self.rolling_sigma_mean = 1.0;
400        self.ewma_sq_err = 0.0;
401        self.prev_sigma = 0.0;
402        self.sigma_velocity = 0.0;
403        self.auto_bandwidths.clear();
404        self.ensemble_grad_mean = 0.0;
405        self.ensemble_grad_m2 = 0.0;
406        self.ensemble_grad_count = 0;
407        self.rolling_honest_sigma_mean = 1.0;
408        self.packed_cache = None;
409    }
410
411    /// Full model diagnostics.
412    pub fn diagnostics(&self) -> ModelDiagnostics {
413        diagnostics::compute_diagnostics(self)
414    }
415
416    /// Decomposed prediction (per-tree contributions).
417    pub fn predict_decomposed(&self, features: &[f64]) -> DecomposedPrediction {
418        diagnostics::decompose_prediction(self, features)
419    }
420
421    /// Feature importances (location + scale combined).
422    pub fn feature_importances(&self) -> Vec<f64> {
423        diagnostics::compute_feature_importances(self, false)
424    }
425
426    /// Feature importances split by ensemble.
427    pub fn feature_importances_split(&self) -> (Vec<f64>, Vec<f64>) {
428        let location = diagnostics::compute_feature_importances(self, true);
429        let scale = diagnostics::compute_feature_importances_scale(self);
430        (location, scale)
431    }
432
433    /// Compute honest_sigma from current location tree predictions.
434    #[allow(dead_code)]
435    fn compute_honest_sigma(&self, features: &[f64]) -> f64 {
436        if self.location_steps.len() < 2 {
437            return 0.0;
438        }
439
440        let preds: Vec<f64> = self
441            .location_steps
442            .iter()
443            .map(|s| s.predict(features))
444            .collect();
445
446        let n = preds.len() as f64;
447        let mean = preds.iter().sum::<f64>() / n;
448        let var = preds
449            .iter()
450            .map(|p| {
451                let d = p - mean;
452                d * d
453            })
454            .sum::<f64>()
455            / (n - 1.0).max(1.0);
456        crate::math::sqrt(var)
457    }
458}
459
460impl crate::learner::StreamingLearner for DistributionalSGBT {
461    fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
462        let sample = SampleRef::weighted(features, target, weight);
463        DistributionalSGBT::train_one(self, &sample);
464    }
465
466    fn predict(&self, features: &[f64]) -> f64 {
467        DistributionalSGBT::predict(self, features).mu
468    }
469
470    fn n_samples_seen(&self) -> u64 {
471        self.samples_seen
472    }
473
474    fn reset(&mut self) {
475        DistributionalSGBT::reset(self);
476    }
477}