use super::{
shap::{compute_shap_values_for_example, ComputeShapValuesForExampleOutput},
train_early_stopping_split, EarlyStoppingMonitor, Progress, TrainOptions, TrainProgressEvent,
};
use ndarray::{self, prelude::*};
use num::{clamp, ToPrimitive};
use rayon::{self, prelude::*};
use std::{num::NonZeroUsize, ops::Neg};
use tangram_metrics::{BinaryCrossEntropy, BinaryCrossEntropyInput};
use tangram_progress_counter::ProgressCounter;
use tangram_table::prelude::*;
use tangram_zip::{pzip, zip};
#[derive(Clone, Debug)]
pub struct BinaryClassifier {
pub bias: f32,
pub weights: Array1<f32>,
pub means: Vec<f32>,
}
pub struct BinaryClassifierTrainOutput {
pub model: BinaryClassifier,
pub losses: Option<Vec<f32>>,
pub feature_importances: Option<Vec<f32>>,
}
impl BinaryClassifier {
pub fn train(
features: ArrayView2<f32>,
labels: EnumTableColumnView,
train_options: &TrainOptions,
progress: Progress,
) -> BinaryClassifierTrainOutput {
let n_features = features.ncols();
let (features_train, labels_train, features_early_stopping, labels_early_stopping) =
train_early_stopping_split(
features,
labels.as_slice().into(),
train_options
.early_stopping_options
.as_ref()
.map(|o| o.early_stopping_fraction)
.unwrap_or(0.0),
);
let means = features_train
.axis_iter(Axis(1))
.map(|column| column.mean().unwrap())
.collect();
let mut model = BinaryClassifier {
bias: 0.0,
weights: <Array1<f32>>::zeros(n_features),
means,
};
let mut early_stopping_monitor =
train_options
.early_stopping_options
.as_ref()
.map(|early_stopping_options| {
EarlyStoppingMonitor::new(
early_stopping_options.min_decrease_in_loss_for_significant_change,
early_stopping_options.n_rounds_without_improvement_to_stop,
)
});
let progress_counter = ProgressCounter::new(train_options.max_epochs.to_u64().unwrap());
(progress.handle_progress_event)(TrainProgressEvent::Train(progress_counter.clone()));
let mut probabilities_buffer: Array1<f32> = Array1::zeros(labels.len());
let mut losses = if train_options.compute_losses {
Some(Vec::new())
} else {
None
};
let kill_chip = progress.kill_chip;
for _ in 0..train_options.max_epochs {
progress_counter.inc(1);
let n_examples_per_batch = train_options.n_examples_per_batch;
struct BinaryClassifierPtr(*mut BinaryClassifier);
unsafe impl Send for BinaryClassifierPtr {}
unsafe impl Sync for BinaryClassifierPtr {}
let model_ptr = BinaryClassifierPtr(&mut model);
pzip!(
features_train.axis_chunks_iter(Axis(0), n_examples_per_batch),
labels_train.axis_chunks_iter(Axis(0), n_examples_per_batch),
probabilities_buffer.axis_chunks_iter_mut(Axis(0), n_examples_per_batch),
)
.for_each(|(features, labels, probabilities)| {
let model = unsafe { &mut *model_ptr.0 };
BinaryClassifier::train_batch(
model,
features,
labels,
probabilities,
train_options,
kill_chip,
);
});
if let Some(losses) = &mut losses {
let loss =
BinaryClassifier::compute_loss(probabilities_buffer.view(), labels_train);
losses.push(loss);
}
if let Some(early_stopping_monitor) = early_stopping_monitor.as_mut() {
let early_stopping_metric_value =
BinaryClassifier::compute_early_stopping_metric_value(
&model,
features_early_stopping,
labels_early_stopping,
train_options,
);
let should_stop = early_stopping_monitor.update(early_stopping_metric_value);
if should_stop {
break;
}
}
if progress.kill_chip.is_activated() {
break;
}
}
(progress.handle_progress_event)(TrainProgressEvent::TrainDone);
let feature_importances = BinaryClassifier::compute_feature_importances(&model);
BinaryClassifierTrainOutput {
model,
losses,
feature_importances: Some(feature_importances),
}
}
fn compute_feature_importances(model: &BinaryClassifier) -> Vec<f32> {
let mut feature_importances = model
.weights
.iter()
.map(|weight| weight.abs())
.collect::<Vec<_>>();
let feature_importances_sum = feature_importances.iter().sum::<f32>();
feature_importances
.iter_mut()
.for_each(|feature_importance| *feature_importance /= feature_importances_sum);
feature_importances
}
fn train_batch(
&mut self,
features: ArrayView2<f32>,
labels: ArrayView1<Option<NonZeroUsize>>,
mut probabilities: ArrayViewMut1<f32>,
train_options: &TrainOptions,
kill_chip: &tangram_kill_chip::KillChip,
) {
if kill_chip.is_activated() {
return;
}
let learning_rate = train_options.learning_rate;
let mut py = features.dot(&self.weights) + self.bias;
for (probability, py) in zip!(probabilities.iter_mut(), py.iter_mut()) {
*probability = 1.0 / (py.neg().exp() + 1.0);
*py = *probability;
}
for (py, label) in zip!(py.view_mut(), labels) {
let label = match label.map(|l| l.get()) {
Some(1) => 0.0,
Some(2) => 1.0,
_ => unreachable!(),
};
*py -= label
}
let py = py.insert_axis(Axis(1));
let weight_gradients = (&features * &py).mean_axis(Axis(0)).unwrap();
let bias_gradient = py.mean_axis(Axis(0)).unwrap()[0];
for (weight, weight_gradient) in zip!(self.weights.view_mut(), weight_gradients.view()) {
*weight += -learning_rate * weight_gradient;
}
self.bias += -learning_rate * bias_gradient;
}
pub fn compute_loss(
probabilities: ArrayView1<f32>,
labels: ArrayView1<Option<NonZeroUsize>>,
) -> f32 {
let mut total = 0.0;
for (label, probability) in zip!(labels.iter(), probabilities) {
let label = (label.unwrap().get() - 1).to_f32().unwrap();
let probability_clamped =
clamp(*probability, std::f32::EPSILON, 1.0 - std::f32::EPSILON);
total += -1.0 * label * probability_clamped.ln()
+ -1.0 * (1.0 - label) * (1.0 - probability_clamped).ln()
}
total / labels.len().to_f32().unwrap()
}
fn compute_early_stopping_metric_value(
&self,
features: ArrayView2<f32>,
labels: ArrayView1<Option<NonZeroUsize>>,
train_options: &TrainOptions,
) -> f32 {
pzip!(
features.axis_chunks_iter(Axis(0), train_options.n_examples_per_batch),
labels.axis_chunks_iter(Axis(0), train_options.n_examples_per_batch),
)
.fold(
|| {
let predictions = unsafe {
<Array1<f32>>::uninit(train_options.n_examples_per_batch).assume_init()
};
let metric = BinaryCrossEntropy::new();
(predictions, metric)
},
|(mut predictions, mut metric), (features, labels)| {
let slice = s![0..features.nrows()];
let mut predictions_slice = predictions.slice_mut(slice);
self.predict(features, predictions_slice.view_mut());
for (prediction, label) in zip!(predictions_slice.iter(), labels.iter()) {
metric.update(BinaryCrossEntropyInput {
probability: *prediction,
label: *label,
});
}
(predictions, metric)
},
)
.map(|(_, metric)| metric)
.reduce(BinaryCrossEntropy::new, |mut a, b| {
a.merge(b);
a
})
.finalize()
.unwrap()
}
pub fn predict(&self, features: ArrayView2<f32>, mut probabilities: ArrayViewMut1<f32>) {
probabilities.fill(self.bias);
ndarray::linalg::general_mat_vec_mul(
1.0,
&features,
&self.weights,
1.0,
&mut probabilities,
);
for probability in probabilities.iter_mut() {
*probability = 1.0 / (probability.neg().exp() + 1.0);
}
}
pub fn compute_feature_contributions(
&self,
features: ArrayView2<f32>,
) -> Vec<ComputeShapValuesForExampleOutput> {
features
.axis_iter(Axis(0))
.map(|features| {
compute_shap_values_for_example(
features.as_slice().unwrap(),
self.bias,
self.weights.view(),
&self.means,
)
})
.collect()
}
pub fn from_reader(
binary_classifier: crate::serialize::BinaryClassifierReader,
) -> BinaryClassifier {
crate::serialize::deserialize_binary_classifier(binary_classifier)
}
pub fn to_writer(
&self,
writer: &mut buffalo::Writer,
) -> buffalo::Position<crate::serialize::BinaryClassifierWriter> {
crate::serialize::serialize_binary_classifier(self, writer)
}
pub fn from_bytes(&self, bytes: &[u8]) -> BinaryClassifier {
let reader = buffalo::read::<crate::serialize::BinaryClassifierReader>(bytes);
Self::from_reader(reader)
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut writer = buffalo::Writer::new();
self.to_writer(&mut writer);
writer.into_bytes()
}
}