use crate::{
shap::{compute_shap_values_for_example, ComputeShapValuesForExampleOutput},
train::TrainOutput,
train_tree::TrainTree,
Progress, TrainOptions, Tree,
};
use ndarray::prelude::*;
use num::{clamp, ToPrimitive};
use rayon::{self, prelude::*};
use std::num::NonZeroUsize;
use tangram_table::prelude::*;
use tangram_zip::{pzip, zip};
#[derive(Clone, Debug)]
pub struct MulticlassClassifier {
pub biases: Array1<f32>,
pub trees: Array2<Tree>,
}
#[derive(Debug)]
pub struct MulticlassClassifierTrainOutput {
pub model: MulticlassClassifier,
pub losses: Option<Vec<f32>>,
pub feature_importances: Option<Vec<f32>>,
}
impl MulticlassClassifier {
pub fn train(
features: TableView,
labels: EnumTableColumnView,
train_options: &TrainOptions,
progress: Progress,
) -> MulticlassClassifierTrainOutput {
let task = crate::train::Task::MulticlassClassification {
n_classes: labels.variants().len(),
};
let train_output = crate::train::train(
task,
features,
TableColumnView::Enum(labels),
train_options,
progress,
);
match train_output {
TrainOutput::MulticlassClassifier(train_output) => train_output,
_ => unreachable!(),
}
}
pub fn predict(&self, features: ArrayView2<TableValue>, mut probabilities: ArrayViewMut2<f32>) {
zip!(
probabilities.axis_iter_mut(Axis(0)),
features.axis_iter(Axis(0))
)
.for_each(|(mut logits, example)| {
logits.assign(&self.biases);
for trees in self.trees.axis_iter(Axis(0)) {
for (logit, tree) in zip!(logits.iter_mut(), trees.iter()) {
*logit += tree.predict(example.as_slice().unwrap());
}
}
softmax(logits.as_slice_mut().unwrap());
});
}
pub fn compute_feature_contributions(
&self,
features: ArrayView2<TableValue>,
) -> Vec<Vec<ComputeShapValuesForExampleOutput>> {
features
.axis_iter(Axis(0))
.map(|features| {
zip!(self.trees.axis_iter(Axis(1)), self.biases.iter())
.map(|(tree, bias)| {
compute_shap_values_for_example(features.as_slice().unwrap(), tree, *bias)
})
.collect()
})
.collect()
}
pub fn from_reader(
multiclass_classifier: crate::serialize::MulticlassClassifierReader,
) -> MulticlassClassifier {
crate::serialize::deserialize_multiclass_classifier(multiclass_classifier)
}
pub fn to_writer(
&self,
writer: &mut buffalo::Writer,
) -> buffalo::Position<crate::serialize::MulticlassClassifierWriter> {
crate::serialize::serialize_multiclass_classifier(self, writer)
}
pub fn from_bytes(&self, bytes: &[u8]) -> MulticlassClassifier {
let reader = buffalo::read::<crate::serialize::MulticlassClassifierReader>(bytes);
Self::from_reader(reader)
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut writer = buffalo::Writer::new();
self.to_writer(&mut writer);
writer.into_bytes()
}
}
pub fn update_logits(
trees_for_round: &[TrainTree],
binned_features: ArrayView2<TableValue>,
mut predictions: ArrayViewMut2<f32>,
) {
let features_rows = binned_features.axis_iter(Axis(0));
let logits_rows = predictions.axis_iter_mut(Axis(1));
for (features, mut logits) in zip!(features_rows, logits_rows) {
for (logit, tree) in zip!(logits.iter_mut(), trees_for_round.iter()) {
*logit += tree.predict(features.as_slice().unwrap());
}
}
}
pub fn compute_loss(logits: ArrayView2<f32>, labels: ArrayView1<Option<NonZeroUsize>>) -> f32 {
let mut loss = 0.0;
for (label, logits) in zip!(labels.into_iter(), logits.axis_iter(Axis(0))) {
let mut probabilities = logits.to_owned();
softmax(probabilities.as_slice_mut().unwrap());
for (index, &probability) in probabilities.indexed_iter() {
let probability = clamp(probability, std::f32::EPSILON, 1.0 - std::f32::EPSILON);
if index == (label.unwrap().get() - 1) {
loss += -probability.ln();
}
}
}
loss / labels.len().to_f32().unwrap()
}
pub fn compute_biases(
labels: ArrayView1<Option<NonZeroUsize>>,
n_trees_per_round: usize,
) -> Array1<f32> {
let mut biases: Array1<f32> = Array::zeros(n_trees_per_round);
for label in labels {
let label = label.unwrap().get() - 1;
biases[label] += 1.0;
}
let n_examples = labels.len().to_f32().unwrap();
for bias in biases.iter_mut() {
let proba = *bias / n_examples;
let clamped_proba = clamp(proba, std::f32::EPSILON, 1.0 - std::f32::EPSILON);
*bias = clamped_proba.ln();
}
biases
}
pub fn compute_gradients_and_hessians(
class_index: usize,
gradients: &mut [f32],
hessians: &mut [f32],
labels: &[Option<NonZeroUsize>],
logits: ArrayView2<f32>,
) {
pzip!(gradients, hessians, logits.axis_iter(Axis(0)), labels).for_each(
|(gradient, hessian, logits, label)| {
let max = logits.iter().fold(std::f32::MIN, |a, &b| f32::max(a, b));
let mut sum = 0.0;
for logit in logits.iter() {
sum += (*logit - max).exp();
}
let prediction = (logits[class_index] - max).exp() / sum;
let label = label.unwrap().get() - 1;
let label = if label == class_index { 1.0 } else { 0.0 };
*gradient = prediction - label;
*hessian = prediction * (1.0 - prediction);
},
);
}
fn softmax(logits: &mut [f32]) {
let max = logits.iter().fold(std::f32::MIN, |a, &b| f32::max(a, b));
for logit in logits.iter_mut() {
*logit = (*logit - max).exp();
}
let sum = logits.iter().sum::<f32>();
for logit in logits.iter_mut() {
*logit /= sum;
}
}