1use crate::{
2 choose_best_split::{
3 choose_best_split_root, choose_best_splits_not_root, ChooseBestSplitOutput,
4 ChooseBestSplitRootOptions, ChooseBestSplitSuccess, ChooseBestSplitsNotRootOptions,
5 },
6 compute_bin_stats::BinStats,
7 compute_binned_features::{BinnedFeaturesColumnMajor, BinnedFeaturesRowMajor},
8 compute_binning_instructions::BinningInstruction,
9 pool::{Pool, PoolItem},
10 rearrange_examples_index::rearrange_examples_index,
11 SplitDirection, TrainOptions,
12};
13use bitvec::prelude::*;
14use num::ToPrimitive;
15use std::{cmp::Ordering, collections::BinaryHeap, ops::Range};
16
17#[derive(Debug)]
18pub struct TrainTree {
19 pub nodes: Vec<TrainNode>,
20 pub leaf_values: Vec<(Range<usize>, f64)>,
21}
22
23impl TrainTree {
24 pub fn predict(&self, example: &[tangram_table::TableValue]) -> f32 {
26 let mut node_index = 0;
28 loop {
30 match &self.nodes.get(node_index).unwrap() {
31 TrainNode::Branch(TrainBranchNode {
33 left_child_index,
34 right_child_index,
35 split:
36 TrainBranchSplit::Continuous(TrainBranchSplitContinuous {
37 feature_index,
38 split_value,
39 ..
40 }),
41 ..
42 }) => {
43 node_index = if example[*feature_index].as_number().unwrap() <= split_value {
44 left_child_index.unwrap()
45 } else {
46 right_child_index.unwrap()
47 };
48 }
49 TrainNode::Branch(TrainBranchNode {
51 left_child_index,
52 right_child_index,
53 split:
54 TrainBranchSplit::Discrete(TrainBranchSplitDiscrete {
55 feature_index,
56 directions,
57 ..
58 }),
59 ..
60 }) => {
61 let bin_index =
62 if let Some(bin_index) = example[*feature_index].as_enum().unwrap() {
63 bin_index.get()
64 } else {
65 0
66 };
67 node_index = match (*directions.get(bin_index).unwrap()).into() {
68 SplitDirection::Left => left_child_index.unwrap(),
69 SplitDirection::Right => right_child_index.unwrap(),
70 };
71 }
72 TrainNode::Leaf(TrainLeafNode { value, .. }) => return *value as f32,
74 }
75 }
76 }
77}
78
79#[derive(Debug)]
80pub enum TrainNode {
81 Branch(TrainBranchNode),
82 Leaf(TrainLeafNode),
83}
84
85impl TrainNode {
86 pub fn as_branch_mut(&mut self) -> Option<&mut TrainBranchNode> {
87 match self {
88 TrainNode::Branch(s) => Some(s),
89 _ => None,
90 }
91 }
92}
93
94#[derive(Debug)]
95pub struct TrainBranchNode {
96 pub left_child_index: Option<usize>,
97 pub right_child_index: Option<usize>,
98 pub split: TrainBranchSplit,
99 pub examples_fraction: f32,
100}
101
102#[derive(Clone, Debug)]
103pub enum TrainBranchSplit {
104 Continuous(TrainBranchSplitContinuous),
105 Discrete(TrainBranchSplitDiscrete),
106}
107
108#[derive(Clone, Debug)]
109pub struct TrainBranchSplitContinuous {
110 pub feature_index: usize,
111 pub split_value: f32,
112 pub bin_index: usize,
113 pub invalid_values_direction: SplitDirection,
114}
115
116#[derive(Clone, Debug)]
117pub struct TrainBranchSplitDiscrete {
118 pub feature_index: usize,
119 pub directions: BitVec<Lsb0, u8>,
120}
121
122#[derive(Debug)]
123pub struct TrainLeafNode {
124 pub value: f64,
125 pub examples_fraction: f32,
126}
127
128struct QueueItem {
129 pub gain: f32,
131 pub parent_index: Option<usize>,
133 pub split_direction: Option<SplitDirection>,
135 pub depth: usize,
137 pub bin_stats: PoolItem<BinStats>,
139 pub examples_index_range: std::ops::Range<usize>,
141 pub sum_gradients: f64,
143 pub sum_hessians: f64,
145 pub split: TrainBranchSplit,
147 pub left_n_examples: usize,
149 pub left_sum_gradients: f64,
151 pub left_sum_hessians: f64,
153 pub right_n_examples: usize,
155 pub right_sum_gradients: f64,
157 pub right_sum_hessians: f64,
159 pub splittable_features: Vec<bool>,
161}
162
163impl PartialEq for QueueItem {
164 fn eq(&self, other: &Self) -> bool {
165 self.gain == other.gain
166 }
167}
168
169impl Eq for QueueItem {}
170
171impl std::cmp::PartialOrd for QueueItem {
172 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
173 self.gain.partial_cmp(&other.gain)
174 }
175}
176
177impl std::cmp::Ord for QueueItem {
178 fn cmp(&self, other: &Self) -> Ordering {
179 self.partial_cmp(other).unwrap()
180 }
181}
182
183pub struct TrainTreeOptions<'a> {
184 pub bin_stats_pool: &'a Pool<BinStats>,
185 pub binned_features_row_major: &'a Option<BinnedFeaturesRowMajor>,
186 pub binned_features_column_major: &'a BinnedFeaturesColumnMajor,
187 pub binning_instructions: &'a [BinningInstruction],
188 pub examples_index_left_buffer: &'a mut [u32],
189 pub examples_index_right_buffer: &'a mut [u32],
190 pub examples_index: &'a mut [u32],
191 pub gradients_ordered_buffer: &'a mut [f32],
192 pub gradients: &'a [f32],
193 pub hessians_are_constant: bool,
194 pub hessians_ordered_buffer: &'a mut [f32],
195 pub hessians: &'a [f32],
196 #[cfg(feature = "timing")]
197 pub timing: &'a crate::timing::Timing,
198 pub train_options: &'a TrainOptions,
199}
200
201pub fn train_tree(options: TrainTreeOptions) -> TrainTree {
203 let TrainTreeOptions {
204 bin_stats_pool,
205 binned_features_row_major,
206 binned_features_column_major,
207 binning_instructions,
208 examples_index_left_buffer,
209 examples_index_right_buffer,
210 examples_index,
211 gradients_ordered_buffer,
212 gradients,
213 hessians_are_constant,
214 hessians_ordered_buffer,
215 hessians,
216 train_options,
217 ..
218 } = options;
219 #[cfg(feature = "timing")]
220 let timing = options.timing;
221 let mut nodes = Vec::new();
223 let mut queue: BinaryHeap<QueueItem> = BinaryHeap::new();
225 let mut leaf_values: Vec<(Range<usize>, f64)> = Vec::new();
227
228 let n_examples_root = examples_index.len();
229 let examples_index_range_root = 0..n_examples_root;
230
231 let choose_best_split_output_root = choose_best_split_root(ChooseBestSplitRootOptions {
233 bin_stats_pool,
234 binned_features_column_major,
235 binned_features_row_major,
236 binning_instructions,
237 examples_index,
238 gradients,
239 hessians_are_constant,
240 hessians,
241 #[cfg(feature = "timing")]
242 timing,
243 train_options,
244 });
245
246 match choose_best_split_output_root {
248 ChooseBestSplitOutput::Success(output) => {
249 add_queue_item(AddQueueItemOptions {
250 depth: 0,
251 examples_index_range: examples_index_range_root,
252 output,
253 parent_index: None,
254 queue: &mut queue,
255 split_direction: None,
256 });
257 }
258 ChooseBestSplitOutput::Failure(output) => {
259 add_leaf(AddLeafOptions {
260 examples_index_range: examples_index_range_root,
261 leaf_values: &mut leaf_values,
262 n_examples_root,
263 nodes: &mut nodes,
264 train_options,
265 parent_node_index: None,
266 split_direction: None,
267 sum_gradients: output.sum_gradients,
268 sum_hessians: output.sum_hessians,
269 });
270 return TrainTree { nodes, leaf_values };
271 }
272 }
273
274 loop {
276 let n_leaf_nodes = leaf_values.len() + queue.len();
278 let max_leaf_nodes_reached = n_leaf_nodes == train_options.max_leaf_nodes;
279 if max_leaf_nodes_reached {
280 break;
281 }
282
283 let node_index = nodes.len();
285 let queue_item = if let Some(queue_item) = queue.pop() {
286 queue_item
287 } else {
288 break;
289 };
290
291 let examples_fraction = queue_item.examples_index_range.len().to_f32().unwrap()
293 / n_examples_root.to_f32().unwrap();
294 nodes.push(TrainNode::Branch(TrainBranchNode {
295 split: queue_item.split.clone(),
296 left_child_index: None,
297 right_child_index: None,
298 examples_fraction,
299 }));
300 if let Some(parent_index) = queue_item.parent_index {
301 let parent = nodes
302 .get_mut(parent_index)
303 .unwrap()
304 .as_branch_mut()
305 .unwrap();
306 let split_direction = queue_item.split_direction.unwrap();
307 match split_direction {
308 SplitDirection::Left => parent.left_child_index = Some(node_index),
309 SplitDirection::Right => parent.right_child_index = Some(node_index),
310 }
311 }
312
313 #[cfg(feature = "timing")]
315 let start = std::time::Instant::now();
316 let (left, right) = rearrange_examples_index(
317 binned_features_column_major,
318 &queue_item.split,
319 examples_index
320 .get_mut(queue_item.examples_index_range.clone())
321 .unwrap(),
322 examples_index_left_buffer
323 .get_mut(queue_item.examples_index_range.clone())
324 .unwrap(),
325 examples_index_right_buffer
326 .get_mut(queue_item.examples_index_range.clone())
327 .unwrap(),
328 );
329 let branch_examples_index_range_start = queue_item.examples_index_range.start;
331 let left_child_examples_index_range = branch_examples_index_range_start + left.start
332 ..branch_examples_index_range_start + left.end;
333 let right_child_examples_index_range = branch_examples_index_range_start + right.start
334 ..branch_examples_index_range_start + right.end;
335 let left_child_examples_index = examples_index
336 .get(left_child_examples_index_range.clone())
337 .unwrap();
338 let right_child_examples_index = examples_index
339 .get(right_child_examples_index_range.clone())
340 .unwrap();
341 #[cfg(feature = "timing")]
342 timing.rearrange_examples_index.inc(start.elapsed());
343
344 #[cfg(feature = "timing")]
346 let start = std::time::Instant::now();
347 let (left_child_best_split_output, right_child_best_split_output) =
348 choose_best_splits_not_root(ChooseBestSplitsNotRootOptions {
349 bin_stats_pool,
350 binned_features_column_major,
351 binned_features_row_major,
352 binning_instructions,
353 gradients_ordered_buffer,
354 gradients,
355 hessians_are_constant,
356 hessians_ordered_buffer,
357 hessians,
358 left_child_examples_index,
359 splittable_features: queue_item.splittable_features.as_slice(),
360 left_child_n_examples: queue_item.left_n_examples,
361 left_child_sum_gradients: queue_item.left_sum_gradients,
362 left_child_sum_hessians: queue_item.left_sum_hessians,
363 parent_bin_stats: queue_item.bin_stats,
364 parent_depth: queue_item.depth,
365 right_child_examples_index,
366 right_child_n_examples: queue_item.right_n_examples,
367 right_child_sum_gradients: queue_item.right_sum_gradients,
368 right_child_sum_hessians: queue_item.right_sum_hessians,
369 #[cfg(feature = "timing")]
370 timing,
371 train_options,
372 });
373 #[cfg(feature = "timing")]
374 timing.choose_best_split_not_root.inc(start.elapsed());
375
376 match left_child_best_split_output {
378 ChooseBestSplitOutput::Success(output) => {
379 add_queue_item(AddQueueItemOptions {
380 depth: queue_item.depth + 1,
381 examples_index_range: left_child_examples_index_range,
382 output,
383 parent_index: Some(node_index),
384 queue: &mut queue,
385 split_direction: Some(SplitDirection::Left),
386 });
387 }
388 ChooseBestSplitOutput::Failure(output) => {
389 add_leaf(AddLeafOptions {
390 examples_index_range: left_child_examples_index_range,
391 leaf_values: &mut leaf_values,
392 n_examples_root,
393 nodes: &mut nodes,
394 train_options,
395 parent_node_index: Some(node_index),
396 split_direction: Some(SplitDirection::Left),
397 sum_gradients: output.sum_gradients,
398 sum_hessians: output.sum_hessians,
399 });
400 }
401 }
402
403 match right_child_best_split_output {
405 ChooseBestSplitOutput::Success(output) => {
406 add_queue_item(AddQueueItemOptions {
407 depth: queue_item.depth + 1,
408 examples_index_range: right_child_examples_index_range,
409 output,
410 parent_index: Some(node_index),
411 queue: &mut queue,
412 split_direction: Some(SplitDirection::Right),
413 });
414 }
415 ChooseBestSplitOutput::Failure(output) => {
416 add_leaf(AddLeafOptions {
417 examples_index_range: right_child_examples_index_range,
418 leaf_values: &mut leaf_values,
419 n_examples_root,
420 nodes: &mut nodes,
421 train_options,
422 parent_node_index: Some(node_index),
423 split_direction: Some(SplitDirection::Right),
424 sum_gradients: output.sum_gradients,
425 sum_hessians: output.sum_hessians,
426 });
427 }
428 }
429 }
430
431 while let Some(queue_item) = queue.pop() {
433 add_leaf(AddLeafOptions {
434 examples_index_range: queue_item.examples_index_range,
435 leaf_values: &mut leaf_values,
436 n_examples_root,
437 nodes: &mut nodes,
438 train_options,
439 parent_node_index: Some(queue_item.parent_index.unwrap()),
440 split_direction: Some(queue_item.split_direction.unwrap()),
441 sum_gradients: queue_item.sum_gradients,
442 sum_hessians: queue_item.sum_hessians,
443 });
444 }
445
446 TrainTree { nodes, leaf_values }
447}
448
449struct AddQueueItemOptions<'a> {
450 depth: usize,
451 examples_index_range: Range<usize>,
452 output: ChooseBestSplitSuccess,
453 parent_index: Option<usize>,
454 queue: &'a mut BinaryHeap<QueueItem>,
455 split_direction: Option<SplitDirection>,
456}
457
458fn add_queue_item(options: AddQueueItemOptions) {
460 options.queue.push(QueueItem {
461 gain: options.output.gain,
462 splittable_features: options.output.splittable_features,
463 parent_index: options.parent_index,
464 split_direction: options.split_direction,
465 depth: options.depth,
466 bin_stats: options.output.bin_stats,
467 examples_index_range: options.examples_index_range,
468 sum_gradients: options.output.sum_gradients,
469 sum_hessians: options.output.sum_hessians,
470 split: options.output.split,
471 left_n_examples: options.output.left_n_examples,
472 left_sum_gradients: options.output.left_sum_gradients,
473 left_sum_hessians: options.output.left_sum_hessians,
474 right_n_examples: options.output.right_n_examples,
475 right_sum_gradients: options.output.right_sum_gradients,
476 right_sum_hessians: options.output.right_sum_hessians,
477 });
478}
479
480struct AddLeafOptions<'a> {
481 examples_index_range: Range<usize>,
482 leaf_values: &'a mut Vec<(Range<usize>, f64)>,
483 n_examples_root: usize,
484 nodes: &'a mut Vec<TrainNode>,
485 train_options: &'a TrainOptions,
486 parent_node_index: Option<usize>,
487 split_direction: Option<SplitDirection>,
488 sum_gradients: f64,
489 sum_hessians: f64,
490}
491
492fn add_leaf(options: AddLeafOptions) {
494 let AddLeafOptions {
495 examples_index_range,
496 leaf_values,
497 n_examples_root,
498 nodes,
499 train_options,
500 parent_node_index,
501 split_direction,
502 sum_gradients,
503 sum_hessians,
504 } = options;
505 let leaf_index = nodes.len();
507 let value = -train_options.learning_rate as f64 * sum_gradients
509 / (sum_hessians
510 + train_options.l2_regularization_for_continuous_splits as f64
511 + std::f64::EPSILON);
512 let examples_fraction =
513 examples_index_range.len().to_f32().unwrap() / n_examples_root.to_f32().unwrap();
514 let node = TrainNode::Leaf(TrainLeafNode {
515 value,
516 examples_fraction,
517 });
518 leaf_values.push((examples_index_range, value));
519 nodes.push(node);
520 if let Some(parent_node_index) = parent_node_index {
522 let parent = nodes
523 .get_mut(parent_node_index)
524 .unwrap()
525 .as_branch_mut()
526 .unwrap();
527 match split_direction.unwrap() {
528 SplitDirection::Left => parent.left_child_index = Some(leaf_index),
529 SplitDirection::Right => parent.right_child_index = Some(leaf_index),
530 }
531 }
532}