use crate::Progress;
use super::{
shap::{compute_shap_values_for_example, ComputeShapValuesForExampleOutput},
train_early_stopping_split, EarlyStoppingMonitor, TrainOptions, TrainProgressEvent,
};
use ndarray::{self, prelude::*};
use num::{clamp, ToPrimitive};
use rayon::{self, prelude::*};
use std::num::NonZeroUsize;
use tangram_metrics::{CrossEntropy, CrossEntropyInput};
use tangram_progress_counter::ProgressCounter;
use tangram_table::prelude::*;
use tangram_zip::{pzip, zip};
#[derive(Clone, Debug)]
pub struct MulticlassClassifier {
pub biases: Array1<f32>,
pub weights: Array2<f32>,
pub means: Vec<f32>,
}
pub struct MulticlassClassifierTrainOutput {
pub model: MulticlassClassifier,
pub losses: Option<Vec<f32>>,
pub feature_importances: Option<Vec<f32>>,
}
impl MulticlassClassifier {
pub fn train(
features: ArrayView2<f32>,
labels: EnumTableColumnView,
train_options: &TrainOptions,
progress: Progress,
) -> MulticlassClassifierTrainOutput {
let n_classes = labels.variants().len();
let n_features = features.ncols();
let (features_train, labels_train, features_early_stopping, labels_early_stopping) =
train_early_stopping_split(
features,
labels.as_slice().into(),
train_options
.early_stopping_options
.as_ref()
.map(|o| o.early_stopping_fraction)
.unwrap_or(0.0),
);
let means = features_train
.axis_iter(Axis(1))
.map(|column| column.mean().unwrap())
.collect();
let mut model = MulticlassClassifier {
biases: <Array1<f32>>::zeros(n_classes),
weights: <Array2<f32>>::zeros((n_features, n_classes)),
means,
};
let mut early_stopping_monitor =
train_options
.early_stopping_options
.as_ref()
.map(|early_stopping_options| {
EarlyStoppingMonitor::new(
early_stopping_options.min_decrease_in_loss_for_significant_change,
early_stopping_options.n_rounds_without_improvement_to_stop,
)
});
let progress_counter = ProgressCounter::new(train_options.max_epochs.to_u64().unwrap());
(progress.handle_progress_event)(TrainProgressEvent::Train(progress_counter.clone()));
let mut probabilities_buffer: Array2<f32> = Array2::zeros((labels.len(), n_classes));
let mut losses = if train_options.compute_losses {
Some(Vec::new())
} else {
None
};
let kill_chip = progress.kill_chip;
for _ in 0..train_options.max_epochs {
progress_counter.inc(1);
let n_examples_per_batch = train_options.n_examples_per_batch;
struct MulticlassClassifierPtr(*mut MulticlassClassifier);
unsafe impl Send for MulticlassClassifierPtr {}
unsafe impl Sync for MulticlassClassifierPtr {}
let model_ptr = MulticlassClassifierPtr(&mut model);
pzip!(
features_train.axis_chunks_iter(Axis(0), n_examples_per_batch),
labels_train.axis_chunks_iter(Axis(0), n_examples_per_batch),
probabilities_buffer.axis_chunks_iter_mut(Axis(0), n_examples_per_batch),
)
.for_each(|(features, labels, probabilities)| {
let model = unsafe { &mut *model_ptr.0 };
MulticlassClassifier::train_batch(
model,
features,
labels,
probabilities,
train_options,
kill_chip,
);
});
if let Some(losses) = &mut losses {
let loss =
MulticlassClassifier::compute_loss(probabilities_buffer.view(), labels_train);
losses.push(loss);
}
if let Some(early_stopping_monitor) = early_stopping_monitor.as_mut() {
let early_stopping_metric_value =
MulticlassClassifier::compute_early_stopping_metric_value(
&model,
features_early_stopping,
labels_early_stopping,
train_options,
);
let should_stop = early_stopping_monitor.update(early_stopping_metric_value);
if should_stop {
break;
}
}
if progress.kill_chip.is_activated() {
break;
}
}
(progress.handle_progress_event)(TrainProgressEvent::TrainDone);
let feature_importances = MulticlassClassifier::compute_feature_importances(&model);
MulticlassClassifierTrainOutput {
model,
losses,
feature_importances: Some(feature_importances),
}
}
fn compute_feature_importances(model: &MulticlassClassifier) -> Vec<f32> {
let mut feature_importances = model
.weights
.axis_iter(Axis(0))
.map(|weights_each_class| {
weights_each_class
.iter()
.map(|weight| weight.abs())
.sum::<f32>() / model.weights.ncols().to_f32().unwrap()
})
.collect::<Vec<_>>();
let feature_importances_sum: f32 = feature_importances.iter().sum::<f32>();
feature_importances
.iter_mut()
.for_each(|feature_importance| *feature_importance /= feature_importances_sum);
feature_importances
}
fn train_batch(
&mut self,
features: ArrayView2<f32>,
labels: ArrayView1<Option<NonZeroUsize>>,
mut probabilities: ArrayViewMut2<f32>,
train_options: &TrainOptions,
kill_chip: &tangram_kill_chip::KillChip,
) {
if kill_chip.is_activated() {
return;
}
let learning_rate = train_options.learning_rate;
let n_classes = self.weights.ncols();
let mut logits = features.dot(&self.weights) + &self.biases;
softmax(logits.view_mut());
for (probability, logit) in zip!(probabilities.iter_mut(), logits.iter()) {
*probability = *logit;
}
let mut predictions = logits;
for (mut predictions, label) in zip!(predictions.axis_iter_mut(Axis(0)), labels) {
for (class_index, prediction) in predictions.iter_mut().enumerate() {
*prediction -= if class_index == label.unwrap().get() - 1 {
1.0
} else {
0.0
};
}
}
let py = predictions;
for class_index in 0..n_classes {
let weight_gradients = (&features * &py.column(class_index).insert_axis(Axis(1)))
.mean_axis(Axis(0))
.unwrap();
for (weight, weight_gradient) in zip!(
self.weights.column_mut(class_index),
weight_gradients.iter()
) {
*weight += -learning_rate * weight_gradient
}
let bias_gradients = py
.column(class_index)
.insert_axis(Axis(1))
.mean_axis(Axis(0))
.unwrap();
self.biases[class_index] += -learning_rate * bias_gradients[0];
}
}
pub fn compute_loss(
probabilities: ArrayView2<f32>,
labels: ArrayView1<Option<NonZeroUsize>>,
) -> f32 {
let mut loss = 0.0;
for (label, probabilities) in zip!(labels.into_iter(), probabilities.axis_iter(Axis(0))) {
for (index, &probability) in probabilities.indexed_iter() {
let probability = clamp(probability, std::f32::EPSILON, 1.0 - std::f32::EPSILON);
if index == (label.unwrap().get() - 1) {
loss += -probability.ln();
}
}
}
loss / labels.len().to_f32().unwrap()
}
fn compute_early_stopping_metric_value(
&self,
features: ArrayView2<f32>,
labels: ArrayView1<Option<NonZeroUsize>>,
train_options: &TrainOptions,
) -> f32 {
let n_classes = self.biases.len();
pzip!(
features.axis_chunks_iter(Axis(0), train_options.n_examples_per_batch),
labels.axis_chunks_iter(Axis(0), train_options.n_examples_per_batch),
)
.fold(
|| {
let predictions = unsafe {
<Array2<f32>>::uninit((train_options.n_examples_per_batch, n_classes))
.assume_init()
};
let metric = CrossEntropy::default();
(predictions, metric)
},
|(mut predictions, mut metric), (features, labels)| {
let slice = s![0..features.nrows(), ..];
let mut predictions_slice = predictions.slice_mut(slice);
self.predict(features, predictions_slice.view_mut());
for (prediction, label) in zip!(predictions_slice.axis_iter(Axis(0)), labels.iter())
{
metric.update(CrossEntropyInput {
probabilities: prediction,
label: *label,
});
}
(predictions, metric)
},
)
.map(|(_, metric)| metric)
.reduce(CrossEntropy::new, |mut a, b| {
a.merge(b);
a
})
.finalize()
.0
.unwrap()
}
pub fn predict(&self, features: ArrayView2<f32>, mut probabilities: ArrayViewMut2<f32>) {
for mut row in probabilities.axis_iter_mut(Axis(0)) {
row.assign(&self.biases.view());
}
ndarray::linalg::general_mat_mul(1.0, &features, &self.weights, 1.0, &mut probabilities);
softmax(probabilities);
}
pub fn compute_feature_contributions(
&self,
features: ArrayView2<f32>,
) -> Vec<Vec<ComputeShapValuesForExampleOutput>> {
features
.axis_iter(Axis(0))
.map(|features| {
zip!(self.weights.axis_iter(Axis(1)), self.biases.view())
.map(|(weights, bias)| {
compute_shap_values_for_example(
features.as_slice().unwrap(),
*bias,
weights.view(),
&self.means,
)
})
.collect()
})
.collect()
}
pub fn from_reader(
multiclass_classifier: crate::serialize::MulticlassClassifierReader,
) -> MulticlassClassifier {
crate::serialize::deserialize_multiclass_classifier(multiclass_classifier)
}
pub fn to_writer(
&self,
writer: &mut buffalo::Writer,
) -> buffalo::Position<crate::serialize::MulticlassClassifierWriter> {
crate::serialize::serialize_multiclass_classifier(self, writer)
}
pub fn from_bytes(&self, bytes: &[u8]) -> MulticlassClassifier {
let reader = buffalo::read::<crate::serialize::MulticlassClassifierReader>(bytes);
Self::from_reader(reader)
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut writer = buffalo::Writer::new();
self.to_writer(&mut writer);
writer.into_bytes()
}
}
fn softmax(mut logits: ArrayViewMut2<f32>) {
for mut logits in logits.axis_iter_mut(Axis(0)) {
let max = logits.iter().fold(std::f32::MIN, |a, &b| f32::max(a, b));
for logit in logits.iter_mut() {
*logit = (*logit - max).exp();
}
let sum = logits.iter().sum::<f32>();
for logit in logits.iter_mut() {
*logit /= sum;
}
}
}