irithyll_core/ensemble/distributional/
diagnostics.rs1use alloc::vec;
4use alloc::vec::Vec;
5
6use crate::ensemble::config::ScaleMode;
7use crate::ensemble::step::BoostingStep;
8
9use super::DistributionalSGBT;
10
11#[derive(Debug, Clone)]
13pub struct DistributionalTreeDiagnostic {
14 pub n_leaves: usize,
16 pub max_depth_reached: usize,
18 pub samples_seen: u64,
20 pub leaf_weight_stats: (f64, f64, f64, f64),
22 pub split_features: Vec<usize>,
24 pub leaf_sample_counts: Vec<u64>,
26 pub prediction_mean: f64,
28 pub prediction_std: f64,
30}
31
32#[derive(Debug, Clone)]
37pub struct ModelDiagnostics {
38 pub trees: Vec<DistributionalTreeDiagnostic>,
40 pub location_trees: Vec<DistributionalTreeDiagnostic>,
42 pub scale_trees: Vec<DistributionalTreeDiagnostic>,
44 pub feature_split_counts: Vec<usize>,
46 pub location_base: f64,
48 pub scale_base: f64,
50 pub empirical_sigma: f64,
52 pub scale_mode: ScaleMode,
54 pub scale_trees_active: usize,
56 pub auto_bandwidths: Vec<f64>,
59 pub ensemble_grad_mean: f64,
61 pub ensemble_grad_std: f64,
63}
64
65#[derive(Debug, Clone)]
67pub struct DecomposedPrediction {
68 pub location_base: f64,
70 pub scale_base: f64,
72 pub location_contributions: Vec<f64>,
75 pub scale_contributions: Vec<f64>,
78}
79
80impl DecomposedPrediction {
81 pub fn mu(&self) -> f64 {
83 self.location_base + self.location_contributions.iter().sum::<f64>()
84 }
85
86 pub fn log_sigma(&self) -> f64 {
88 self.scale_base + self.scale_contributions.iter().sum::<f64>()
89 }
90
91 pub fn sigma(&self) -> f64 {
93 crate::math::exp(self.log_sigma()).max(1e-8)
94 }
95}
96
97pub(crate) fn compute_diagnostics(model: &DistributionalSGBT) -> ModelDiagnostics {
98 let n = model.location_steps.len();
99 let mut trees = Vec::with_capacity(2 * n);
100 let mut feature_split_counts: Vec<usize> = Vec::new();
101
102 fn collect_tree_diags(
103 steps: &[BoostingStep],
104 trees: &mut Vec<DistributionalTreeDiagnostic>,
105 feature_split_counts: &mut Vec<usize>,
106 ) {
107 for step in steps {
108 let slot = step.slot();
109 let tree = slot.active_tree();
110 let arena = tree.arena();
111
112 let leaf_values: Vec<f64> = (0..arena.is_leaf.len())
113 .filter(|&i| arena.is_leaf[i])
114 .map(|i| arena.leaf_value[i])
115 .collect();
116
117 let leaf_sample_counts: Vec<u64> = (0..arena.is_leaf.len())
118 .filter(|&i| arena.is_leaf[i])
119 .map(|i| arena.sample_count[i])
120 .collect();
121
122 let max_depth_reached = (0..arena.is_leaf.len())
123 .filter(|&i| arena.is_leaf[i])
124 .map(|i| arena.depth[i] as usize)
125 .max()
126 .unwrap_or(0);
127
128 let leaf_weight_stats = if leaf_values.is_empty() {
129 (0.0, 0.0, 0.0, 0.0)
130 } else {
131 let min = leaf_values.iter().cloned().fold(f64::INFINITY, f64::min);
132 let max = leaf_values
133 .iter()
134 .cloned()
135 .fold(f64::NEG_INFINITY, f64::max);
136 let sum: f64 = leaf_values.iter().sum();
137 let mean = sum / leaf_values.len() as f64;
138 let var: f64 = leaf_values
139 .iter()
140 .map(|v| {
141 let d = v - mean;
142 d * d
143 })
144 .sum::<f64>()
145 / leaf_values.len() as f64;
146 (min, max, mean, crate::math::sqrt(var))
147 };
148
149 let gains = slot.split_gains();
150 let split_features: Vec<usize> = gains
151 .iter()
152 .enumerate()
153 .filter(|(_, &g)| g > 0.0)
154 .map(|(i, _)| i)
155 .collect();
156
157 if !gains.is_empty() {
158 if feature_split_counts.is_empty() {
159 feature_split_counts.resize(gains.len(), 0);
160 }
161 for &fi in &split_features {
162 if fi < feature_split_counts.len() {
163 feature_split_counts[fi] += 1;
164 }
165 }
166 }
167
168 trees.push(DistributionalTreeDiagnostic {
169 n_leaves: leaf_values.len(),
170 max_depth_reached,
171 samples_seen: step.n_samples_seen(),
172 leaf_weight_stats,
173 split_features,
174 leaf_sample_counts,
175 prediction_mean: slot.prediction_mean(),
176 prediction_std: slot.prediction_std(),
177 });
178 }
179 }
180
181 collect_tree_diags(&model.location_steps, &mut trees, &mut feature_split_counts);
182 collect_tree_diags(&model.scale_steps, &mut trees, &mut feature_split_counts);
183
184 let location_trees = trees[..n].to_vec();
185 let scale_trees = trees[n..].to_vec();
186 let scale_trees_active = scale_trees.iter().filter(|t| t.n_leaves > 1).count();
187
188 ModelDiagnostics {
189 trees,
190 location_trees,
191 scale_trees,
192 feature_split_counts,
193 location_base: model.location_base,
194 scale_base: model.scale_base,
195 empirical_sigma: crate::math::sqrt(model.ewma_sq_err),
196 scale_mode: model.scale_mode,
197 scale_trees_active,
198 auto_bandwidths: model.auto_bandwidths.clone(),
199 ensemble_grad_mean: model.ensemble_grad_mean,
200 ensemble_grad_std: crate::math::sqrt(
201 model.ensemble_grad_m2 / model.ensemble_grad_count.max(1) as f64,
202 ),
203 }
204}
205
206pub(crate) fn decompose_prediction(
207 model: &DistributionalSGBT,
208 features: &[f64],
209) -> DecomposedPrediction {
210 let lr = model.config.learning_rate;
211 let location: Vec<f64> = model
212 .location_steps
213 .iter()
214 .map(|s| lr * s.predict(features))
215 .collect();
216
217 let (sb, scale) = match model.scale_mode {
218 ScaleMode::Empirical => {
219 let empirical_sigma = crate::math::sqrt(model.ewma_sq_err).max(1e-8);
220 (
221 crate::math::ln(empirical_sigma),
222 vec![0.0; model.location_steps.len()],
223 )
224 }
225 ScaleMode::TreeChain => {
226 let s: Vec<f64> = model
227 .scale_steps
228 .iter()
229 .map(|s| lr * s.predict(features))
230 .collect();
231 (model.scale_base, s)
232 }
233 };
234
235 DecomposedPrediction {
236 location_base: model.location_base,
237 scale_base: sb,
238 location_contributions: location,
239 scale_contributions: scale,
240 }
241}
242
243pub(crate) fn compute_feature_importances(
244 model: &DistributionalSGBT,
245 location_only: bool,
246) -> Vec<f64> {
247 let mut totals: Vec<f64> = Vec::new();
248 let steps = if location_only {
249 vec![&model.location_steps]
250 } else {
251 vec![&model.location_steps, &model.scale_steps]
252 };
253
254 for st in steps {
255 for step in st {
256 let gains = step.slot().split_gains();
257 if totals.is_empty() && !gains.is_empty() {
258 totals.resize(gains.len(), 0.0);
259 }
260 for (i, &g) in gains.iter().enumerate() {
261 if i < totals.len() {
262 totals[i] += g;
263 }
264 }
265 }
266 }
267 let sum: f64 = totals.iter().sum();
268 if sum > 0.0 {
269 totals.iter_mut().for_each(|v| *v /= sum);
270 }
271 totals
272}
273
274pub(crate) fn compute_feature_importances_scale(model: &DistributionalSGBT) -> Vec<f64> {
275 let mut totals: Vec<f64> = Vec::new();
276 for step in &model.scale_steps {
277 let gains = step.slot().split_gains();
278 if totals.is_empty() && !gains.is_empty() {
279 totals.resize(gains.len(), 0.0);
280 }
281 for (i, &g) in gains.iter().enumerate() {
282 if i < totals.len() {
283 totals[i] += g;
284 }
285 }
286 }
287 let sum: f64 = totals.iter().sum();
288 if sum > 0.0 {
289 totals.iter_mut().for_each(|v| *v /= sum);
290 }
291 totals
292}