use crate::{
shap::{compute_shap_values_for_example, ComputeShapValuesForExampleOutput},
train::{train, Task, TrainOutput},
train_tree::TrainTree,
Progress, TrainOptions, Tree,
};
use ndarray::prelude::*;
use num::{clamp, ToPrimitive};
use rayon::prelude::*;
use std::{num::NonZeroUsize, ops::Neg};
use tangram_table::prelude::*;
use tangram_zip::{pzip, zip};
#[derive(Clone, Debug)]
pub struct BinaryClassifier {
pub bias: f32,
pub trees: Vec<Tree>,
}
#[derive(Debug)]
pub struct BinaryClassifierTrainOutput {
pub model: BinaryClassifier,
pub losses: Option<Vec<f32>>,
pub feature_importances: Option<Vec<f32>>,
}
impl BinaryClassifier {
pub fn train(
features: TableView,
labels: EnumTableColumnView,
train_options: &TrainOptions,
progress: Progress,
) -> BinaryClassifierTrainOutput {
let task = Task::BinaryClassification;
let train_output = train(
task,
features,
TableColumnView::Enum(labels),
train_options,
progress,
);
match train_output {
TrainOutput::BinaryClassifier(train_output) => train_output,
_ => unreachable!(),
}
}
pub fn predict(&self, features: ArrayView2<TableValue>, mut probabilities: ArrayViewMut1<f32>) {
probabilities.fill(self.bias);
let probabilities = probabilities.as_slice_mut().unwrap();
for tree in self.trees.iter() {
zip!(features.axis_iter(Axis(0)), probabilities.iter_mut()).for_each(
|(example, logit)| {
*logit += tree.predict(example.as_slice().unwrap());
},
);
}
probabilities.iter_mut().for_each(|probability| {
*probability = 1.0 / (probability.neg().exp() + 1.0);
});
}
pub fn compute_feature_contributions(
&self,
features: ArrayView2<TableValue>,
) -> Vec<ComputeShapValuesForExampleOutput> {
let trees = ArrayView1::from_shape(self.trees.len(), &self.trees).unwrap();
features
.axis_iter(Axis(0))
.map(|features| {
compute_shap_values_for_example(features.as_slice().unwrap(), trees, self.bias)
})
.collect()
}
pub fn from_reader(
binary_classifier: crate::serialize::BinaryClassifierReader,
) -> BinaryClassifier {
crate::serialize::deserialize_binary_classifier(binary_classifier)
}
pub fn to_writer(
&self,
writer: &mut buffalo::Writer,
) -> buffalo::Position<crate::serialize::BinaryClassifierWriter> {
crate::serialize::serialize_binary_classifier(self, writer)
}
pub fn from_bytes(&self, bytes: &[u8]) -> BinaryClassifier {
let reader = buffalo::read::<crate::serialize::BinaryClassifierReader>(bytes);
Self::from_reader(reader)
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut writer = buffalo::Writer::new();
self.to_writer(&mut writer);
writer.into_bytes()
}
}
pub fn update_logits(
trees_for_round: &[TrainTree],
binned_features: ArrayView2<TableValue>,
mut predictions: ArrayViewMut2<f32>,
) {
for tree in trees_for_round {
for (prediction, features) in
zip!(predictions.iter_mut(), binned_features.axis_iter(Axis(0)))
{
*prediction += tree.predict(features.as_slice().unwrap());
}
}
}
pub fn compute_loss(logits: ArrayView2<f32>, labels: ArrayView1<Option<NonZeroUsize>>) -> f32 {
let mut total = 0.0;
for (label, logit) in zip!(labels.iter(), logits) {
let label = (label.unwrap().get() - 1).to_f32().unwrap();
let probability = 1.0 / (logit.neg().exp() + 1.0);
let probability_clamped = clamp(probability, std::f32::EPSILON, 1.0 - std::f32::EPSILON);
total += -1.0 * label * probability_clamped.ln()
+ -1.0 * (1.0 - label) * (1.0 - probability_clamped).ln()
}
total / labels.len().to_f32().unwrap()
}
pub fn compute_biases(labels: ArrayView1<Option<NonZeroUsize>>) -> Array1<f32> {
let pos_count = labels
.iter()
.map(|l| if l.unwrap().get() == 2 { 1 } else { 0 })
.sum::<usize>();
let neg_count = labels.len() - pos_count;
let log_odds = (pos_count.to_f32().unwrap() / neg_count.to_f32().unwrap()).ln();
arr1(&[log_odds])
}
pub fn compute_gradients_and_hessians(
gradients: &mut [f32],
hessians: &mut [f32],
labels: &[Option<NonZeroUsize>],
predictions: &[f32],
) {
pzip!(gradients, hessians, labels, predictions).for_each(
|(gradient, hessian, label, prediction)| {
let probability = clamp(
sigmoid(*prediction),
std::f32::EPSILON,
1.0 - std::f32::EPSILON,
);
*gradient = probability - (label.unwrap().get() - 1).to_f32().unwrap();
*hessian = probability * (1.0 - probability);
},
);
}
fn sigmoid(value: f32) -> f32 {
1.0 / (value.neg().exp() + 1.0)
}