Skip to main content

irithyll_core/ensemble/
diagnostics.rs

1//! Diagnostics for SGBT ensembles.
2//!
3//! Provides [`TreeDiagnostics`], [`EnsembleDiagnostics`], and
4//! [`DistributionalDiagnostics`] for inspecting tree structure, feature
5//! importance, per-tree contributions, and ensemble health.
6
7use alloc::vec;
8use alloc::vec::Vec;
9
10use core::cmp::Ordering;
11use core::fmt;
12
13/// Diagnostics for a single tree in the ensemble.
14#[derive(Debug, Clone)]
15pub struct TreeDiagnostics {
16    /// Total nodes (internal + leaf).
17    pub n_nodes: usize,
18    /// Number of leaf nodes.
19    pub n_leaves: usize,
20    /// Maximum depth of the tree.
21    pub max_depth: usize,
22    /// Samples this tree has been trained on.
23    pub n_samples: u64,
24    /// Number of times this tree slot has been replaced.
25    pub n_replacements: u64,
26    /// This tree's contribution to the current prediction (`lr * tree.predict(x)`).
27    pub contribution: f64,
28}
29
30/// Diagnostics for an SGBT ensemble.
31#[derive(Debug, Clone)]
32pub struct EnsembleDiagnostics {
33    /// Per-tree diagnostics.
34    pub trees: Vec<TreeDiagnostics>,
35    /// Feature importance: fraction of splits per feature across all trees.
36    /// Sums to 1.0. Indexed by feature index.
37    pub feature_importance: Vec<f64>,
38    /// Total number of tree replacements across all slots.
39    pub total_replacements: u64,
40    /// Number of active trees (n_steps).
41    pub n_trees: usize,
42    /// Base prediction (intercept).
43    pub base_prediction: f64,
44    /// Learning rate.
45    pub learning_rate: f64,
46    /// Total samples the ensemble has seen.
47    pub n_samples: u64,
48}
49
50/// Diagnostics for a [`DistributionalSGBT`](super::distributional::DistributionalSGBT).
51#[derive(Debug, Clone)]
52pub struct DistributionalDiagnostics {
53    /// Location (mu) ensemble diagnostics.
54    pub location: EnsembleDiagnostics,
55    /// Scale (sigma) ensemble diagnostics (if tree-chain mode).
56    pub scale: Option<EnsembleDiagnostics>,
57    /// Standard deviation of per-tree contributions to the current prediction,
58    /// used as a model-derived uncertainty estimate.
59    pub honest_sigma: f64,
60    /// Exponential moving average of `honest_sigma`, providing a stable
61    /// baseline for detecting sudden changes in prediction variance.
62    pub rolling_honest_sigma_mean: f64,
63    /// Effective minimum-training-samples threshold when `adaptive_mts` is
64    /// enabled; `None` if adaptive MTS is not active.
65    pub effective_mts: Option<u64>,
66}
67
68// ---------------------------------------------------------------------------
69// Display impls
70// ---------------------------------------------------------------------------
71
72impl fmt::Display for TreeDiagnostics {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        write!(
75            f,
76            "nodes={}, leaves={}, depth={}, samples={}, replacements={}, contribution={:.6}",
77            self.n_nodes,
78            self.n_leaves,
79            self.max_depth,
80            self.n_samples,
81            self.n_replacements,
82            self.contribution,
83        )
84    }
85}
86
87impl fmt::Display for EnsembleDiagnostics {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        let _ = writeln!(f, "=== Ensemble Diagnostics ===");
90        let _ = writeln!(
91            f,
92            "Trees: {}, Base: {:.4}, LR: {:.4}, Samples: {}",
93            self.n_trees, self.base_prediction, self.learning_rate, self.n_samples,
94        );
95        let _ = writeln!(f, "Total replacements: {}", self.total_replacements);
96
97        // Feature importance (top 10)
98        let mut importance: Vec<(usize, f64)> = self
99            .feature_importance
100            .iter()
101            .enumerate()
102            .map(|(i, &v)| (i, v))
103            .collect();
104        importance.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
105        let top_n = importance.len().min(10);
106        let _ = writeln!(f, "Feature importance (top {top_n}):");
107        for &(feat, imp) in importance.iter().take(top_n) {
108            let _ = writeln!(f, "  feature[{feat}]: {imp:.4}");
109        }
110
111        // Tree summary
112        if !self.trees.is_empty() {
113            let avg_depth = self.trees.iter().map(|t| t.max_depth).sum::<usize>() as f64
114                / self.trees.len() as f64;
115            let avg_nodes = self.trees.iter().map(|t| t.n_nodes).sum::<usize>() as f64
116                / self.trees.len() as f64;
117            let _ = writeln!(f, "Avg depth: {avg_depth:.1}, Avg nodes: {avg_nodes:.1}");
118        }
119
120        Ok(())
121    }
122}
123
124impl fmt::Display for DistributionalDiagnostics {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        let _ = writeln!(f, "=== Distributional Diagnostics ===");
127        let _ = writeln!(f, "--- Location ---");
128        let _ = write!(f, "{}", self.location);
129        if let Some(ref scale) = self.scale {
130            let _ = writeln!(f, "--- Scale ---");
131            let _ = write!(f, "{}", scale);
132        }
133        let _ = writeln!(f, "honest_sigma: {:.6}", self.honest_sigma);
134        let _ = writeln!(
135            f,
136            "rolling_honest_sigma_mean: {:.6}",
137            self.rolling_honest_sigma_mean,
138        );
139        if let Some(mts) = self.effective_mts {
140            let _ = writeln!(f, "effective_mts: {mts}");
141        }
142        Ok(())
143    }
144}
145
146// ---------------------------------------------------------------------------
147// Builder helpers (used by SGBT and DistributionalSGBT)
148// ---------------------------------------------------------------------------
149
150use crate::ensemble::step::BoostingStep;
151
152/// Build per-tree diagnostics and aggregate feature importance from a slice of
153/// boosting steps.
154///
155/// `features` is the feature vector used to compute per-tree contributions.
156/// If `None`, contributions are set to 0.0.
157#[allow(dead_code)]
158pub(crate) fn build_ensemble_diagnostics(
159    steps: &[BoostingStep],
160    base_prediction: f64,
161    learning_rate: f64,
162    n_samples: u64,
163    features: Option<&[f64]>,
164) -> EnsembleDiagnostics {
165    let mut trees = Vec::with_capacity(steps.len());
166    let mut split_counts: Vec<u64> = Vec::new();
167
168    for step in steps {
169        let slot = step.slot();
170        let tree = slot.active_tree();
171        let arena = tree.arena();
172
173        let n_nodes = arena.n_nodes();
174        let n_leaves = arena.n_leaves();
175        let max_depth = (0..arena.is_leaf.len())
176            .filter(|&i| arena.is_leaf[i])
177            .map(|i| arena.depth[i] as usize)
178            .max()
179            .unwrap_or(0);
180        let n_tree_samples = step.n_samples_seen();
181        let n_replacements = slot.replacements();
182
183        let contribution = match features {
184            Some(f) => learning_rate * step.predict(f),
185            None => 0.0,
186        };
187
188        // Accumulate split counts per feature using thresholds.
189        let thresholds = tree.collect_split_thresholds_per_feature();
190        if !thresholds.is_empty() {
191            if split_counts.len() < thresholds.len() {
192                split_counts.resize(thresholds.len(), 0);
193            }
194            for (feat_idx, splits) in thresholds.iter().enumerate() {
195                if feat_idx < split_counts.len() {
196                    split_counts[feat_idx] += splits.len() as u64;
197                }
198            }
199        }
200
201        trees.push(TreeDiagnostics {
202            n_nodes,
203            n_leaves,
204            max_depth,
205            n_samples: n_tree_samples,
206            n_replacements,
207            contribution,
208        });
209    }
210
211    // Normalize split counts to feature importance (sums to 1.0).
212    let total: u64 = split_counts.iter().sum();
213    let feature_importance = if total > 0 {
214        split_counts
215            .iter()
216            .map(|&c| c as f64 / total as f64)
217            .collect()
218    } else {
219        vec![0.0; split_counts.len()]
220    };
221
222    let total_replacements: u64 = trees.iter().map(|t| t.n_replacements).sum();
223
224    EnsembleDiagnostics {
225        n_trees: trees.len(),
226        trees,
227        feature_importance,
228        total_replacements,
229        base_prediction,
230        learning_rate,
231        n_samples,
232    }
233}