1use alloc::vec;
8use alloc::vec::Vec;
9
10use core::cmp::Ordering;
11use core::fmt;
12
13#[derive(Debug, Clone)]
15pub struct TreeDiagnostics {
16 pub n_nodes: usize,
18 pub n_leaves: usize,
20 pub max_depth: usize,
22 pub n_samples: u64,
24 pub n_replacements: u64,
26 pub contribution: f64,
28}
29
30#[derive(Debug, Clone)]
32pub struct EnsembleDiagnostics {
33 pub trees: Vec<TreeDiagnostics>,
35 pub feature_importance: Vec<f64>,
38 pub total_replacements: u64,
40 pub n_trees: usize,
42 pub base_prediction: f64,
44 pub learning_rate: f64,
46 pub n_samples: u64,
48}
49
50#[derive(Debug, Clone)]
52pub struct DistributionalDiagnostics {
53 pub location: EnsembleDiagnostics,
55 pub scale: Option<EnsembleDiagnostics>,
57 pub honest_sigma: f64,
60 pub rolling_honest_sigma_mean: f64,
63 pub effective_mts: Option<u64>,
66}
67
68impl fmt::Display for TreeDiagnostics {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 write!(
75 f,
76 "nodes={}, leaves={}, depth={}, samples={}, replacements={}, contribution={:.6}",
77 self.n_nodes,
78 self.n_leaves,
79 self.max_depth,
80 self.n_samples,
81 self.n_replacements,
82 self.contribution,
83 )
84 }
85}
86
87impl fmt::Display for EnsembleDiagnostics {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 let _ = writeln!(f, "=== Ensemble Diagnostics ===");
90 let _ = writeln!(
91 f,
92 "Trees: {}, Base: {:.4}, LR: {:.4}, Samples: {}",
93 self.n_trees, self.base_prediction, self.learning_rate, self.n_samples,
94 );
95 let _ = writeln!(f, "Total replacements: {}", self.total_replacements);
96
97 let mut importance: Vec<(usize, f64)> = self
99 .feature_importance
100 .iter()
101 .enumerate()
102 .map(|(i, &v)| (i, v))
103 .collect();
104 importance.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
105 let top_n = importance.len().min(10);
106 let _ = writeln!(f, "Feature importance (top {top_n}):");
107 for &(feat, imp) in importance.iter().take(top_n) {
108 let _ = writeln!(f, " feature[{feat}]: {imp:.4}");
109 }
110
111 if !self.trees.is_empty() {
113 let avg_depth = self.trees.iter().map(|t| t.max_depth).sum::<usize>() as f64
114 / self.trees.len() as f64;
115 let avg_nodes = self.trees.iter().map(|t| t.n_nodes).sum::<usize>() as f64
116 / self.trees.len() as f64;
117 let _ = writeln!(f, "Avg depth: {avg_depth:.1}, Avg nodes: {avg_nodes:.1}");
118 }
119
120 Ok(())
121 }
122}
123
124impl fmt::Display for DistributionalDiagnostics {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 let _ = writeln!(f, "=== Distributional Diagnostics ===");
127 let _ = writeln!(f, "--- Location ---");
128 let _ = write!(f, "{}", self.location);
129 if let Some(ref scale) = self.scale {
130 let _ = writeln!(f, "--- Scale ---");
131 let _ = write!(f, "{}", scale);
132 }
133 let _ = writeln!(f, "honest_sigma: {:.6}", self.honest_sigma);
134 let _ = writeln!(
135 f,
136 "rolling_honest_sigma_mean: {:.6}",
137 self.rolling_honest_sigma_mean,
138 );
139 if let Some(mts) = self.effective_mts {
140 let _ = writeln!(f, "effective_mts: {mts}");
141 }
142 Ok(())
143 }
144}
145
146use crate::ensemble::step::BoostingStep;
151
152#[allow(dead_code)]
158pub(crate) fn build_ensemble_diagnostics(
159 steps: &[BoostingStep],
160 base_prediction: f64,
161 learning_rate: f64,
162 n_samples: u64,
163 features: Option<&[f64]>,
164) -> EnsembleDiagnostics {
165 let mut trees = Vec::with_capacity(steps.len());
166 let mut split_counts: Vec<u64> = Vec::new();
167
168 for step in steps {
169 let slot = step.slot();
170 let tree = slot.active_tree();
171 let arena = tree.arena();
172
173 let n_nodes = arena.n_nodes();
174 let n_leaves = arena.n_leaves();
175 let max_depth = (0..arena.is_leaf.len())
176 .filter(|&i| arena.is_leaf[i])
177 .map(|i| arena.depth[i] as usize)
178 .max()
179 .unwrap_or(0);
180 let n_tree_samples = step.n_samples_seen();
181 let n_replacements = slot.replacements();
182
183 let contribution = match features {
184 Some(f) => learning_rate * step.predict(f),
185 None => 0.0,
186 };
187
188 let thresholds = tree.collect_split_thresholds_per_feature();
190 if !thresholds.is_empty() {
191 if split_counts.len() < thresholds.len() {
192 split_counts.resize(thresholds.len(), 0);
193 }
194 for (feat_idx, splits) in thresholds.iter().enumerate() {
195 if feat_idx < split_counts.len() {
196 split_counts[feat_idx] += splits.len() as u64;
197 }
198 }
199 }
200
201 trees.push(TreeDiagnostics {
202 n_nodes,
203 n_leaves,
204 max_depth,
205 n_samples: n_tree_samples,
206 n_replacements,
207 contribution,
208 });
209 }
210
211 let total: u64 = split_counts.iter().sum();
213 let feature_importance = if total > 0 {
214 split_counts
215 .iter()
216 .map(|&c| c as f64 / total as f64)
217 .collect()
218 } else {
219 vec![0.0; split_counts.len()]
220 };
221
222 let total_replacements: u64 = trees.iter().map(|t| t.n_replacements).sum();
223
224 EnsembleDiagnostics {
225 n_trees: trees.len(),
226 trees,
227 feature_importance,
228 total_replacements,
229 base_prediction,
230 learning_rate,
231 n_samples,
232 }
233}