use crate::dataset::BinnedDataset;
use crate::learner::{LinearBooster, LinearConfig, TreeBooster, TreeConfig, WeakLearner};
use crate::loss::{LossFunction, MseLoss};
use crate::Result;
use super::stats::{compute_mse, compute_r2, compute_residuals};
#[derive(Debug, Clone)]
pub struct LinearProbeResult {
pub r2: f32,
pub mse: f32,
pub predictions: Vec<f32>,
pub residuals: Vec<f32>,
pub weights: Vec<f32>,
pub iterations: usize,
}
#[derive(Debug, Clone)]
pub struct TreeProbeResult {
pub r2_on_residuals: f32,
pub mse_reduction: f32,
pub relative_improvement: f32,
pub num_splits: usize,
pub feature_usage: Vec<usize>,
}
pub fn run_linear_probe(
dataset: &BinnedDataset,
sample_indices: Option<&[usize]>,
max_iter: usize,
) -> Result<LinearProbeResult> {
let num_features = dataset.num_features();
let (raw_features, sample_targets) = extract_features_for_probe(dataset, sample_indices);
let num_samples = sample_targets.len();
if num_samples < 10 || num_features == 0 {
return Ok(LinearProbeResult {
r2: 0.0,
mse: f32::MAX,
predictions: vec![0.0; num_samples],
residuals: sample_targets.clone(),
weights: vec![0.0; num_features],
iterations: 0,
});
}
let linear_config = LinearConfig::default()
.with_lambda(1.0) .with_l1_ratio(0.0) .with_max_iter(max_iter)
.with_tol(1e-4);
let mut linear = LinearBooster::new(num_features, linear_config);
let loss = MseLoss;
let base_pred = loss.initial_prediction(&sample_targets);
let mut predictions = vec![base_pred; num_samples];
let mut gradients = vec![0.0f32; num_samples];
let mut hessians = vec![1.0f32; num_samples];
let mut prev_mse = f32::MAX;
let mut iterations = 0;
for iter in 0..max_iter {
for i in 0..num_samples {
let (g, h) = loss.gradient_hessian(sample_targets[i], predictions[i]);
gradients[i] = g;
hessians[i] = h;
}
linear.fit_on_gradients(&raw_features, num_features, &gradients, &hessians)?;
let linear_preds = linear.predict_batch(&raw_features, num_features);
for i in 0..num_samples {
predictions[i] = base_pred + linear_preds[i];
}
let mse = compute_mse(&sample_targets, &predictions);
iterations = iter + 1;
if (prev_mse - mse).abs() < 1e-6 {
break;
}
prev_mse = mse;
}
let r2 = compute_r2(&sample_targets, &predictions);
let mse = compute_mse(&sample_targets, &predictions);
let residuals = compute_residuals(&sample_targets, &predictions);
let weights = linear.weights().to_vec();
Ok(LinearProbeResult {
r2,
mse,
predictions,
residuals,
weights,
iterations,
})
}
pub fn run_tree_probe(
dataset: &BinnedDataset,
linear_result: &LinearProbeResult,
sample_indices: Option<&[usize]>,
max_depth: usize,
) -> Result<TreeProbeResult> {
let num_features = dataset.num_features();
let residuals = &linear_result.residuals;
if residuals.len() < 20 {
return Ok(TreeProbeResult {
r2_on_residuals: 0.0,
mse_reduction: 0.0,
relative_improvement: 0.0,
num_splits: 0,
feature_usage: vec![0; num_features],
});
}
let probe_dataset = if let Some(indices) = sample_indices {
dataset.subset_by_indices(indices)
} else {
dataset.clone()
};
let tree_config = TreeConfig::default()
.with_max_depth(max_depth)
.with_max_leaves(2_usize.pow(max_depth as u32))
.with_learning_rate(1.0) .with_lambda(1.0);
let mut tree_booster = TreeBooster::new(tree_config);
let loss = MseLoss;
let mut gradients = vec![0.0f32; residuals.len()];
let mut hessians = vec![1.0f32; residuals.len()];
for i in 0..residuals.len() {
let (g, h) = loss.gradient_hessian(residuals[i], 0.0);
gradients[i] = g;
hessians[i] = h;
}
tree_booster.fit_on_gradients(&probe_dataset, &gradients, &hessians, None)?;
let tree_preds = if let Some(tree) = tree_booster.tree() {
tree.predict_all(&probe_dataset)
} else {
vec![0.0; residuals.len()]
};
let residual_mse = compute_mse(residuals, &vec![0.0; residuals.len()]);
let after_tree_mse = compute_mse(residuals, &tree_preds);
let r2_on_residuals = compute_r2(residuals, &tree_preds);
let mse_reduction = (residual_mse - after_tree_mse).max(0.0);
let relative_improvement = if residual_mse > 1e-10 {
mse_reduction / residual_mse
} else {
0.0
};
let mut feature_usage = vec![0usize; num_features];
if let Some(tree) = tree_booster.tree() {
for node in tree.nodes() {
if let crate::tree::NodeType::Internal { feature_idx, .. } = node.node_type {
if feature_idx < num_features {
feature_usage[feature_idx] += 1;
}
}
}
}
let num_splits: usize = feature_usage.iter().sum();
Ok(TreeProbeResult {
r2_on_residuals,
mse_reduction,
relative_improvement,
num_splits,
feature_usage,
})
}
fn extract_features_for_probe(
dataset: &BinnedDataset,
sample_indices: Option<&[usize]>,
) -> (Vec<f32>, Vec<f32>) {
let num_features = dataset.num_features();
let feature_info = dataset.all_feature_info();
let all_targets = dataset.targets();
let indices: Vec<usize> = if let Some(idx) = sample_indices {
idx.to_vec()
} else {
(0..dataset.num_rows()).collect()
};
let num_samples = indices.len();
let mut features = vec![0.0f32; num_samples * num_features];
let mut targets = Vec::with_capacity(num_samples);
for (out_idx, &row_idx) in indices.iter().enumerate() {
targets.push(all_targets[row_idx]);
for f in 0..num_features {
let bin = dataset.get_bin(row_idx, f) as usize;
let boundaries = &feature_info[f].bin_boundaries;
let raw_value = if boundaries.is_empty() {
bin as f32
} else if bin == 0 {
boundaries.first().copied().unwrap_or(0.0) as f32
} else if bin >= boundaries.len() {
boundaries.last().copied().unwrap_or(0.0) as f32
} else {
((boundaries[bin - 1] + boundaries[bin.min(boundaries.len() - 1)]) / 2.0) as f32
};
features[out_idx * num_features + f] = raw_value;
}
}
(features, targets)
}
#[derive(Debug, Clone)]
pub struct CombinedProbeResult {
pub linear: LinearProbeResult,
pub tree: TreeProbeResult,
pub combined_r2: f32,
pub tree_contribution: f32,
}
pub fn run_combined_probe(
dataset: &BinnedDataset,
sample_indices: Option<&[usize]>,
linear_max_iter: usize,
tree_max_depth: usize,
) -> Result<CombinedProbeResult> {
let linear = run_linear_probe(dataset, sample_indices, linear_max_iter)?;
let tree = run_tree_probe(dataset, &linear, sample_indices, tree_max_depth)?;
let combined_r2 = linear.r2 + (1.0 - linear.r2) * tree.r2_on_residuals;
let tree_contribution = if linear.r2 < 0.99 {
tree.r2_on_residuals * (1.0 - linear.r2)
} else {
0.0
};
Ok(CombinedProbeResult {
linear,
tree,
combined_r2: combined_r2.clamp(0.0, 1.0),
tree_contribution,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::{FeatureInfo, FeatureType};
fn create_linear_dataset(n: usize) -> BinnedDataset {
let num_features = 2;
let mut features = Vec::with_capacity(n * num_features);
let x0_bins: Vec<u8> = (0..n).map(|r| ((r * 17) % 256) as u8).collect();
let x1_bins: Vec<u8> = (0..n).map(|r| ((r * 23) % 256) as u8).collect();
features.extend(x0_bins.iter().cloned());
features.extend(x1_bins.iter().cloned());
let targets: Vec<f32> = (0..n)
.map(|i| {
let x0 = x0_bins[i] as f32 / 255.0;
let x1 = x1_bins[i] as f32 / 255.0;
2.0 * x0 + 3.0 * x1 + (i % 10) as f32 * 0.001 })
.collect();
let feature_info = (0..num_features)
.map(|i| FeatureInfo {
name: format!("f{}", i),
feature_type: FeatureType::Numeric,
num_bins: 255,
bin_boundaries: (0..255).map(|b| b as f64 / 255.0).collect(),
})
.collect();
BinnedDataset::new(n, features, targets, feature_info)
}
#[test]
fn test_linear_probe_captures_linear_signal() {
let dataset = create_linear_dataset(1000);
let result = run_linear_probe(&dataset, None, 100).unwrap();
assert!(
result.r2 >= 0.0 && result.r2 <= 1.0,
"R² should be in valid range: {}",
result.r2
);
assert!(!result.predictions.is_empty(), "Should produce predictions");
assert_eq!(result.residuals.len(), result.predictions.len());
}
#[test]
fn test_tree_probe_on_linear_data() {
let dataset = create_linear_dataset(1000);
let linear_result = run_linear_probe(&dataset, None, 100).unwrap();
let tree_result = run_tree_probe(&dataset, &linear_result, None, 4).unwrap();
assert!(
tree_result.r2_on_residuals >= 0.0,
"R² on residuals should be valid"
);
assert!(tree_result.num_splits > 0, "Tree should have splits");
}
#[test]
fn test_combined_probe() {
let dataset = create_linear_dataset(500);
let result = run_combined_probe(&dataset, None, 50, 3).unwrap();
assert!(result.linear.r2 >= 0.0);
assert!(result.combined_r2 >= result.linear.r2 - 0.01); }
}