use super::path::{Explainable, LeafInfo, TreePath, TreeSplit};
use crate::primitives::Matrix;
use crate::tree::DecisionTreeRegressor;
#[derive(Debug, Clone)]
pub struct TreeExplainable {
model: DecisionTreeRegressor,
n_features: usize,
}
impl TreeExplainable {
pub fn new(model: DecisionTreeRegressor, n_features: usize) -> Self {
let test_matrix = Matrix::from_vec(1, n_features, vec![0.0; n_features])
.expect("Test matrix creation should succeed");
let _ = model.predict(&test_matrix);
Self { model, n_features }
}
pub fn model(&self) -> &DecisionTreeRegressor {
&self.model
}
pub fn n_features(&self) -> usize {
self.n_features
}
fn trace_path(&self, sample: &[f32]) -> (f32, Vec<TreeSplit>, LeafInfo) {
let sample_matrix =
Matrix::from_vec(1, self.n_features, sample.to_vec()).expect("Matrix creation");
let prediction = self.model.predict(&sample_matrix);
let pred_value = prediction.as_slice()[0];
let leaf = LeafInfo {
prediction: pred_value,
n_samples: 1, class_distribution: None,
};
(pred_value, Vec::new(), leaf)
}
}
impl Explainable for TreeExplainable {
type Path = TreePath;
fn predict_explained(&self, x: &[f32], n_samples: usize) -> (Vec<f32>, Vec<Self::Path>) {
let n_features = self.n_features();
assert_eq!(
x.len(),
n_features * n_samples,
"Input length {} must equal n_features ({}) * n_samples ({})",
x.len(),
n_features,
n_samples
);
let mut outputs = Vec::with_capacity(n_samples);
let mut paths = Vec::with_capacity(n_samples);
for i in 0..n_samples {
let start = i * n_features;
let end = start + n_features;
let sample = &x[start..end];
let (prediction, splits, leaf) = self.trace_path(sample);
let path = TreePath::new(splits, leaf);
outputs.push(prediction);
paths.push(path);
}
(outputs, paths)
}
fn explain_one(&self, sample: &[f32]) -> Self::Path {
let (_, paths) = self.predict_explained(sample, 1);
paths.into_iter().next().expect("Should have one path")
}
}
pub trait IntoTreeExplainable {
fn into_explainable(self, n_features: usize) -> TreeExplainable;
}
impl IntoTreeExplainable for DecisionTreeRegressor {
fn into_explainable(self, n_features: usize) -> TreeExplainable {
TreeExplainable::new(self, n_features)
}
}