tangram_tree/
train_tree.rs

1use crate::{
2	choose_best_split::{
3		choose_best_split_root, choose_best_splits_not_root, ChooseBestSplitOutput,
4		ChooseBestSplitRootOptions, ChooseBestSplitSuccess, ChooseBestSplitsNotRootOptions,
5	},
6	compute_bin_stats::BinStats,
7	compute_binned_features::{BinnedFeaturesColumnMajor, BinnedFeaturesRowMajor},
8	compute_binning_instructions::BinningInstruction,
9	pool::{Pool, PoolItem},
10	rearrange_examples_index::rearrange_examples_index,
11	SplitDirection, TrainOptions,
12};
13use bitvec::prelude::*;
14use num::ToPrimitive;
15use std::{cmp::Ordering, collections::BinaryHeap, ops::Range};
16
17#[derive(Debug)]
18pub struct TrainTree {
19	pub nodes: Vec<TrainNode>,
20	pub leaf_values: Vec<(Range<usize>, f64)>,
21}
22
23impl TrainTree {
24	/// Make a prediction.
25	pub fn predict(&self, example: &[tangram_table::TableValue]) -> f32 {
26		// Start at the root node.
27		let mut node_index = 0;
28		// Traverse the tree until we get to a leaf.
29		loop {
30			match &self.nodes.get(node_index).unwrap() {
31				// This branch uses a continuous split.
32				TrainNode::Branch(TrainBranchNode {
33					left_child_index,
34					right_child_index,
35					split:
36						TrainBranchSplit::Continuous(TrainBranchSplitContinuous {
37							feature_index,
38							split_value,
39							..
40						}),
41					..
42				}) => {
43					node_index = if example[*feature_index].as_number().unwrap() <= split_value {
44						left_child_index.unwrap()
45					} else {
46						right_child_index.unwrap()
47					};
48				}
49				// This branch uses a discrete split.
50				TrainNode::Branch(TrainBranchNode {
51					left_child_index,
52					right_child_index,
53					split:
54						TrainBranchSplit::Discrete(TrainBranchSplitDiscrete {
55							feature_index,
56							directions,
57							..
58						}),
59					..
60				}) => {
61					let bin_index =
62						if let Some(bin_index) = example[*feature_index].as_enum().unwrap() {
63							bin_index.get()
64						} else {
65							0
66						};
67					node_index = match (*directions.get(bin_index).unwrap()).into() {
68						SplitDirection::Left => left_child_index.unwrap(),
69						SplitDirection::Right => right_child_index.unwrap(),
70					};
71				}
72				// We made it to a leaf! The prediction is the leaf's value.
73				TrainNode::Leaf(TrainLeafNode { value, .. }) => return *value as f32,
74			}
75		}
76	}
77}
78
79#[derive(Debug)]
80pub enum TrainNode {
81	Branch(TrainBranchNode),
82	Leaf(TrainLeafNode),
83}
84
85impl TrainNode {
86	pub fn as_branch_mut(&mut self) -> Option<&mut TrainBranchNode> {
87		match self {
88			TrainNode::Branch(s) => Some(s),
89			_ => None,
90		}
91	}
92}
93
94#[derive(Debug)]
95pub struct TrainBranchNode {
96	pub left_child_index: Option<usize>,
97	pub right_child_index: Option<usize>,
98	pub split: TrainBranchSplit,
99	pub examples_fraction: f32,
100}
101
102#[derive(Clone, Debug)]
103pub enum TrainBranchSplit {
104	Continuous(TrainBranchSplitContinuous),
105	Discrete(TrainBranchSplitDiscrete),
106}
107
108#[derive(Clone, Debug)]
109pub struct TrainBranchSplitContinuous {
110	pub feature_index: usize,
111	pub split_value: f32,
112	pub bin_index: usize,
113	pub invalid_values_direction: SplitDirection,
114}
115
116#[derive(Clone, Debug)]
117pub struct TrainBranchSplitDiscrete {
118	pub feature_index: usize,
119	pub directions: BitVec<Lsb0, u8>,
120}
121
122#[derive(Debug)]
123pub struct TrainLeafNode {
124	pub value: f64,
125	pub examples_fraction: f32,
126}
127
128struct QueueItem {
129	/// The priority queue will be sorted by the gain of the split.
130	pub gain: f32,
131	/// The queue item holds a reference to its parent so that it can update the parent's left or right child index if the queue item becomes a node added to the tree.
132	pub parent_index: Option<usize>,
133	/// Will this node be a left or right child of its parent?
134	pub split_direction: Option<SplitDirection>,
135	/// This is the depth of the item in the tree.
136	pub depth: usize,
137	/// The bin_stats consisting of aggregate hessian/gradient statistics of the training examples that reach this node.
138	pub bin_stats: PoolItem<BinStats>,
139	/// The examples_index_range tells you what range of entries in the examples index correspond to this node.
140	pub examples_index_range: std::ops::Range<usize>,
141	/// This is the sum of the gradients for the training examples that pass through this node.
142	pub sum_gradients: f64,
143	/// This is the sum of the hessians for the training examples that pass through this node.
144	pub sum_hessians: f64,
145	/// This is the best split that was chosen for this node.
146	pub split: TrainBranchSplit,
147	/// This is the number of training examples that were sent to the left child.
148	pub left_n_examples: usize,
149	/// This is the sum of the gradients for the training examples that were sent to the left child.
150	pub left_sum_gradients: f64,
151	/// This is the sum of the hessians for the training examples that were sent to the left child.
152	pub left_sum_hessians: f64,
153	/// This is the number of training examples that were sent to the right child.
154	pub right_n_examples: usize,
155	/// This is the sum of the gradients for the training examples that were sent to the right child.
156	pub right_sum_gradients: f64,
157	/// This is the sum of the hessians for the training examples that were sent to the right child.
158	pub right_sum_hessians: f64,
159	/// These are the features that are still splittable.
160	pub splittable_features: Vec<bool>,
161}
162
163impl PartialEq for QueueItem {
164	fn eq(&self, other: &Self) -> bool {
165		self.gain == other.gain
166	}
167}
168
169impl Eq for QueueItem {}
170
171impl std::cmp::PartialOrd for QueueItem {
172	fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
173		self.gain.partial_cmp(&other.gain)
174	}
175}
176
177impl std::cmp::Ord for QueueItem {
178	fn cmp(&self, other: &Self) -> Ordering {
179		self.partial_cmp(other).unwrap()
180	}
181}
182
183pub struct TrainTreeOptions<'a> {
184	pub bin_stats_pool: &'a Pool<BinStats>,
185	pub binned_features_row_major: &'a Option<BinnedFeaturesRowMajor>,
186	pub binned_features_column_major: &'a BinnedFeaturesColumnMajor,
187	pub binning_instructions: &'a [BinningInstruction],
188	pub examples_index_left_buffer: &'a mut [u32],
189	pub examples_index_right_buffer: &'a mut [u32],
190	pub examples_index: &'a mut [u32],
191	pub gradients_ordered_buffer: &'a mut [f32],
192	pub gradients: &'a [f32],
193	pub hessians_are_constant: bool,
194	pub hessians_ordered_buffer: &'a mut [f32],
195	pub hessians: &'a [f32],
196	#[cfg(feature = "timing")]
197	pub timing: &'a crate::timing::Timing,
198	pub train_options: &'a TrainOptions,
199}
200
201/// Train a tree.
202pub fn train_tree(options: TrainTreeOptions) -> TrainTree {
203	let TrainTreeOptions {
204		bin_stats_pool,
205		binned_features_row_major,
206		binned_features_column_major,
207		binning_instructions,
208		examples_index_left_buffer,
209		examples_index_right_buffer,
210		examples_index,
211		gradients_ordered_buffer,
212		gradients,
213		hessians_are_constant,
214		hessians_ordered_buffer,
215		hessians,
216		train_options,
217		..
218	} = options;
219	#[cfg(feature = "timing")]
220	let timing = options.timing;
221	// These are the nodes in the tree returned by this function
222	let mut nodes = Vec::new();
223	// This priority queue stores the potential nodes to split ordered by their gain.
224	let mut queue: BinaryHeap<QueueItem> = BinaryHeap::new();
225	// To update the gradients and hessians we need to make predictions. Rather than running each example through the tree, we can reuse the mapping from example index to leaf value previously computed.
226	let mut leaf_values: Vec<(Range<usize>, f64)> = Vec::new();
227
228	let n_examples_root = examples_index.len();
229	let examples_index_range_root = 0..n_examples_root;
230
231	// Choose the best split for the root node.
232	let choose_best_split_output_root = choose_best_split_root(ChooseBestSplitRootOptions {
233		bin_stats_pool,
234		binned_features_column_major,
235		binned_features_row_major,
236		binning_instructions,
237		examples_index,
238		gradients,
239		hessians_are_constant,
240		hessians,
241		#[cfg(feature = "timing")]
242		timing,
243		train_options,
244	});
245
246	// If we were able to find a split for the root node, add it to the queue and proceed to the loop. Otherwise, return a tree with a single node.
247	match choose_best_split_output_root {
248		ChooseBestSplitOutput::Success(output) => {
249			add_queue_item(AddQueueItemOptions {
250				depth: 0,
251				examples_index_range: examples_index_range_root,
252				output,
253				parent_index: None,
254				queue: &mut queue,
255				split_direction: None,
256			});
257		}
258		ChooseBestSplitOutput::Failure(output) => {
259			add_leaf(AddLeafOptions {
260				examples_index_range: examples_index_range_root,
261				leaf_values: &mut leaf_values,
262				n_examples_root,
263				nodes: &mut nodes,
264				train_options,
265				parent_node_index: None,
266				split_direction: None,
267				sum_gradients: output.sum_gradients,
268				sum_hessians: output.sum_hessians,
269			});
270			return TrainTree { nodes, leaf_values };
271		}
272	}
273
274	// This is the training loop for a tree.
275	loop {
276		// If we will hit the maximum number of leaf nodes by adding the remaining queue items as leaves then exit the loop.
277		let n_leaf_nodes = leaf_values.len() + queue.len();
278		let max_leaf_nodes_reached = n_leaf_nodes == train_options.max_leaf_nodes;
279		if max_leaf_nodes_reached {
280			break;
281		}
282
283		// Pop an item off the queue.
284		let node_index = nodes.len();
285		let queue_item = if let Some(queue_item) = queue.pop() {
286			queue_item
287		} else {
288			break;
289		};
290
291		// Create the new branch node.
292		let examples_fraction = queue_item.examples_index_range.len().to_f32().unwrap()
293			/ n_examples_root.to_f32().unwrap();
294		nodes.push(TrainNode::Branch(TrainBranchNode {
295			split: queue_item.split.clone(),
296			left_child_index: None,
297			right_child_index: None,
298			examples_fraction,
299		}));
300		if let Some(parent_index) = queue_item.parent_index {
301			let parent = nodes
302				.get_mut(parent_index)
303				.unwrap()
304				.as_branch_mut()
305				.unwrap();
306			let split_direction = queue_item.split_direction.unwrap();
307			match split_direction {
308				SplitDirection::Left => parent.left_child_index = Some(node_index),
309				SplitDirection::Right => parent.right_child_index = Some(node_index),
310			}
311		}
312
313		// Rearrange the examples index.
314		#[cfg(feature = "timing")]
315		let start = std::time::Instant::now();
316		let (left, right) = rearrange_examples_index(
317			binned_features_column_major,
318			&queue_item.split,
319			examples_index
320				.get_mut(queue_item.examples_index_range.clone())
321				.unwrap(),
322			examples_index_left_buffer
323				.get_mut(queue_item.examples_index_range.clone())
324				.unwrap(),
325			examples_index_right_buffer
326				.get_mut(queue_item.examples_index_range.clone())
327				.unwrap(),
328		);
329		// The left and right ranges are local to the node, so add the node's start to make them global.
330		let branch_examples_index_range_start = queue_item.examples_index_range.start;
331		let left_child_examples_index_range = branch_examples_index_range_start + left.start
332			..branch_examples_index_range_start + left.end;
333		let right_child_examples_index_range = branch_examples_index_range_start + right.start
334			..branch_examples_index_range_start + right.end;
335		let left_child_examples_index = examples_index
336			.get(left_child_examples_index_range.clone())
337			.unwrap();
338		let right_child_examples_index = examples_index
339			.get(right_child_examples_index_range.clone())
340			.unwrap();
341		#[cfg(feature = "timing")]
342		timing.rearrange_examples_index.inc(start.elapsed());
343
344		// Choose the best splits for each of the right and left children of this new branch.
345		#[cfg(feature = "timing")]
346		let start = std::time::Instant::now();
347		let (left_child_best_split_output, right_child_best_split_output) =
348			choose_best_splits_not_root(ChooseBestSplitsNotRootOptions {
349				bin_stats_pool,
350				binned_features_column_major,
351				binned_features_row_major,
352				binning_instructions,
353				gradients_ordered_buffer,
354				gradients,
355				hessians_are_constant,
356				hessians_ordered_buffer,
357				hessians,
358				left_child_examples_index,
359				splittable_features: queue_item.splittable_features.as_slice(),
360				left_child_n_examples: queue_item.left_n_examples,
361				left_child_sum_gradients: queue_item.left_sum_gradients,
362				left_child_sum_hessians: queue_item.left_sum_hessians,
363				parent_bin_stats: queue_item.bin_stats,
364				parent_depth: queue_item.depth,
365				right_child_examples_index,
366				right_child_n_examples: queue_item.right_n_examples,
367				right_child_sum_gradients: queue_item.right_sum_gradients,
368				right_child_sum_hessians: queue_item.right_sum_hessians,
369				#[cfg(feature = "timing")]
370				timing,
371				train_options,
372			});
373		#[cfg(feature = "timing")]
374		timing.choose_best_split_not_root.inc(start.elapsed());
375
376		// Add a queue item or leaf for the left child.
377		match left_child_best_split_output {
378			ChooseBestSplitOutput::Success(output) => {
379				add_queue_item(AddQueueItemOptions {
380					depth: queue_item.depth + 1,
381					examples_index_range: left_child_examples_index_range,
382					output,
383					parent_index: Some(node_index),
384					queue: &mut queue,
385					split_direction: Some(SplitDirection::Left),
386				});
387			}
388			ChooseBestSplitOutput::Failure(output) => {
389				add_leaf(AddLeafOptions {
390					examples_index_range: left_child_examples_index_range,
391					leaf_values: &mut leaf_values,
392					n_examples_root,
393					nodes: &mut nodes,
394					train_options,
395					parent_node_index: Some(node_index),
396					split_direction: Some(SplitDirection::Left),
397					sum_gradients: output.sum_gradients,
398					sum_hessians: output.sum_hessians,
399				});
400			}
401		}
402
403		// Add a queue item or leaf for the right child.
404		match right_child_best_split_output {
405			ChooseBestSplitOutput::Success(output) => {
406				add_queue_item(AddQueueItemOptions {
407					depth: queue_item.depth + 1,
408					examples_index_range: right_child_examples_index_range,
409					output,
410					parent_index: Some(node_index),
411					queue: &mut queue,
412					split_direction: Some(SplitDirection::Right),
413				});
414			}
415			ChooseBestSplitOutput::Failure(output) => {
416				add_leaf(AddLeafOptions {
417					examples_index_range: right_child_examples_index_range,
418					leaf_values: &mut leaf_values,
419					n_examples_root,
420					nodes: &mut nodes,
421					train_options,
422					parent_node_index: Some(node_index),
423					split_direction: Some(SplitDirection::Right),
424					sum_gradients: output.sum_gradients,
425					sum_hessians: output.sum_hessians,
426				});
427			}
428		}
429	}
430
431	// The remaining items on the queue should all be made into leaves.
432	while let Some(queue_item) = queue.pop() {
433		add_leaf(AddLeafOptions {
434			examples_index_range: queue_item.examples_index_range,
435			leaf_values: &mut leaf_values,
436			n_examples_root,
437			nodes: &mut nodes,
438			train_options,
439			parent_node_index: Some(queue_item.parent_index.unwrap()),
440			split_direction: Some(queue_item.split_direction.unwrap()),
441			sum_gradients: queue_item.sum_gradients,
442			sum_hessians: queue_item.sum_hessians,
443		});
444	}
445
446	TrainTree { nodes, leaf_values }
447}
448
449struct AddQueueItemOptions<'a> {
450	depth: usize,
451	examples_index_range: Range<usize>,
452	output: ChooseBestSplitSuccess,
453	parent_index: Option<usize>,
454	queue: &'a mut BinaryHeap<QueueItem>,
455	split_direction: Option<SplitDirection>,
456}
457
458/// Add a queue item to the queue.
459fn add_queue_item(options: AddQueueItemOptions) {
460	options.queue.push(QueueItem {
461		gain: options.output.gain,
462		splittable_features: options.output.splittable_features,
463		parent_index: options.parent_index,
464		split_direction: options.split_direction,
465		depth: options.depth,
466		bin_stats: options.output.bin_stats,
467		examples_index_range: options.examples_index_range,
468		sum_gradients: options.output.sum_gradients,
469		sum_hessians: options.output.sum_hessians,
470		split: options.output.split,
471		left_n_examples: options.output.left_n_examples,
472		left_sum_gradients: options.output.left_sum_gradients,
473		left_sum_hessians: options.output.left_sum_hessians,
474		right_n_examples: options.output.right_n_examples,
475		right_sum_gradients: options.output.right_sum_gradients,
476		right_sum_hessians: options.output.right_sum_hessians,
477	});
478}
479
480struct AddLeafOptions<'a> {
481	examples_index_range: Range<usize>,
482	leaf_values: &'a mut Vec<(Range<usize>, f64)>,
483	n_examples_root: usize,
484	nodes: &'a mut Vec<TrainNode>,
485	train_options: &'a TrainOptions,
486	parent_node_index: Option<usize>,
487	split_direction: Option<SplitDirection>,
488	sum_gradients: f64,
489	sum_hessians: f64,
490}
491
492/// Add a leaf to the list of nodes and update the parent to refer to it.
493fn add_leaf(options: AddLeafOptions) {
494	let AddLeafOptions {
495		examples_index_range,
496		leaf_values,
497		n_examples_root,
498		nodes,
499		train_options,
500		parent_node_index,
501		split_direction,
502		sum_gradients,
503		sum_hessians,
504	} = options;
505	// This is the index this leaf will have in the `nodes` array.
506	let leaf_index = nodes.len();
507	// Compute the leaf's value.
508	let value = -train_options.learning_rate as f64 * sum_gradients
509		/ (sum_hessians
510			+ train_options.l2_regularization_for_continuous_splits as f64
511			+ std::f64::EPSILON);
512	let examples_fraction =
513		examples_index_range.len().to_f32().unwrap() / n_examples_root.to_f32().unwrap();
514	let node = TrainNode::Leaf(TrainLeafNode {
515		value,
516		examples_fraction,
517	});
518	leaf_values.push((examples_index_range, value));
519	nodes.push(node);
520	// Update the parent's left or right child index to refer to this leaf's index.
521	if let Some(parent_node_index) = parent_node_index {
522		let parent = nodes
523			.get_mut(parent_node_index)
524			.unwrap()
525			.as_branch_mut()
526			.unwrap();
527		match split_direction.unwrap() {
528			SplitDirection::Left => parent.left_child_index = Some(leaf_index),
529			SplitDirection::Right => parent.right_child_index = Some(leaf_index),
530		}
531	}
532}