use alloc::vec;
use alloc::vec::Vec;
use crate::ensemble::config::ScaleMode;
use crate::ensemble::step::BoostingStep;
use super::DistributionalSGBT;
#[derive(Debug, Clone)]
pub struct DistributionalTreeDiagnostic {
pub n_leaves: usize,
pub max_depth_reached: usize,
pub samples_seen: u64,
pub leaf_weight_stats: (f64, f64, f64, f64),
pub split_features: Vec<usize>,
pub leaf_sample_counts: Vec<u64>,
pub prediction_mean: f64,
pub prediction_std: f64,
}
#[derive(Debug, Clone)]
pub struct ModelDiagnostics {
pub trees: Vec<DistributionalTreeDiagnostic>,
pub location_trees: Vec<DistributionalTreeDiagnostic>,
pub scale_trees: Vec<DistributionalTreeDiagnostic>,
pub feature_split_counts: Vec<usize>,
pub location_base: f64,
pub scale_base: f64,
pub empirical_sigma: f64,
pub scale_mode: ScaleMode,
pub scale_trees_active: usize,
pub auto_bandwidths: Vec<f64>,
pub ensemble_grad_mean: f64,
pub ensemble_grad_std: f64,
}
#[derive(Debug, Clone)]
pub struct DecomposedPrediction {
pub location_base: f64,
pub scale_base: f64,
pub location_contributions: Vec<f64>,
pub scale_contributions: Vec<f64>,
}
impl DecomposedPrediction {
pub fn mu(&self) -> f64 {
self.location_base + self.location_contributions.iter().sum::<f64>()
}
pub fn log_sigma(&self) -> f64 {
self.scale_base + self.scale_contributions.iter().sum::<f64>()
}
pub fn sigma(&self) -> f64 {
crate::math::exp(self.log_sigma()).max(1e-8)
}
}
pub(crate) fn compute_diagnostics(model: &DistributionalSGBT) -> ModelDiagnostics {
let n = model.location_steps.len();
let mut trees = Vec::with_capacity(2 * n);
let mut feature_split_counts: Vec<usize> = Vec::new();
fn collect_tree_diags(
steps: &[BoostingStep],
trees: &mut Vec<DistributionalTreeDiagnostic>,
feature_split_counts: &mut Vec<usize>,
) {
for step in steps {
let slot = step.slot();
let tree = slot.active_tree();
let arena = tree.arena();
let leaf_values: Vec<f64> = (0..arena.is_leaf.len())
.filter(|&i| arena.is_leaf[i])
.map(|i| arena.leaf_value[i])
.collect();
let leaf_sample_counts: Vec<u64> = (0..arena.is_leaf.len())
.filter(|&i| arena.is_leaf[i])
.map(|i| arena.sample_count[i])
.collect();
let max_depth_reached = (0..arena.is_leaf.len())
.filter(|&i| arena.is_leaf[i])
.map(|i| arena.depth[i] as usize)
.max()
.unwrap_or(0);
let leaf_weight_stats = if leaf_values.is_empty() {
(0.0, 0.0, 0.0, 0.0)
} else {
let min = leaf_values.iter().cloned().fold(f64::INFINITY, f64::min);
let max = leaf_values
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
let sum: f64 = leaf_values.iter().sum();
let mean = sum / leaf_values.len() as f64;
let var: f64 = leaf_values
.iter()
.map(|v| {
let d = v - mean;
d * d
})
.sum::<f64>()
/ leaf_values.len() as f64;
(min, max, mean, crate::math::sqrt(var))
};
let gains = slot.split_gains();
let split_features: Vec<usize> = gains
.iter()
.enumerate()
.filter(|(_, &g)| g > 0.0)
.map(|(i, _)| i)
.collect();
if !gains.is_empty() {
if feature_split_counts.is_empty() {
feature_split_counts.resize(gains.len(), 0);
}
for &fi in &split_features {
if fi < feature_split_counts.len() {
feature_split_counts[fi] += 1;
}
}
}
trees.push(DistributionalTreeDiagnostic {
n_leaves: leaf_values.len(),
max_depth_reached,
samples_seen: step.n_samples_seen(),
leaf_weight_stats,
split_features,
leaf_sample_counts,
prediction_mean: slot.prediction_mean(),
prediction_std: slot.prediction_std(),
});
}
}
collect_tree_diags(&model.location_steps, &mut trees, &mut feature_split_counts);
collect_tree_diags(&model.scale_steps, &mut trees, &mut feature_split_counts);
let location_trees = trees[..n].to_vec();
let scale_trees = trees[n..].to_vec();
let scale_trees_active = scale_trees.iter().filter(|t| t.n_leaves > 1).count();
ModelDiagnostics {
trees,
location_trees,
scale_trees,
feature_split_counts,
location_base: model.location_base,
scale_base: model.scale_base,
empirical_sigma: crate::math::sqrt(model.ewma_sq_err),
scale_mode: model.scale_mode,
scale_trees_active,
auto_bandwidths: model.auto_bandwidths.clone(),
ensemble_grad_mean: model.ensemble_grad_mean,
ensemble_grad_std: crate::math::sqrt(
model.ensemble_grad_m2 / model.ensemble_grad_count.max(1) as f64,
),
}
}
pub(crate) fn decompose_prediction(
model: &DistributionalSGBT,
features: &[f64],
) -> DecomposedPrediction {
let lr = model.config.learning_rate;
let location: Vec<f64> = model
.location_steps
.iter()
.map(|s| lr * s.predict(features))
.collect();
let (sb, scale) = match model.scale_mode {
ScaleMode::Empirical => {
let empirical_sigma = crate::math::sqrt(model.ewma_sq_err).max(1e-8);
(
crate::math::ln(empirical_sigma),
vec![0.0; model.location_steps.len()],
)
}
ScaleMode::TreeChain => {
let s: Vec<f64> = model
.scale_steps
.iter()
.map(|s| lr * s.predict(features))
.collect();
(model.scale_base, s)
}
};
DecomposedPrediction {
location_base: model.location_base,
scale_base: sb,
location_contributions: location,
scale_contributions: scale,
}
}
pub(crate) fn compute_feature_importances(
model: &DistributionalSGBT,
location_only: bool,
) -> Vec<f64> {
let mut totals: Vec<f64> = Vec::new();
let steps = if location_only {
vec![&model.location_steps]
} else {
vec![&model.location_steps, &model.scale_steps]
};
for st in steps {
for step in st {
let gains = step.slot().split_gains();
if totals.is_empty() && !gains.is_empty() {
totals.resize(gains.len(), 0.0);
}
for (i, &g) in gains.iter().enumerate() {
if i < totals.len() {
totals[i] += g;
}
}
}
}
let sum: f64 = totals.iter().sum();
if sum > 0.0 {
totals.iter_mut().for_each(|v| *v /= sum);
}
totals
}
pub(crate) fn compute_feature_importances_scale(model: &DistributionalSGBT) -> Vec<f64> {
let mut totals: Vec<f64> = Vec::new();
for step in &model.scale_steps {
let gains = step.slot().split_gains();
if totals.is_empty() && !gains.is_empty() {
totals.resize(gains.len(), 0.0);
}
for (i, &g) in gains.iter().enumerate() {
if i < totals.len() {
totals[i] += g;
}
}
}
let sum: f64 = totals.iter().sum();
if sum > 0.0 {
totals.iter_mut().for_each(|v| *v /= sum);
}
totals
}