tangram_tree 0.5.0

Tangram is an automated machine learning framework designed for programmers.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
#[cfg(feature = "timing")]
use crate::timing::Timing;
use crate::{
	binary_classifier::{BinaryClassifier, BinaryClassifierTrainOutput},
	compute_bin_stats::{BinStats, BinStatsEntry},
	compute_binned_features::{
		compute_binned_features_column_major, compute_binned_features_row_major,
	},
	compute_binning_instructions::compute_binning_instructions,
	compute_feature_importances::compute_feature_importances,
	multiclass_classifier::{MulticlassClassifier, MulticlassClassifierTrainOutput},
	pool::Pool,
	regressor::{Regressor, RegressorTrainOutput},
	train_tree::{
		train_tree, TrainBranchNode, TrainBranchSplit, TrainBranchSplitContinuous,
		TrainBranchSplitDiscrete, TrainLeafNode, TrainNode, TrainTree, TrainTreeOptions,
	},
	BinnedFeaturesLayout, BranchNode, BranchSplit, BranchSplitContinuous, BranchSplitDiscrete,
	LeafNode, Node, Progress, TrainOptions, TrainProgressEvent, Tree,
};
use ndarray::prelude::*;
use num::ToPrimitive;
use rayon::prelude::*;
use tangram_progress_counter::ProgressCounter;
use tangram_table::prelude::*;

/// This enum is used by the common `train` function below to customize the training code slightly for each task.
#[derive(Clone, Copy, Debug)]
pub enum Task {
	Regression,
	BinaryClassification,
	MulticlassClassification { n_classes: usize },
}

/// This is the return type of the common `train` function.
#[derive(Debug)]
pub enum TrainOutput {
	Regressor(RegressorTrainOutput),
	BinaryClassifier(BinaryClassifierTrainOutput),
	MulticlassClassifier(MulticlassClassifierTrainOutput),
}

/// To avoid code duplication, this shared `train` function is called by `Regressor::train`, `BinaryClassifier::train`, and `MulticlassClassifier::train`.
pub fn train(
	task: Task,
	features: TableView,
	labels: TableColumnView,
	train_options: &TrainOptions,
	progress: Progress,
) -> TrainOutput {
	#[cfg(feature = "timing")]
	let timing = Timing::new();

	// If early stopping is enabled, split the features and labels into train and early stopping sets.
	let early_stopping_enabled = train_options.early_stopping_options.is_some();
	let (
		features_train,
		labels_train,
		features_early_stopping,
		labels_early_stopping,
		mut early_stopping_monitor,
	) = if let Some(early_stopping_options) = &train_options.early_stopping_options {
		let (features_train, labels_train, features_early_stopping, labels_early_stopping) =
			train_early_stopping_split(
				features,
				labels,
				early_stopping_options.early_stopping_fraction,
			);
		let early_stopping_monitor = EarlyStoppingMonitor::new(
			early_stopping_options.min_decrease_in_loss_for_significant_change,
			early_stopping_options.n_rounds_without_improvement_to_stop,
		);
		(
			features_train,
			labels_train,
			Some(features_early_stopping),
			Some(labels_early_stopping),
			Some(early_stopping_monitor),
		)
	} else {
		(features, labels, None, None, None)
	};

	let n_features = features_train.ncols();
	let n_examples_train = features_train.nrows();

	// Determine how to bin each feature.
	#[cfg(feature = "timing")]
	let start = std::time::Instant::now();
	let binning_instructions = compute_binning_instructions(&features_train, &train_options);
	#[cfg(feature = "timing")]
	timing.compute_binning_instructions.inc(start.elapsed());

	// Use the binning instructions from the previous step to compute the binned features.
	let binned_features_layout = train_options.binned_features_layout;
	let progress_counter = ProgressCounter::new(features_train.nrows().to_u64().unwrap());
	(progress.handle_progress_event)(TrainProgressEvent::Initialize(progress_counter.clone()));
	#[cfg(feature = "timing")]
	let start = std::time::Instant::now();
	let compute_binned_features_column_major_output = compute_binned_features_column_major(
		&features_train,
		&binning_instructions,
		train_options,
		&|| progress_counter.inc(1),
	);
	let features_train = features_train
		.view_columns(&compute_binned_features_column_major_output.used_feature_indexes);
	let features_early_stopping = features_early_stopping
		.as_ref()
		.map(|features_early_stopping| {
			features_early_stopping
				.view_columns(&compute_binned_features_column_major_output.used_feature_indexes)
				.to_rows()
		});

	let used_features_binning_instructions = compute_binned_features_column_major_output
		.used_feature_indexes
		.iter()
		.map(|original_feature_index| binning_instructions[*original_feature_index].clone())
		.collect::<Vec<_>>();
	let binned_features_row_major =
		if let BinnedFeaturesLayout::RowMajor = train_options.binned_features_layout {
			Some(compute_binned_features_row_major(
				&features_train,
				&used_features_binning_instructions,
				&|| progress_counter.inc(1),
			))
		} else {
			None
		};
	#[cfg(feature = "timing")]
	timing.compute_binned_features.inc(start.elapsed());

	// Regression and binary classification train one tree per round. Multiclass classification trains one tree per class per round.
	let n_trees_per_round = match task {
		Task::Regression => 1,
		Task::BinaryClassification => 1,
		Task::MulticlassClassification { n_classes } => n_classes,
	};

	// The mean square error loss used in regression has a constant second derivative, so there is no need to use hessians for regression tasks.
	let hessians_are_constant = match task {
		Task::Regression => true,
		Task::BinaryClassification => false,
		Task::MulticlassClassification { .. } => false,
	};

	// Compute the biases. A tree model's prediction will be a bias plus the sum of the outputs of each tree. The bias will produce the baseline prediction.
	let biases = match task {
		// For regression, the bias is the mean of the labels.
		Task::Regression => {
			let labels_train = labels_train.as_number().unwrap();
			let labels_train = labels_train.as_slice().into();
			crate::regressor::compute_biases(labels_train)
		}
		// For binary classification, the bias is the log of the ratio of positive examples to negative examples in the training set, so the baseline prediction is the majority class.
		Task::BinaryClassification => {
			let labels_train = labels_train.as_enum().unwrap();
			let labels_train = labels_train.as_slice().into();
			crate::binary_classifier::compute_biases(labels_train)
		}
		// For multiclass classification the biases are the logs of each class's proporation in the training set, so the baseline prediction is the majority class.
		Task::MulticlassClassification { .. } => {
			let labels_train = labels_train.as_enum().unwrap();
			let labels_train = labels_train.as_slice().into();
			crate::multiclass_classifier::compute_biases(labels_train, n_trees_per_round)
		}
	};

	// Pre-allocate memory to be used in training.
	let mut predictions =
		unsafe { Array::uninit((n_examples_train, n_trees_per_round).f()).assume_init() };
	let mut gradients = unsafe { Array::uninit(n_examples_train).assume_init() };
	let mut hessians = unsafe { Array::uninit(n_examples_train).assume_init() };
	let mut gradients_ordered_buffer = unsafe { Array::uninit(n_examples_train).assume_init() };
	let mut hessians_ordered_buffer = unsafe { Array::uninit(n_examples_train).assume_init() };
	let mut examples_index = unsafe { Array::uninit(n_examples_train).assume_init() };
	let mut examples_index_left_buffer = unsafe { Array::uninit(n_examples_train).assume_init() };
	let mut examples_index_right_buffer = unsafe { Array::uninit(n_examples_train).assume_init() };
	let mut predictions_early_stopping = if early_stopping_enabled {
		let mut predictions_early_stopping = unsafe {
			Array::uninit((
				n_trees_per_round,
				labels_early_stopping.as_ref().unwrap().len(),
			))
			.assume_init()
		};
		for mut predictions in predictions_early_stopping.axis_iter_mut(Axis(1)) {
			predictions.assign(&biases);
		}
		Some(predictions_early_stopping)
	} else {
		None
	};
	let binning_instructions_for_pool = used_features_binning_instructions.clone();
	let bin_stats_pool = match binned_features_layout {
		BinnedFeaturesLayout::ColumnMajor => Pool::new(
			train_options.max_leaf_nodes,
			Box::new(move || {
				BinStats::ColumnMajor(
					binning_instructions_for_pool
						.iter()
						.map(|binning_instructions| {
							vec![BinStatsEntry::default(); binning_instructions.n_bins()]
						})
						.collect(),
				)
			}),
		),
		BinnedFeaturesLayout::RowMajor => Pool::new(
			train_options.max_leaf_nodes,
			Box::new(move || {
				BinStats::RowMajor(
					binning_instructions_for_pool
						.iter()
						.flat_map(|binning_instructions| {
							vec![BinStatsEntry::default(); binning_instructions.n_bins()]
						})
						.collect(),
				)
			}),
		),
	};

	// This is the total number of rounds that have been trained thus far.
	let mut n_rounds_trained = 0;
	// These are the trees in round-major order. After training this will be converted to an array of shape (n_rounds, n_trees_per_round).
	let mut trees: Vec<TrainTree> = Vec::new();
	// Collect the loss on the training dataset for each round if enabled.
	let mut losses: Option<Vec<f32>> = if train_options.compute_losses {
		Some(Vec::new())
	} else {
		None
	};

	// Before the first round, fill the predictions with the biases, which are the baseline predictions.
	for mut predictions in predictions.axis_iter_mut(Axis(0)) {
		predictions.assign(&biases)
	}

	(progress.handle_progress_event)(TrainProgressEvent::InitializeDone);

	// Train rounds of trees until we hit max_rounds or the early stopping monitor indicates we should stop early.
	let round_counter = ProgressCounter::new(train_options.max_rounds.to_u64().unwrap());
	(progress.handle_progress_event)(TrainProgressEvent::Train(round_counter.clone()));
	for _ in 0..train_options.max_rounds {
		round_counter.inc(1);
		// Train n_trees_per_round trees.
		let mut trees_for_round = Vec::with_capacity(n_trees_per_round);
		for tree_per_round_index in 0..n_trees_per_round {
			// Before training the next tree, we need to determine what value for each example we would like the tree to learn.
			#[cfg(feature = "timing")]
			let start = std::time::Instant::now();
			match task {
				Task::Regression => {
					let labels_train = labels_train.as_number().unwrap();
					crate::regressor::compute_gradients_and_hessians(
						gradients.as_slice_mut().unwrap(),
						hessians.as_slice_mut().unwrap(),
						labels_train.as_slice(),
						predictions.column(0).as_slice().unwrap(),
					);
				}
				Task::BinaryClassification => {
					let labels_train = labels_train.as_enum().unwrap();
					crate::binary_classifier::compute_gradients_and_hessians(
						gradients.as_slice_mut().unwrap(),
						hessians.as_slice_mut().unwrap(),
						labels_train.as_slice(),
						predictions.column(0).as_slice().unwrap(),
					);
				}
				Task::MulticlassClassification { .. } => {
					let labels_train = labels_train.as_enum().unwrap();
					crate::multiclass_classifier::compute_gradients_and_hessians(
						tree_per_round_index,
						gradients.as_slice_mut().unwrap(),
						hessians.as_slice_mut().unwrap(),
						labels_train.as_slice(),
						predictions.view(),
					);
				}
			};
			#[cfg(feature = "timing")]
			timing.compute_gradients_and_hessians.inc(start.elapsed());
			// Reset the examples_index.
			examples_index
				.as_slice_mut()
				.unwrap()
				.par_iter_mut()
				.enumerate()
				.for_each(|(index, value)| {
					*value = index.to_u32().unwrap();
				});
			// Train the tree.
			let tree = train_tree(TrainTreeOptions {
				binning_instructions: &used_features_binning_instructions,
				binned_features_row_major: &binned_features_row_major,
				binned_features_column_major: &compute_binned_features_column_major_output
					.binned_features,
				gradients: gradients.as_slice().unwrap(),
				hessians: hessians.as_slice().unwrap(),
				gradients_ordered_buffer: gradients_ordered_buffer.as_slice_mut().unwrap(),
				hessians_ordered_buffer: hessians_ordered_buffer.as_slice_mut().unwrap(),
				examples_index: examples_index.as_slice_mut().unwrap(),
				examples_index_left_buffer: examples_index_left_buffer.as_slice_mut().unwrap(),
				examples_index_right_buffer: examples_index_right_buffer.as_slice_mut().unwrap(),
				bin_stats_pool: &bin_stats_pool,
				hessians_are_constant,
				train_options: &train_options,
				#[cfg(feature = "timing")]
				timing: &timing,
			});
			// Update the predictions using the leaf values from the tree.
			update_predictions_with_tree(
				predictions
					.column_mut(tree_per_round_index)
					.as_slice_mut()
					.unwrap(),
				examples_index.as_slice().unwrap(),
				&tree,
				#[cfg(feature = "timing")]
				&timing,
			);
			trees_for_round.push(tree);
		}
		// If loss computation is enabled, compute the loss for this round.
		if let Some(losses) = losses.as_mut() {
			let loss = match task {
				Task::Regression => {
					let labels_train = labels_train.as_number().unwrap();
					let labels_train = labels_train.as_slice().into();
					crate::regressor::compute_loss(predictions.view(), labels_train)
				}
				Task::BinaryClassification => {
					let labels_train = labels_train.as_enum().unwrap();
					let labels_train = labels_train.as_slice().into();
					crate::binary_classifier::compute_loss(predictions.view(), labels_train)
				}
				Task::MulticlassClassification { .. } => {
					let labels_train = labels_train.as_enum().unwrap();
					let labels_train = labels_train.as_slice().into();
					crate::multiclass_classifier::compute_loss(predictions.view(), labels_train)
				}
			};
			losses.push(loss);
		}
		// If early stopping is enabled, compute the early stopping metric and update the early stopping monitor to see if we should stop training at this round.
		let should_stop = if early_stopping_enabled {
			let features_early_stopping = features_early_stopping.as_ref().unwrap();
			let labels_early_stopping = labels_early_stopping.as_ref().unwrap();
			let predictions_early_stopping = predictions_early_stopping.as_mut().unwrap();
			let early_stopping_monitor = early_stopping_monitor.as_mut().unwrap();
			let value = compute_early_stopping_metric(
				&task,
				trees_for_round.as_slice(),
				features_early_stopping.view(),
				labels_early_stopping.view(),
				predictions_early_stopping.view_mut(),
			);
			early_stopping_monitor.update(value)
		} else {
			false
		};
		// Add the trees for this round to the list of trees.
		trees.extend(trees_for_round);
		n_rounds_trained += 1;
		// Exit the training loop if we should stop.
		if should_stop {
			break;
		}
		// Check if we should stop training.
		if progress.kill_chip.is_activated() {
			break;
		}
	}

	(progress.handle_progress_event)(TrainProgressEvent::TrainDone);

	// Compute the feature importances.
	let feature_importances = Some(compute_feature_importances(&trees, n_features));

	// Print out the timing and tree information if the timing feature is enabled.
	#[cfg(feature = "timing")]
	eprintln!("{:?}", timing);

	// Assemble the model.
	let trees: Vec<Tree> = trees
		.into_iter()
		.map(|train_tree| {
			tree_from_train_tree(
				train_tree,
				compute_binned_features_column_major_output
					.used_feature_indexes
					.as_slice(),
			)
		})
		.collect();
	match task {
		Task::Regression => TrainOutput::Regressor(RegressorTrainOutput {
			model: Regressor {
				bias: *biases.get(0).unwrap(),
				trees,
			},
			feature_importances,
			losses,
		}),
		Task::BinaryClassification => TrainOutput::BinaryClassifier(BinaryClassifierTrainOutput {
			model: BinaryClassifier {
				bias: *biases.get(0).unwrap(),
				trees,
			},
			feature_importances,
			losses,
		}),
		Task::MulticlassClassification { .. } => {
			let trees =
				Array2::from_shape_vec((n_rounds_trained, n_trees_per_round), trees).unwrap();
			TrainOutput::MulticlassClassifier(MulticlassClassifierTrainOutput {
				model: MulticlassClassifier { biases, trees },
				feature_importances,
				losses,
			})
		}
	}
}

fn update_predictions_with_tree(
	predictions: &mut [f32],
	examples_index: &[u32],
	tree: &TrainTree,
	#[cfg(feature = "timing")] timing: &Timing,
) {
	#[cfg(feature = "timing")]
	let start = std::time::Instant::now();
	struct PredictionsPtr(*mut [f32]);
	unsafe impl Send for PredictionsPtr {}
	unsafe impl Sync for PredictionsPtr {}
	let predictions_ptr = PredictionsPtr(predictions);
	tree.leaf_values.par_iter().for_each(|(range, value)| {
		examples_index[range.clone()]
			.iter()
			.for_each(|example_index| unsafe {
				let predictions = &mut *predictions_ptr.0;
				let example_index = example_index.to_usize().unwrap();
				*predictions.get_unchecked_mut(example_index) += *value as f32;
			});
	});
	#[cfg(feature = "timing")]
	timing.update_predictions.inc(start.elapsed());
}

#[derive(Clone)]
pub struct EarlyStoppingMonitor {
	tolerance: f32,
	max_rounds_no_improve: usize,
	previous_stopping_metric: Option<f32>,
	num_rounds_no_improve: usize,
}

impl EarlyStoppingMonitor {
	/// Create a train stop monitor,
	pub fn new(tolerance: f32, max_rounds_no_improve: usize) -> EarlyStoppingMonitor {
		EarlyStoppingMonitor {
			tolerance,
			max_rounds_no_improve,
			previous_stopping_metric: None,
			num_rounds_no_improve: 0,
		}
	}

	/// Update with the next epoch's task metrics. Returns true if training should stop.
	pub fn update(&mut self, value: f32) -> bool {
		let stopping_metric = value;
		let result = if let Some(previous_stopping_metric) = self.previous_stopping_metric {
			if stopping_metric > previous_stopping_metric
				|| f32::abs(stopping_metric - previous_stopping_metric) < self.tolerance
			{
				self.num_rounds_no_improve += 1;
				self.num_rounds_no_improve >= self.max_rounds_no_improve
			} else {
				self.num_rounds_no_improve = 0;
				false
			}
		} else {
			false
		};
		self.previous_stopping_metric = Some(stopping_metric);
		result
	}
}

/// Split the feature and labels into train and early stopping datasets, where the early stopping dataset will have `early_stopping_fraction * features.nrows()` rows.
fn train_early_stopping_split<'features, 'labels>(
	features: TableView<'features>,
	labels: TableColumnView<'labels>,
	early_stopping_fraction: f32,
) -> (
	TableView<'features>,
	TableColumnView<'labels>,
	TableView<'features>,
	TableColumnView<'labels>,
) {
	let split_index = (early_stopping_fraction * labels.len().to_f32().unwrap())
		.to_usize()
		.unwrap();
	let (features_early_stopping, features_train) = features.split_at_row(split_index);
	let (labels_early_stopping, labels_train) = labels.split_at_row(split_index);
	(
		features_train,
		labels_train,
		features_early_stopping,
		labels_early_stopping,
	)
}

/// Compute the early stopping metric value for the set of trees that have been trained thus far.
fn compute_early_stopping_metric(
	task: &Task,
	trees_for_round: &[TrainTree],
	features: ArrayView2<TableValue>,
	labels: TableColumnView,
	mut predictions: ArrayViewMut2<f32>,
) -> f32 {
	match task {
		Task::Regression => {
			let labels = labels.as_number().unwrap();
			let labels = labels.as_slice().into();
			crate::regressor::update_logits(
				trees_for_round,
				features.view(),
				predictions.view_mut(),
			);
			crate::regressor::compute_loss(predictions.view(), labels)
		}
		Task::BinaryClassification => {
			let labels = labels.as_enum().unwrap();
			let labels = labels.as_slice().into();
			crate::binary_classifier::update_logits(
				trees_for_round,
				features.view(),
				predictions.view_mut(),
			);
			crate::binary_classifier::compute_loss(predictions.view(), labels)
		}
		Task::MulticlassClassification { .. } => {
			let labels = labels.as_enum().unwrap();
			let labels = labels.as_slice().into();
			crate::multiclass_classifier::update_logits(
				trees_for_round,
				features.view(),
				predictions.view_mut(),
			);
			crate::multiclass_classifier::compute_loss(predictions.view(), labels)
		}
	}
}

fn tree_from_train_tree(
	train_tree: TrainTree,
	train_feature_index_to_feature_index: &[usize],
) -> Tree {
	let nodes = train_tree
		.nodes
		.into_iter()
		.map(|node| node_from_train_node(node, train_feature_index_to_feature_index))
		.collect();
	Tree { nodes }
}

fn node_from_train_node(
	train_node: TrainNode,
	train_feature_index_to_feature_index: &[usize],
) -> Node {
	match train_node {
		TrainNode::Branch(TrainBranchNode {
			left_child_index,
			right_child_index,
			split,
			examples_fraction,
			..
		}) => Node::Branch(BranchNode {
			left_child_index: left_child_index.unwrap(),
			right_child_index: right_child_index.unwrap(),
			split: match split {
				TrainBranchSplit::Continuous(TrainBranchSplitContinuous {
					feature_index,
					invalid_values_direction,
					split_value,
					..
				}) => BranchSplit::Continuous(BranchSplitContinuous {
					feature_index: train_feature_index_to_feature_index[feature_index],
					split_value,
					invalid_values_direction,
				}),
				TrainBranchSplit::Discrete(TrainBranchSplitDiscrete {
					feature_index,
					directions,
					..
				}) => BranchSplit::Discrete(BranchSplitDiscrete {
					feature_index: train_feature_index_to_feature_index[feature_index],
					directions,
				}),
			},
			examples_fraction,
		}),
		TrainNode::Leaf(TrainLeafNode {
			value,
			examples_fraction,
			..
		}) => Node::Leaf(LeafNode {
			value,
			examples_fraction,
		}),
	}
}