#[cfg(feature = "timing")]
use crate::timing::Timing;
use crate::{
binary_classifier::{BinaryClassifier, BinaryClassifierTrainOutput},
compute_bin_stats::{BinStats, BinStatsEntry},
compute_binned_features::{
compute_binned_features_column_major, compute_binned_features_row_major,
},
compute_binning_instructions::compute_binning_instructions,
compute_feature_importances::compute_feature_importances,
multiclass_classifier::{MulticlassClassifier, MulticlassClassifierTrainOutput},
pool::Pool,
regressor::{Regressor, RegressorTrainOutput},
train_tree::{
train_tree, TrainBranchNode, TrainBranchSplit, TrainBranchSplitContinuous,
TrainBranchSplitDiscrete, TrainLeafNode, TrainNode, TrainTree, TrainTreeOptions,
},
BinnedFeaturesLayout, BranchNode, BranchSplit, BranchSplitContinuous, BranchSplitDiscrete,
LeafNode, Node, Progress, TrainOptions, TrainProgressEvent, Tree,
};
use ndarray::prelude::*;
use num::ToPrimitive;
use rayon::prelude::*;
use tangram_progress_counter::ProgressCounter;
use tangram_table::prelude::*;
#[derive(Clone, Copy, Debug)]
pub enum Task {
Regression,
BinaryClassification,
MulticlassClassification { n_classes: usize },
}
#[derive(Debug)]
pub enum TrainOutput {
Regressor(RegressorTrainOutput),
BinaryClassifier(BinaryClassifierTrainOutput),
MulticlassClassifier(MulticlassClassifierTrainOutput),
}
pub fn train(
task: Task,
features: TableView,
labels: TableColumnView,
train_options: &TrainOptions,
progress: Progress,
) -> TrainOutput {
#[cfg(feature = "timing")]
let timing = Timing::new();
let early_stopping_enabled = train_options.early_stopping_options.is_some();
let (
features_train,
labels_train,
features_early_stopping,
labels_early_stopping,
mut early_stopping_monitor,
) = if let Some(early_stopping_options) = &train_options.early_stopping_options {
let (features_train, labels_train, features_early_stopping, labels_early_stopping) =
train_early_stopping_split(
features,
labels,
early_stopping_options.early_stopping_fraction,
);
let early_stopping_monitor = EarlyStoppingMonitor::new(
early_stopping_options.min_decrease_in_loss_for_significant_change,
early_stopping_options.n_rounds_without_improvement_to_stop,
);
(
features_train,
labels_train,
Some(features_early_stopping),
Some(labels_early_stopping),
Some(early_stopping_monitor),
)
} else {
(features, labels, None, None, None)
};
let n_features = features_train.ncols();
let n_examples_train = features_train.nrows();
#[cfg(feature = "timing")]
let start = std::time::Instant::now();
let binning_instructions = compute_binning_instructions(&features_train, train_options);
#[cfg(feature = "timing")]
timing.compute_binning_instructions.inc(start.elapsed());
let binned_features_layout = train_options.binned_features_layout;
let progress_counter = ProgressCounter::new(features_train.nrows().to_u64().unwrap());
(progress.handle_progress_event)(TrainProgressEvent::Initialize(progress_counter.clone()));
#[cfg(feature = "timing")]
let start = std::time::Instant::now();
let compute_binned_features_column_major_output = compute_binned_features_column_major(
&features_train,
&binning_instructions,
train_options,
&|| progress_counter.inc(1),
);
let features_train = features_train
.view_columns(&compute_binned_features_column_major_output.used_feature_indexes);
let features_early_stopping = features_early_stopping
.as_ref()
.map(|features_early_stopping| {
features_early_stopping
.view_columns(&compute_binned_features_column_major_output.used_feature_indexes)
.to_rows()
});
let used_features_binning_instructions = compute_binned_features_column_major_output
.used_feature_indexes
.iter()
.map(|original_feature_index| binning_instructions[*original_feature_index].clone())
.collect::<Vec<_>>();
let binned_features_row_major =
if let BinnedFeaturesLayout::RowMajor = train_options.binned_features_layout {
Some(compute_binned_features_row_major(
&features_train,
&used_features_binning_instructions,
&|| progress_counter.inc(1),
))
} else {
None
};
#[cfg(feature = "timing")]
timing.compute_binned_features.inc(start.elapsed());
let n_trees_per_round = match task {
Task::Regression => 1,
Task::BinaryClassification => 1,
Task::MulticlassClassification { n_classes } => n_classes,
};
let hessians_are_constant = match task {
Task::Regression => true,
Task::BinaryClassification => false,
Task::MulticlassClassification { .. } => false,
};
let biases = match task {
Task::Regression => {
let labels_train = labels_train.as_number().unwrap();
let labels_train = labels_train.as_slice().into();
crate::regressor::compute_biases(labels_train)
}
Task::BinaryClassification => {
let labels_train = labels_train.as_enum().unwrap();
let labels_train = labels_train.as_slice().into();
crate::binary_classifier::compute_biases(labels_train)
}
Task::MulticlassClassification { .. } => {
let labels_train = labels_train.as_enum().unwrap();
let labels_train = labels_train.as_slice().into();
crate::multiclass_classifier::compute_biases(labels_train, n_trees_per_round)
}
};
let mut predictions =
unsafe { Array::uninit((n_examples_train, n_trees_per_round).f()).assume_init() };
let mut gradients = unsafe { Array::uninit(n_examples_train).assume_init() };
let mut hessians = unsafe { Array::uninit(n_examples_train).assume_init() };
let mut gradients_ordered_buffer = unsafe { Array::uninit(n_examples_train).assume_init() };
let mut hessians_ordered_buffer = unsafe { Array::uninit(n_examples_train).assume_init() };
let mut examples_index = unsafe { Array::uninit(n_examples_train).assume_init() };
let mut examples_index_left_buffer = unsafe { Array::uninit(n_examples_train).assume_init() };
let mut examples_index_right_buffer = unsafe { Array::uninit(n_examples_train).assume_init() };
let mut predictions_early_stopping = if early_stopping_enabled {
let mut predictions_early_stopping = unsafe {
Array::uninit((
n_trees_per_round,
labels_early_stopping.as_ref().unwrap().len(),
))
.assume_init()
};
for mut predictions in predictions_early_stopping.axis_iter_mut(Axis(1)) {
predictions.assign(&biases);
}
Some(predictions_early_stopping)
} else {
None
};
let binning_instructions_for_pool = used_features_binning_instructions.clone();
let bin_stats_pool = match binned_features_layout {
BinnedFeaturesLayout::ColumnMajor => Pool::new(
train_options.max_leaf_nodes,
Box::new(move || {
BinStats::ColumnMajor(
binning_instructions_for_pool
.iter()
.map(|binning_instructions| {
vec![BinStatsEntry::default(); binning_instructions.n_bins()]
})
.collect(),
)
}),
),
BinnedFeaturesLayout::RowMajor => Pool::new(
train_options.max_leaf_nodes,
Box::new(move || {
BinStats::RowMajor(
binning_instructions_for_pool
.iter()
.flat_map(|binning_instructions| {
vec![BinStatsEntry::default(); binning_instructions.n_bins()]
})
.collect(),
)
}),
),
};
let mut n_rounds_trained = 0;
let mut trees: Vec<TrainTree> = Vec::new();
let mut losses: Option<Vec<f32>> = if train_options.compute_losses {
Some(Vec::new())
} else {
None
};
for mut predictions in predictions.axis_iter_mut(Axis(0)) {
predictions.assign(&biases)
}
(progress.handle_progress_event)(TrainProgressEvent::InitializeDone);
let round_counter = ProgressCounter::new(train_options.max_rounds.to_u64().unwrap());
(progress.handle_progress_event)(TrainProgressEvent::Train(round_counter.clone()));
for _ in 0..train_options.max_rounds {
round_counter.inc(1);
let mut trees_for_round = Vec::with_capacity(n_trees_per_round);
for tree_per_round_index in 0..n_trees_per_round {
#[cfg(feature = "timing")]
let start = std::time::Instant::now();
match task {
Task::Regression => {
let labels_train = labels_train.as_number().unwrap();
crate::regressor::compute_gradients_and_hessians(
gradients.as_slice_mut().unwrap(),
hessians.as_slice_mut().unwrap(),
labels_train.as_slice(),
predictions.column(0).as_slice().unwrap(),
);
}
Task::BinaryClassification => {
let labels_train = labels_train.as_enum().unwrap();
crate::binary_classifier::compute_gradients_and_hessians(
gradients.as_slice_mut().unwrap(),
hessians.as_slice_mut().unwrap(),
labels_train.as_slice(),
predictions.column(0).as_slice().unwrap(),
);
}
Task::MulticlassClassification { .. } => {
let labels_train = labels_train.as_enum().unwrap();
crate::multiclass_classifier::compute_gradients_and_hessians(
tree_per_round_index,
gradients.as_slice_mut().unwrap(),
hessians.as_slice_mut().unwrap(),
labels_train.as_slice(),
predictions.view(),
);
}
};
#[cfg(feature = "timing")]
timing.compute_gradients_and_hessians.inc(start.elapsed());
examples_index
.as_slice_mut()
.unwrap()
.par_iter_mut()
.enumerate()
.for_each(|(index, value)| {
*value = index.to_u32().unwrap();
});
let tree = train_tree(TrainTreeOptions {
binning_instructions: &used_features_binning_instructions,
binned_features_row_major: &binned_features_row_major,
binned_features_column_major: &compute_binned_features_column_major_output
.binned_features,
gradients: gradients.as_slice().unwrap(),
hessians: hessians.as_slice().unwrap(),
gradients_ordered_buffer: gradients_ordered_buffer.as_slice_mut().unwrap(),
hessians_ordered_buffer: hessians_ordered_buffer.as_slice_mut().unwrap(),
examples_index: examples_index.as_slice_mut().unwrap(),
examples_index_left_buffer: examples_index_left_buffer.as_slice_mut().unwrap(),
examples_index_right_buffer: examples_index_right_buffer.as_slice_mut().unwrap(),
bin_stats_pool: &bin_stats_pool,
hessians_are_constant,
train_options,
#[cfg(feature = "timing")]
timing: &timing,
});
update_predictions_with_tree(
predictions
.column_mut(tree_per_round_index)
.as_slice_mut()
.unwrap(),
examples_index.as_slice().unwrap(),
&tree,
#[cfg(feature = "timing")]
&timing,
);
trees_for_round.push(tree);
}
if let Some(losses) = losses.as_mut() {
let loss = match task {
Task::Regression => {
let labels_train = labels_train.as_number().unwrap();
let labels_train = labels_train.as_slice().into();
crate::regressor::compute_loss(predictions.view(), labels_train)
}
Task::BinaryClassification => {
let labels_train = labels_train.as_enum().unwrap();
let labels_train = labels_train.as_slice().into();
crate::binary_classifier::compute_loss(predictions.view(), labels_train)
}
Task::MulticlassClassification { .. } => {
let labels_train = labels_train.as_enum().unwrap();
let labels_train = labels_train.as_slice().into();
crate::multiclass_classifier::compute_loss(predictions.view(), labels_train)
}
};
losses.push(loss);
}
let should_stop = if early_stopping_enabled {
let features_early_stopping = features_early_stopping.as_ref().unwrap();
let labels_early_stopping = labels_early_stopping.as_ref().unwrap();
let predictions_early_stopping = predictions_early_stopping.as_mut().unwrap();
let early_stopping_monitor = early_stopping_monitor.as_mut().unwrap();
let value = compute_early_stopping_metric(
&task,
trees_for_round.as_slice(),
features_early_stopping.view(),
labels_early_stopping.view(),
predictions_early_stopping.view_mut(),
);
early_stopping_monitor.update(value)
} else {
false
};
trees.extend(trees_for_round);
n_rounds_trained += 1;
if should_stop {
break;
}
if progress.kill_chip.is_activated() {
break;
}
}
(progress.handle_progress_event)(TrainProgressEvent::TrainDone);
let feature_importances = Some(compute_feature_importances(&trees, n_features));
#[cfg(feature = "timing")]
eprintln!("{:?}", timing);
let trees: Vec<Tree> = trees
.into_iter()
.map(|train_tree| {
tree_from_train_tree(
train_tree,
compute_binned_features_column_major_output
.used_feature_indexes
.as_slice(),
)
})
.collect();
match task {
Task::Regression => TrainOutput::Regressor(RegressorTrainOutput {
model: Regressor {
bias: *biases.get(0).unwrap(),
trees,
},
feature_importances,
losses,
}),
Task::BinaryClassification => TrainOutput::BinaryClassifier(BinaryClassifierTrainOutput {
model: BinaryClassifier {
bias: *biases.get(0).unwrap(),
trees,
},
feature_importances,
losses,
}),
Task::MulticlassClassification { .. } => {
let trees =
Array2::from_shape_vec((n_rounds_trained, n_trees_per_round), trees).unwrap();
TrainOutput::MulticlassClassifier(MulticlassClassifierTrainOutput {
model: MulticlassClassifier { biases, trees },
feature_importances,
losses,
})
}
}
}
fn update_predictions_with_tree(
predictions: &mut [f32],
examples_index: &[u32],
tree: &TrainTree,
#[cfg(feature = "timing")] timing: &Timing,
) {
#[cfg(feature = "timing")]
let start = std::time::Instant::now();
struct PredictionsPtr(*mut [f32]);
unsafe impl Send for PredictionsPtr {}
unsafe impl Sync for PredictionsPtr {}
let predictions_ptr = PredictionsPtr(predictions);
tree.leaf_values.par_iter().for_each(|(range, value)| {
examples_index[range.clone()]
.iter()
.for_each(|example_index| unsafe {
let predictions = &mut *predictions_ptr.0;
let example_index = example_index.to_usize().unwrap();
*predictions.get_unchecked_mut(example_index) += *value as f32;
});
});
#[cfg(feature = "timing")]
timing.update_predictions.inc(start.elapsed());
}
#[derive(Clone)]
pub struct EarlyStoppingMonitor {
tolerance: f32,
max_rounds_no_improve: usize,
previous_stopping_metric: Option<f32>,
num_rounds_no_improve: usize,
}
impl EarlyStoppingMonitor {
pub fn new(tolerance: f32, max_rounds_no_improve: usize) -> EarlyStoppingMonitor {
EarlyStoppingMonitor {
tolerance,
max_rounds_no_improve,
previous_stopping_metric: None,
num_rounds_no_improve: 0,
}
}
pub fn update(&mut self, value: f32) -> bool {
let stopping_metric = value;
let result = if let Some(previous_stopping_metric) = self.previous_stopping_metric {
if stopping_metric > previous_stopping_metric
|| f32::abs(stopping_metric - previous_stopping_metric) < self.tolerance
{
self.num_rounds_no_improve += 1;
self.num_rounds_no_improve >= self.max_rounds_no_improve
} else {
self.num_rounds_no_improve = 0;
false
}
} else {
false
};
self.previous_stopping_metric = Some(stopping_metric);
result
}
}
fn train_early_stopping_split<'features, 'labels>(
features: TableView<'features>,
labels: TableColumnView<'labels>,
early_stopping_fraction: f32,
) -> (
TableView<'features>,
TableColumnView<'labels>,
TableView<'features>,
TableColumnView<'labels>,
) {
let split_index = (early_stopping_fraction * labels.len().to_f32().unwrap())
.to_usize()
.unwrap();
let (features_early_stopping, features_train) = features.split_at_row(split_index);
let (labels_early_stopping, labels_train) = labels.split_at_row(split_index);
(
features_train,
labels_train,
features_early_stopping,
labels_early_stopping,
)
}
fn compute_early_stopping_metric(
task: &Task,
trees_for_round: &[TrainTree],
features: ArrayView2<TableValue>,
labels: TableColumnView,
mut predictions: ArrayViewMut2<f32>,
) -> f32 {
match task {
Task::Regression => {
let labels = labels.as_number().unwrap();
let labels = labels.as_slice().into();
crate::regressor::update_logits(
trees_for_round,
features.view(),
predictions.view_mut(),
);
crate::regressor::compute_loss(predictions.view(), labels)
}
Task::BinaryClassification => {
let labels = labels.as_enum().unwrap();
let labels = labels.as_slice().into();
crate::binary_classifier::update_logits(
trees_for_round,
features.view(),
predictions.view_mut(),
);
crate::binary_classifier::compute_loss(predictions.view(), labels)
}
Task::MulticlassClassification { .. } => {
let labels = labels.as_enum().unwrap();
let labels = labels.as_slice().into();
crate::multiclass_classifier::update_logits(
trees_for_round,
features.view(),
predictions.view_mut(),
);
crate::multiclass_classifier::compute_loss(predictions.view(), labels)
}
}
}
fn tree_from_train_tree(
train_tree: TrainTree,
train_feature_index_to_feature_index: &[usize],
) -> Tree {
let nodes = train_tree
.nodes
.into_iter()
.map(|node| node_from_train_node(node, train_feature_index_to_feature_index))
.collect();
Tree { nodes }
}
fn node_from_train_node(
train_node: TrainNode,
train_feature_index_to_feature_index: &[usize],
) -> Node {
match train_node {
TrainNode::Branch(TrainBranchNode {
left_child_index,
right_child_index,
split,
examples_fraction,
..
}) => Node::Branch(BranchNode {
left_child_index: left_child_index.unwrap(),
right_child_index: right_child_index.unwrap(),
split: match split {
TrainBranchSplit::Continuous(TrainBranchSplitContinuous {
feature_index,
invalid_values_direction,
split_value,
..
}) => BranchSplit::Continuous(BranchSplitContinuous {
feature_index: train_feature_index_to_feature_index[feature_index],
split_value,
invalid_values_direction,
}),
TrainBranchSplit::Discrete(TrainBranchSplitDiscrete {
feature_index,
directions,
..
}) => BranchSplit::Discrete(BranchSplitDiscrete {
feature_index: train_feature_index_to_feature_index[feature_index],
directions,
}),
},
examples_fraction,
}),
TrainNode::Leaf(TrainLeafNode {
value,
examples_fraction,
..
}) => Node::Leaf(LeafNode {
value,
examples_fraction,
}),
}
}