Skip to main content

irithyll_core/ensemble/distributional/
diagnostics.rs

1//! Distributional SGBT diagnostic structures and methods.
2
3use alloc::vec;
4use alloc::vec::Vec;
5
6use crate::ensemble::config::ScaleMode;
7use crate::ensemble::step::BoostingStep;
8
9use super::DistributionalSGBT;
10
11/// Per-tree diagnostic summary.
12#[derive(Debug, Clone)]
13pub struct DistributionalTreeDiagnostic {
14    /// Number of leaf nodes in this tree.
15    pub n_leaves: usize,
16    /// Maximum depth reached by any leaf.
17    pub max_depth_reached: usize,
18    /// Total samples this tree has seen.
19    pub samples_seen: u64,
20    /// Leaf weight statistics: `(min, max, mean, std)`.
21    pub leaf_weight_stats: (f64, f64, f64, f64),
22    /// Feature indices this tree has split on (non-zero gain).
23    pub split_features: Vec<usize>,
24    /// Per-leaf sample counts showing data distribution across leaves.
25    pub leaf_sample_counts: Vec<u64>,
26    /// Running mean of predictions from this tree (Welford online).
27    pub prediction_mean: f64,
28    /// Running standard deviation of predictions from this tree.
29    pub prediction_std: f64,
30}
31
32/// Full model diagnostics for [`DistributionalSGBT`].
33///
34/// Contains per-tree summaries, feature usage, base predictions, and
35/// empirical σ state.
36#[derive(Debug, Clone)]
37pub struct ModelDiagnostics {
38    /// Per-tree diagnostic summaries (location trees first, then scale trees).
39    pub trees: Vec<DistributionalTreeDiagnostic>,
40    /// Location trees only (view into `trees`).
41    pub location_trees: Vec<DistributionalTreeDiagnostic>,
42    /// Scale trees only (view into `trees`).
43    pub scale_trees: Vec<DistributionalTreeDiagnostic>,
44    /// How many trees each feature is used in (split count per feature).
45    pub feature_split_counts: Vec<usize>,
46    /// Base prediction for location (mean).
47    pub location_base: f64,
48    /// Base prediction for scale (log-sigma).
49    pub scale_base: f64,
50    /// Current empirical σ (`sqrt(ewma_sq_err)`), always available.
51    pub empirical_sigma: f64,
52    /// Scale mode in use.
53    pub scale_mode: ScaleMode,
54    /// Number of scale trees that actually split (>1 leaf). 0 = frozen chain.
55    pub scale_trees_active: usize,
56    /// Per-feature auto-calibrated bandwidths for smooth prediction.
57    /// `f64::INFINITY` means that feature uses hard routing.
58    pub auto_bandwidths: Vec<f64>,
59    /// Ensemble-level gradient running mean.
60    pub ensemble_grad_mean: f64,
61    /// Ensemble-level gradient standard deviation.
62    pub ensemble_grad_std: f64,
63}
64
65/// Decomposed prediction showing each tree's contribution.
66#[derive(Debug, Clone)]
67pub struct DecomposedPrediction {
68    /// Base location prediction (mean of initial targets).
69    pub location_base: f64,
70    /// Base scale prediction (log-sigma of initial targets).
71    pub scale_base: f64,
72    /// Per-step location contributions: `learning_rate * tree_prediction`.
73    /// `location_base + sum(location_contributions)` = μ.
74    pub location_contributions: Vec<f64>,
75    /// Per-step scale contributions: `learning_rate * tree_prediction`.
76    /// `scale_base + sum(scale_contributions)` = log(σ).
77    pub scale_contributions: Vec<f64>,
78}
79
80impl DecomposedPrediction {
81    /// Reconstruct the final μ from base + contributions.
82    pub fn mu(&self) -> f64 {
83        self.location_base + self.location_contributions.iter().sum::<f64>()
84    }
85
86    /// Reconstruct the final log(σ) from base + contributions.
87    pub fn log_sigma(&self) -> f64 {
88        self.scale_base + self.scale_contributions.iter().sum::<f64>()
89    }
90
91    /// Reconstruct the final σ (exponentiated).
92    pub fn sigma(&self) -> f64 {
93        crate::math::exp(self.log_sigma()).max(1e-8)
94    }
95}
96
97pub(crate) fn compute_diagnostics(model: &DistributionalSGBT) -> ModelDiagnostics {
98    let n = model.location_steps.len();
99    let mut trees = Vec::with_capacity(2 * n);
100    let mut feature_split_counts: Vec<usize> = Vec::new();
101
102    fn collect_tree_diags(
103        steps: &[BoostingStep],
104        trees: &mut Vec<DistributionalTreeDiagnostic>,
105        feature_split_counts: &mut Vec<usize>,
106    ) {
107        for step in steps {
108            let slot = step.slot();
109            let tree = slot.active_tree();
110            let arena = tree.arena();
111
112            let leaf_values: Vec<f64> = (0..arena.is_leaf.len())
113                .filter(|&i| arena.is_leaf[i])
114                .map(|i| arena.leaf_value[i])
115                .collect();
116
117            let leaf_sample_counts: Vec<u64> = (0..arena.is_leaf.len())
118                .filter(|&i| arena.is_leaf[i])
119                .map(|i| arena.sample_count[i])
120                .collect();
121
122            let max_depth_reached = (0..arena.is_leaf.len())
123                .filter(|&i| arena.is_leaf[i])
124                .map(|i| arena.depth[i] as usize)
125                .max()
126                .unwrap_or(0);
127
128            let leaf_weight_stats = if leaf_values.is_empty() {
129                (0.0, 0.0, 0.0, 0.0)
130            } else {
131                let min = leaf_values.iter().cloned().fold(f64::INFINITY, f64::min);
132                let max = leaf_values
133                    .iter()
134                    .cloned()
135                    .fold(f64::NEG_INFINITY, f64::max);
136                let sum: f64 = leaf_values.iter().sum();
137                let mean = sum / leaf_values.len() as f64;
138                let var: f64 = leaf_values
139                    .iter()
140                    .map(|v| {
141                        let d = v - mean;
142                        d * d
143                    })
144                    .sum::<f64>()
145                    / leaf_values.len() as f64;
146                (min, max, mean, crate::math::sqrt(var))
147            };
148
149            let gains = slot.split_gains();
150            let split_features: Vec<usize> = gains
151                .iter()
152                .enumerate()
153                .filter(|(_, &g)| g > 0.0)
154                .map(|(i, _)| i)
155                .collect();
156
157            if !gains.is_empty() {
158                if feature_split_counts.is_empty() {
159                    feature_split_counts.resize(gains.len(), 0);
160                }
161                for &fi in &split_features {
162                    if fi < feature_split_counts.len() {
163                        feature_split_counts[fi] += 1;
164                    }
165                }
166            }
167
168            trees.push(DistributionalTreeDiagnostic {
169                n_leaves: leaf_values.len(),
170                max_depth_reached,
171                samples_seen: step.n_samples_seen(),
172                leaf_weight_stats,
173                split_features,
174                leaf_sample_counts,
175                prediction_mean: slot.prediction_mean(),
176                prediction_std: slot.prediction_std(),
177            });
178        }
179    }
180
181    collect_tree_diags(&model.location_steps, &mut trees, &mut feature_split_counts);
182    collect_tree_diags(&model.scale_steps, &mut trees, &mut feature_split_counts);
183
184    let location_trees = trees[..n].to_vec();
185    let scale_trees = trees[n..].to_vec();
186    let scale_trees_active = scale_trees.iter().filter(|t| t.n_leaves > 1).count();
187
188    ModelDiagnostics {
189        trees,
190        location_trees,
191        scale_trees,
192        feature_split_counts,
193        location_base: model.location_base,
194        scale_base: model.scale_base,
195        empirical_sigma: crate::math::sqrt(model.ewma_sq_err),
196        scale_mode: model.scale_mode,
197        scale_trees_active,
198        auto_bandwidths: model.auto_bandwidths.clone(),
199        ensemble_grad_mean: model.ensemble_grad_mean,
200        ensemble_grad_std: crate::math::sqrt(
201            model.ensemble_grad_m2 / model.ensemble_grad_count.max(1) as f64,
202        ),
203    }
204}
205
206pub(crate) fn decompose_prediction(
207    model: &DistributionalSGBT,
208    features: &[f64],
209) -> DecomposedPrediction {
210    let lr = model.config.learning_rate;
211    let location: Vec<f64> = model
212        .location_steps
213        .iter()
214        .map(|s| lr * s.predict(features))
215        .collect();
216
217    let (sb, scale) = match model.scale_mode {
218        ScaleMode::Empirical => {
219            let empirical_sigma = crate::math::sqrt(model.ewma_sq_err).max(1e-8);
220            (
221                crate::math::ln(empirical_sigma),
222                vec![0.0; model.location_steps.len()],
223            )
224        }
225        ScaleMode::TreeChain => {
226            let s: Vec<f64> = model
227                .scale_steps
228                .iter()
229                .map(|s| lr * s.predict(features))
230                .collect();
231            (model.scale_base, s)
232        }
233    };
234
235    DecomposedPrediction {
236        location_base: model.location_base,
237        scale_base: sb,
238        location_contributions: location,
239        scale_contributions: scale,
240    }
241}
242
243pub(crate) fn compute_feature_importances(
244    model: &DistributionalSGBT,
245    location_only: bool,
246) -> Vec<f64> {
247    let mut totals: Vec<f64> = Vec::new();
248    let steps = if location_only {
249        vec![&model.location_steps]
250    } else {
251        vec![&model.location_steps, &model.scale_steps]
252    };
253
254    for st in steps {
255        for step in st {
256            let gains = step.slot().split_gains();
257            if totals.is_empty() && !gains.is_empty() {
258                totals.resize(gains.len(), 0.0);
259            }
260            for (i, &g) in gains.iter().enumerate() {
261                if i < totals.len() {
262                    totals[i] += g;
263                }
264            }
265        }
266    }
267    let sum: f64 = totals.iter().sum();
268    if sum > 0.0 {
269        totals.iter_mut().for_each(|v| *v /= sum);
270    }
271    totals
272}
273
274pub(crate) fn compute_feature_importances_scale(model: &DistributionalSGBT) -> Vec<f64> {
275    let mut totals: Vec<f64> = Vec::new();
276    for step in &model.scale_steps {
277        let gains = step.slot().split_gains();
278        if totals.is_empty() && !gains.is_empty() {
279            totals.resize(gains.len(), 0.0);
280        }
281        for (i, &g) in gains.iter().enumerate() {
282            if i < totals.len() {
283                totals[i] += g;
284            }
285        }
286    }
287    let sum: f64 = totals.iter().sum();
288    if sum > 0.0 {
289        totals.iter_mut().for_each(|v| *v /= sum);
290    }
291    totals
292}