use crate::{
shap::{compute_shap_values_for_example, ComputeShapValuesForExampleOutput},
train::{train, Task, TrainOutput},
train_tree::TrainTree,
Progress, TrainOptions, Tree,
};
use ndarray::prelude::*;
use num::ToPrimitive;
use rayon::prelude::*;
use tangram_table::prelude::*;
use tangram_zip::{pzip, zip};
#[derive(Clone, Debug)]
pub struct Regressor {
pub bias: f32,
pub trees: Vec<Tree>,
}
#[derive(Debug)]
pub struct RegressorTrainOutput {
pub model: Regressor,
pub losses: Option<Vec<f32>>,
pub feature_importances: Option<Vec<f32>>,
}
impl Regressor {
pub fn train(
features: TableView,
labels: NumberTableColumnView,
train_options: &TrainOptions,
progress: Progress,
) -> RegressorTrainOutput {
let task = Task::Regression;
let train_output = train(
task,
features,
TableColumnView::Number(labels),
train_options,
progress,
);
match train_output {
TrainOutput::Regressor(train_output) => train_output,
_ => unreachable!(),
}
}
pub fn predict(&self, features: ArrayView2<TableValue>, mut predictions: ArrayViewMut1<f32>) {
predictions.fill(self.bias);
let predictions = predictions.as_slice_mut().unwrap();
for tree in self.trees.iter() {
zip!(features.axis_iter(Axis(0)), predictions.iter_mut()).for_each(
|(example, prediction)| {
*prediction += tree.predict(example.as_slice().unwrap());
},
)
}
}
pub fn compute_feature_contributions(
&self,
features: ArrayView2<TableValue>,
) -> Vec<ComputeShapValuesForExampleOutput> {
let trees = ArrayView1::from_shape(self.trees.len(), &self.trees).unwrap();
features
.axis_iter(Axis(0))
.map(|features| {
compute_shap_values_for_example(features.as_slice().unwrap(), trees, self.bias)
})
.collect()
}
pub fn from_reader(regressor: crate::serialize::RegressorReader) -> Regressor {
crate::serialize::deserialize_regressor(regressor)
}
pub fn to_writer(
&self,
writer: &mut buffalo::Writer,
) -> buffalo::Position<crate::serialize::RegressorWriter> {
crate::serialize::serialize_regressor(self, writer)
}
pub fn from_bytes(&self, bytes: &[u8]) -> Regressor {
let reader = buffalo::read::<crate::serialize::RegressorReader>(bytes);
Self::from_reader(reader)
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut writer = buffalo::Writer::new();
self.to_writer(&mut writer);
writer.into_bytes()
}
}
pub fn update_logits(
trees_for_round: &[TrainTree],
features: ArrayView2<TableValue>,
mut predictions: ArrayViewMut2<f32>,
) {
for (prediction, features) in zip!(predictions.row_mut(0), features.axis_iter(Axis(0))) {
for tree in trees_for_round {
*prediction += tree.predict(features.as_slice().unwrap());
}
}
}
pub fn compute_loss(predictions: ArrayView2<f32>, labels: ArrayView1<f32>) -> f32 {
let mut loss = 0.0;
for (label, prediction) in zip!(labels, predictions) {
loss += 0.5 * (label - prediction) * (label - prediction)
}
loss / labels.len().to_f32().unwrap()
}
pub fn compute_biases(labels: ArrayView1<f32>) -> Array1<f32> {
arr1(&[labels.mean().unwrap()])
}
pub fn compute_gradients_and_hessians(
gradients: &mut [f32],
_hessians: &mut [f32],
labels: &[f32],
predictions: &[f32],
) {
pzip!(gradients, labels, predictions).for_each(|(gradient, label, prediction)| {
*gradient = prediction - label;
});
}