use super::path::{Explainable, ForestPath, LeafInfo, TreePath};
use crate::primitives::Matrix;
use crate::tree::RandomForestRegressor;
#[derive(Debug, Clone)]
pub struct EnsembleExplainable {
model: RandomForestRegressor,
n_features: usize,
}
impl EnsembleExplainable {
pub fn new(model: RandomForestRegressor, n_features: usize) -> Self {
let test_matrix = Matrix::from_vec(1, n_features, vec![0.0; n_features])
.expect("Test matrix creation should succeed");
let _ = model.predict(&test_matrix);
Self { model, n_features }
}
pub fn model(&self) -> &RandomForestRegressor {
&self.model
}
pub fn n_features(&self) -> usize {
self.n_features
}
pub fn feature_importances(&self) -> Option<Vec<f32>> {
self.model.feature_importances()
}
fn predict_with_trees(&self, sample: &[f32]) -> (f32, Vec<f32>, Vec<TreePath>) {
let sample_matrix =
Matrix::from_vec(1, self.n_features, sample.to_vec()).expect("Matrix creation");
let prediction = self.model.predict(&sample_matrix);
let ensemble_pred = prediction.as_slice()[0];
let tree_paths = vec![TreePath::new(
Vec::new(),
LeafInfo {
prediction: ensemble_pred,
n_samples: 1,
class_distribution: None,
},
)];
let tree_predictions = vec![ensemble_pred];
(ensemble_pred, tree_predictions, tree_paths)
}
}
impl Explainable for EnsembleExplainable {
type Path = ForestPath;
fn predict_explained(&self, x: &[f32], n_samples: usize) -> (Vec<f32>, Vec<Self::Path>) {
let n_features = self.n_features();
assert_eq!(
x.len(),
n_features * n_samples,
"Input length {} must equal n_features ({}) * n_samples ({})",
x.len(),
n_features,
n_samples
);
let mut outputs = Vec::with_capacity(n_samples);
let mut paths = Vec::with_capacity(n_samples);
#[allow(clippy::disallowed_methods)]
let feature_importance = self.feature_importances().unwrap_or_default();
for i in 0..n_samples {
let start = i * n_features;
let end = start + n_features;
let sample = &x[start..end];
let (ensemble_pred, tree_predictions, tree_paths) = self.predict_with_trees(sample);
let path = ForestPath::new(tree_paths, tree_predictions)
.with_feature_importance(feature_importance.clone());
outputs.push(ensemble_pred);
paths.push(path);
}
(outputs, paths)
}
fn explain_one(&self, sample: &[f32]) -> Self::Path {
let (_, paths) = self.predict_explained(sample, 1);
paths.into_iter().next().expect("Should have one path")
}
}
pub trait IntoEnsembleExplainable {
fn into_explainable(self, n_features: usize) -> EnsembleExplainable;
}
impl IntoEnsembleExplainable for RandomForestRegressor {
fn into_explainable(self, n_features: usize) -> EnsembleExplainable {
EnsembleExplainable::new(self, n_features)
}
}