irithyll-core 1.0.0

Core types, training engine, and inference for irithyll streaming ML — no_std + alloc, histogram binning, Hoeffding trees, SGBT ensembles, drift detection, f32 + int16 packed formats
Documentation
//! Distributional SGBT diagnostic structures and methods.

use alloc::vec;
use alloc::vec::Vec;

use crate::ensemble::config::ScaleMode;
use crate::ensemble::step::BoostingStep;

use super::DistributionalSGBT;

/// Per-tree diagnostic summary.
#[derive(Debug, Clone)]
pub struct DistributionalTreeDiagnostic {
    /// Number of leaf nodes in this tree.
    pub n_leaves: usize,
    /// Maximum depth reached by any leaf.
    pub max_depth_reached: usize,
    /// Total samples this tree has seen.
    pub samples_seen: u64,
    /// Leaf weight statistics: `(min, max, mean, std)`.
    pub leaf_weight_stats: (f64, f64, f64, f64),
    /// Feature indices this tree has split on (non-zero gain).
    pub split_features: Vec<usize>,
    /// Per-leaf sample counts showing data distribution across leaves.
    pub leaf_sample_counts: Vec<u64>,
    /// Running mean of predictions from this tree (Welford online).
    pub prediction_mean: f64,
    /// Running standard deviation of predictions from this tree.
    pub prediction_std: f64,
}

/// Full model diagnostics for [`DistributionalSGBT`].
///
/// Contains per-tree summaries, feature usage, base predictions, and
/// empirical σ state.
#[derive(Debug, Clone)]
pub struct ModelDiagnostics {
    /// Per-tree diagnostic summaries (location trees first, then scale trees).
    pub trees: Vec<DistributionalTreeDiagnostic>,
    /// Location trees only (view into `trees`).
    pub location_trees: Vec<DistributionalTreeDiagnostic>,
    /// Scale trees only (view into `trees`).
    pub scale_trees: Vec<DistributionalTreeDiagnostic>,
    /// How many trees each feature is used in (split count per feature).
    pub feature_split_counts: Vec<usize>,
    /// Base prediction for location (mean).
    pub location_base: f64,
    /// Base prediction for scale (log-sigma).
    pub scale_base: f64,
    /// Current empirical σ (`sqrt(ewma_sq_err)`), always available.
    pub empirical_sigma: f64,
    /// Scale mode in use.
    pub scale_mode: ScaleMode,
    /// Number of scale trees that actually split (>1 leaf). 0 = frozen chain.
    pub scale_trees_active: usize,
    /// Per-feature auto-calibrated bandwidths for smooth prediction.
    /// `f64::INFINITY` means that feature uses hard routing.
    pub auto_bandwidths: Vec<f64>,
    /// Ensemble-level gradient running mean.
    pub ensemble_grad_mean: f64,
    /// Ensemble-level gradient standard deviation.
    pub ensemble_grad_std: f64,
}

/// Decomposed prediction showing each tree's contribution.
#[derive(Debug, Clone)]
pub struct DecomposedPrediction {
    /// Base location prediction (mean of initial targets).
    pub location_base: f64,
    /// Base scale prediction (log-sigma of initial targets).
    pub scale_base: f64,
    /// Per-step location contributions: `learning_rate * tree_prediction`.
    /// `location_base + sum(location_contributions)` = μ.
    pub location_contributions: Vec<f64>,
    /// Per-step scale contributions: `learning_rate * tree_prediction`.
    /// `scale_base + sum(scale_contributions)` = log(σ).
    pub scale_contributions: Vec<f64>,
}

impl DecomposedPrediction {
    /// Reconstruct the final μ from base + contributions.
    pub fn mu(&self) -> f64 {
        self.location_base + self.location_contributions.iter().sum::<f64>()
    }

    /// Reconstruct the final log(σ) from base + contributions.
    pub fn log_sigma(&self) -> f64 {
        self.scale_base + self.scale_contributions.iter().sum::<f64>()
    }

    /// Reconstruct the final σ (exponentiated).
    pub fn sigma(&self) -> f64 {
        crate::math::exp(self.log_sigma()).max(1e-8)
    }
}

pub(crate) fn compute_diagnostics(model: &DistributionalSGBT) -> ModelDiagnostics {
    let n = model.location_steps.len();
    let mut trees = Vec::with_capacity(2 * n);
    let mut feature_split_counts: Vec<usize> = Vec::new();

    fn collect_tree_diags(
        steps: &[BoostingStep],
        trees: &mut Vec<DistributionalTreeDiagnostic>,
        feature_split_counts: &mut Vec<usize>,
    ) {
        for step in steps {
            let slot = step.slot();
            let tree = slot.active_tree();
            let arena = tree.arena();

            let leaf_values: Vec<f64> = (0..arena.is_leaf.len())
                .filter(|&i| arena.is_leaf[i])
                .map(|i| arena.leaf_value[i])
                .collect();

            let leaf_sample_counts: Vec<u64> = (0..arena.is_leaf.len())
                .filter(|&i| arena.is_leaf[i])
                .map(|i| arena.sample_count[i])
                .collect();

            let max_depth_reached = (0..arena.is_leaf.len())
                .filter(|&i| arena.is_leaf[i])
                .map(|i| arena.depth[i] as usize)
                .max()
                .unwrap_or(0);

            let leaf_weight_stats = if leaf_values.is_empty() {
                (0.0, 0.0, 0.0, 0.0)
            } else {
                let min = leaf_values.iter().cloned().fold(f64::INFINITY, f64::min);
                let max = leaf_values
                    .iter()
                    .cloned()
                    .fold(f64::NEG_INFINITY, f64::max);
                let sum: f64 = leaf_values.iter().sum();
                let mean = sum / leaf_values.len() as f64;
                let var: f64 = leaf_values
                    .iter()
                    .map(|v| {
                        let d = v - mean;
                        d * d
                    })
                    .sum::<f64>()
                    / leaf_values.len() as f64;
                (min, max, mean, crate::math::sqrt(var))
            };

            let gains = slot.split_gains();
            let split_features: Vec<usize> = gains
                .iter()
                .enumerate()
                .filter(|(_, &g)| g > 0.0)
                .map(|(i, _)| i)
                .collect();

            if !gains.is_empty() {
                if feature_split_counts.is_empty() {
                    feature_split_counts.resize(gains.len(), 0);
                }
                for &fi in &split_features {
                    if fi < feature_split_counts.len() {
                        feature_split_counts[fi] += 1;
                    }
                }
            }

            trees.push(DistributionalTreeDiagnostic {
                n_leaves: leaf_values.len(),
                max_depth_reached,
                samples_seen: step.n_samples_seen(),
                leaf_weight_stats,
                split_features,
                leaf_sample_counts,
                prediction_mean: slot.prediction_mean(),
                prediction_std: slot.prediction_std(),
            });
        }
    }

    collect_tree_diags(&model.location_steps, &mut trees, &mut feature_split_counts);
    collect_tree_diags(&model.scale_steps, &mut trees, &mut feature_split_counts);

    let location_trees = trees[..n].to_vec();
    let scale_trees = trees[n..].to_vec();
    let scale_trees_active = scale_trees.iter().filter(|t| t.n_leaves > 1).count();

    ModelDiagnostics {
        trees,
        location_trees,
        scale_trees,
        feature_split_counts,
        location_base: model.location_base,
        scale_base: model.scale_base,
        empirical_sigma: crate::math::sqrt(model.ewma_sq_err),
        scale_mode: model.scale_mode,
        scale_trees_active,
        auto_bandwidths: model.auto_bandwidths.clone(),
        ensemble_grad_mean: model.ensemble_grad_mean,
        ensemble_grad_std: crate::math::sqrt(
            model.ensemble_grad_m2 / model.ensemble_grad_count.max(1) as f64,
        ),
    }
}

pub(crate) fn decompose_prediction(
    model: &DistributionalSGBT,
    features: &[f64],
) -> DecomposedPrediction {
    let lr = model.config.learning_rate;
    let location: Vec<f64> = model
        .location_steps
        .iter()
        .map(|s| lr * s.predict(features))
        .collect();

    let (sb, scale) = match model.scale_mode {
        ScaleMode::Empirical => {
            let empirical_sigma = crate::math::sqrt(model.ewma_sq_err).max(1e-8);
            (
                crate::math::ln(empirical_sigma),
                vec![0.0; model.location_steps.len()],
            )
        }
        ScaleMode::TreeChain => {
            let s: Vec<f64> = model
                .scale_steps
                .iter()
                .map(|s| lr * s.predict(features))
                .collect();
            (model.scale_base, s)
        }
    };

    DecomposedPrediction {
        location_base: model.location_base,
        scale_base: sb,
        location_contributions: location,
        scale_contributions: scale,
    }
}

pub(crate) fn compute_feature_importances(
    model: &DistributionalSGBT,
    location_only: bool,
) -> Vec<f64> {
    let mut totals: Vec<f64> = Vec::new();
    let steps = if location_only {
        vec![&model.location_steps]
    } else {
        vec![&model.location_steps, &model.scale_steps]
    };

    for st in steps {
        for step in st {
            let gains = step.slot().split_gains();
            if totals.is_empty() && !gains.is_empty() {
                totals.resize(gains.len(), 0.0);
            }
            for (i, &g) in gains.iter().enumerate() {
                if i < totals.len() {
                    totals[i] += g;
                }
            }
        }
    }
    let sum: f64 = totals.iter().sum();
    if sum > 0.0 {
        totals.iter_mut().for_each(|v| *v /= sum);
    }
    totals
}

pub(crate) fn compute_feature_importances_scale(model: &DistributionalSGBT) -> Vec<f64> {
    let mut totals: Vec<f64> = Vec::new();
    for step in &model.scale_steps {
        let gains = step.slot().split_gains();
        if totals.is_empty() && !gains.is_empty() {
            totals.resize(gains.len(), 0.0);
        }
        for (i, &g) in gains.iter().enumerate() {
            if i < totals.len() {
                totals[i] += g;
            }
        }
    }
    let sum: f64 = totals.iter().sum();
    if sum > 0.0 {
        totals.iter_mut().for_each(|v| *v /= sum);
    }
    totals
}