1use std::collections::HashSet;
2
3use crate::constraints::{Constraint, ConstraintMap};
4use crate::data::{JaggedMatrix, Matrix};
5use crate::gradientbooster::MissingNodeTreatment;
6use crate::histogram::HistogramMatrix;
7use crate::node::SplittableNode;
8use crate::tree::Tree;
9use crate::utils::{
10 between, bound_to_parent, constrained_weight, cull_gain, gain_given_weight, pivot_on_split,
11 pivot_on_split_exclude_missing,
12};
13
14#[derive(Debug)]
15pub struct SplitInfo {
16 pub split_gain: f32,
17 pub split_feature: usize,
18 pub split_value: f64,
19 pub split_bin: u16,
20 pub left_node: NodeInfo,
21 pub right_node: NodeInfo,
22 pub missing_node: MissingInfo,
23}
24
25#[derive(Debug)]
26pub struct NodeInfo {
27 pub grad: f32,
28 pub gain: f32,
29 pub cover: f32,
30 pub weight: f32,
31 pub bounds: (f32, f32),
32}
33
34#[derive(Debug)]
35pub enum MissingInfo {
36 Left,
37 Right,
38 Leaf(NodeInfo),
39 Branch(NodeInfo),
40}
41
42pub trait Splitter {
43 fn new_leaves_added(&self) -> usize {
48 1
49 }
50 fn get_constraint(&self, feature: &usize) -> Option<&Constraint>;
51 fn get_gamma(&self) -> f32;
53 fn get_l1(&self) -> f32;
54 fn get_l2(&self) -> f32;
55 fn get_max_delta_step(&self) -> f32;
56 fn get_learning_rate(&self) -> f32;
57
58 fn clean_up_splits(&self, _tree: &mut Tree) {}
63
64 fn best_split(&self, node: &SplittableNode, col_index: &[usize]) -> Option<SplitInfo> {
68 let mut best_split_info = None;
69 let mut best_gain = 0.0;
70 for (idx, feature) in col_index.iter().enumerate() {
71 let split_info = self.best_feature_split(node, *feature, idx);
72 match split_info {
73 Some(info) => {
74 if info.split_gain > best_gain {
75 best_gain = info.split_gain;
76 best_split_info = Some(info);
77 }
78 }
79 None => continue,
80 }
81 }
82 best_split_info
83 }
84
85 #[allow(clippy::too_many_arguments)]
88 fn evaluate_split(
89 &self,
90 left_gradient: f32,
91 left_hessian: f32,
92 right_gradient: f32,
93 right_hessian: f32,
94 missing_gradient: f32,
95 missing_hessian: f32,
96 lower_bound: f32,
97 upper_bound: f32,
98 parent_weight: f32,
99 constraint: Option<&Constraint>,
100 ) -> Option<(NodeInfo, NodeInfo, MissingInfo)>;
101
102 fn best_feature_split(
105 &self,
106 node: &SplittableNode,
107 feature: usize,
108 idx: usize,
109 ) -> Option<SplitInfo> {
110 let mut split_info: Option<SplitInfo> = None;
111 let mut max_gain: Option<f32> = None;
112
113 let HistogramMatrix(histograms) = &node.histograms;
114 let histogram = histograms.get_col(idx);
115
116 let missing = &histogram[0];
118 let mut cuml_grad = 0.0; let mut cuml_hess = 0.0; let constraint = self.get_constraint(&feature);
121
122 let elements = histogram.len();
123 assert!(elements == histogram.len());
124
125 for (i, bin) in histogram[1..].iter().enumerate() {
126 let left_gradient = cuml_grad;
127 let left_hessian = cuml_hess;
128 let right_gradient = node.gradient_sum - cuml_grad - missing.gradient_sum;
129 let right_hessian = node.hessian_sum - cuml_hess - missing.hessian_sum;
130 cuml_grad += bin.gradient_sum;
131 cuml_hess += bin.hessian_sum;
132
133 let (mut left_node_info, mut right_node_info, missing_info) = match self.evaluate_split(
134 left_gradient,
135 left_hessian,
136 right_gradient,
137 right_hessian,
138 missing.gradient_sum,
139 missing.hessian_sum,
140 node.lower_bound,
141 node.upper_bound,
142 node.weight_value,
143 constraint,
144 ) {
145 None => {
146 continue;
147 }
148 Some(v) => v,
149 };
150
151 let split_gain = node.get_split_gain(
152 &left_node_info,
153 &right_node_info,
154 &missing_info,
155 self.get_gamma(),
156 );
157
158 let split_gain = cull_gain(
160 split_gain,
161 left_node_info.weight,
162 right_node_info.weight,
163 constraint,
164 );
165
166 if split_gain <= 0.0 {
167 continue;
168 }
169
170 let mid = (left_node_info.weight + right_node_info.weight) / 2.0;
171 let (left_bounds, right_bounds) = match constraint {
172 None | Some(Constraint::Unconstrained) => (
173 (node.lower_bound, node.upper_bound),
174 (node.lower_bound, node.upper_bound),
175 ),
176 Some(Constraint::Negative) => ((mid, node.upper_bound), (node.lower_bound, mid)),
177 Some(Constraint::Positive) => ((node.lower_bound, mid), (mid, node.upper_bound)),
178 };
179 left_node_info.bounds = left_bounds;
180 right_node_info.bounds = right_bounds;
181
182 let split_gain = if split_gain.is_nan() { 0.0 } else { split_gain };
185 if max_gain.is_none() || split_gain > max_gain.unwrap() {
186 max_gain = Some(split_gain);
187 split_info = Some(SplitInfo {
188 split_gain,
189 split_feature: feature,
190 split_value: bin.cut_value,
191 split_bin: (i + 1) as u16,
192 left_node: left_node_info,
193 right_node: right_node_info,
194 missing_node: missing_info,
195 });
196 }
197 }
198 split_info
199 }
200
201 #[allow(clippy::too_many_arguments)]
205 fn handle_split_info(
206 &self,
207 split_info: SplitInfo,
208 n_nodes: &usize,
209 node: &mut SplittableNode,
210 index: &mut [usize],
211 col_index: &[usize],
212 data: &Matrix<u16>,
213 cuts: &JaggedMatrix<f64>,
214 grad: &[f32],
215 hess: &[f32],
216 parallel: bool,
217 ) -> Vec<SplittableNode>;
218
219 #[allow(clippy::too_many_arguments)]
222 fn split_node(
223 &self,
224 n_nodes: &usize,
225 node: &mut SplittableNode,
226 index: &mut [usize],
227 col_index: &[usize],
228 data: &Matrix<u16>,
229 cuts: &JaggedMatrix<f64>,
230 grad: &[f32],
231 hess: &[f32],
232 parallel: bool,
233 ) -> Vec<SplittableNode> {
234 match self.best_split(node, col_index) {
235 Some(split_info) => self.handle_split_info(
236 split_info, n_nodes, node, index, col_index, data, cuts, grad, hess, parallel,
237 ),
238 None => Vec::new(),
239 }
240 }
241}
242
243pub struct MissingBranchSplitter {
249 pub l1: f32,
250 pub l2: f32,
251 pub max_delta_step: f32,
252 pub gamma: f32,
253 pub min_leaf_weight: f32,
254 pub learning_rate: f32,
255 pub allow_missing_splits: bool,
256 pub constraints_map: ConstraintMap,
257 pub terminate_missing_features: HashSet<usize>,
258 pub missing_node_treatment: MissingNodeTreatment,
259 pub force_children_to_bound_parent: bool,
260}
261
262impl MissingBranchSplitter {
263 pub fn new_leaves_added(&self) -> usize {
264 2
265 }
266 pub fn update_average_missing_nodes(tree: &mut Tree, node_idx: usize) -> f64 {
267 let node = &tree.nodes[node_idx];
268
269 if node.is_leaf {
270 return node.weight_value as f64;
271 }
272
273 let right = node.right_child;
274 let left = node.left_child;
275 let current_node = node.num;
276 let missing = node.missing_node;
277
278 let right_hessian = tree.nodes[right].hessian_sum as f64;
279 let right_avg_weight = Self::update_average_missing_nodes(tree, right);
280
281 let left_hessian = tree.nodes[left].hessian_sum as f64;
282 let left_avg_weight = Self::update_average_missing_nodes(tree, left);
283
284 let (missing_hessian, missing_avg_weight, missing_leaf) = if tree.nodes[missing].is_leaf {
287 (0., 0., true)
288 } else {
289 (
290 tree.nodes[missing].hessian_sum as f64,
291 Self::update_average_missing_nodes(tree, missing),
292 false,
293 )
294 };
295
296 let update = (right_avg_weight * right_hessian
297 + left_avg_weight * left_hessian
298 + missing_avg_weight * missing_hessian)
299 / (left_hessian + right_hessian + missing_hessian);
300
301 if let Some(n) = tree.nodes.get_mut(current_node) {
303 n.weight_value = update as f32;
304 }
305 if missing_leaf {
308 if let Some(m) = tree.nodes.get_mut(missing) {
309 m.weight_value = update as f32;
310 }
311 }
312
313 update
314 }
315}
316
317impl Splitter for MissingBranchSplitter {
318 fn clean_up_splits(&self, tree: &mut Tree) {
319 if let MissingNodeTreatment::AverageLeafWeight = self.missing_node_treatment {
320 MissingBranchSplitter::update_average_missing_nodes(tree, 0);
321 }
322 }
323
324 fn get_constraint(&self, feature: &usize) -> Option<&Constraint> {
325 self.constraints_map.get(feature)
326 }
327
328 fn get_gamma(&self) -> f32 {
329 self.gamma
330 }
331
332 fn get_l1(&self) -> f32 {
333 self.l1
334 }
335
336 fn get_l2(&self) -> f32 {
337 self.l2
338 }
339 fn get_max_delta_step(&self) -> f32 {
340 self.max_delta_step
341 }
342
343 fn get_learning_rate(&self) -> f32 {
344 self.learning_rate
345 }
346
347 fn evaluate_split(
348 &self,
349 left_gradient: f32,
350 left_hessian: f32,
351 right_gradient: f32,
352 right_hessian: f32,
353 missing_gradient: f32,
354 missing_hessian: f32,
355 lower_bound: f32,
356 upper_bound: f32,
357 parent_weight: f32,
358 constraint: Option<&Constraint>,
359 ) -> Option<(NodeInfo, NodeInfo, MissingInfo)> {
360 if (left_gradient == 0.0) && (left_hessian == 0.0)
364 || (right_gradient == 0.0) && (right_hessian == 0.0)
365 {
366 return None;
367 }
368
369 let mut left_weight = constrained_weight(
370 &self.l1,
371 &self.l2,
372 &self.max_delta_step,
373 left_gradient,
374 left_hessian,
375 lower_bound,
376 upper_bound,
377 constraint,
378 );
379 let mut right_weight = constrained_weight(
380 &self.l1,
381 &self.l2,
382 &self.max_delta_step,
383 right_gradient,
384 right_hessian,
385 lower_bound,
386 upper_bound,
387 constraint,
388 );
389
390 if self.force_children_to_bound_parent {
391 (left_weight, right_weight) = bound_to_parent(parent_weight, left_weight, right_weight);
392 assert!(between(lower_bound, upper_bound, left_weight));
393 assert!(between(lower_bound, upper_bound, right_weight));
394 }
395
396 let left_gain = gain_given_weight(&self.l2, left_gradient, left_hessian, left_weight);
397 let right_gain = gain_given_weight(&self.l2, right_gradient, right_hessian, right_weight);
398
399 if (right_hessian < self.min_leaf_weight) || (left_hessian < self.min_leaf_weight) {
401 return None;
403 }
404
405 let missing_weight = match self.missing_node_treatment {
412 MissingNodeTreatment::AssignToParent => constrained_weight(
413 &self.get_l1(),
414 &self.get_l2(),
415 &self.max_delta_step,
416 missing_gradient + left_gradient + right_gradient,
417 missing_hessian + left_hessian + right_hessian,
418 lower_bound,
419 upper_bound,
420 constraint,
421 ),
422 MissingNodeTreatment::AverageLeafWeight | MissingNodeTreatment::AverageNodeWeight => {
425 (right_weight * right_hessian + left_weight * left_hessian)
426 / (right_hessian + left_hessian)
427 }
428 MissingNodeTreatment::None => {
429 if missing_hessian == 0. || missing_gradient == 0. {
432 parent_weight
433 } else {
434 constrained_weight(
435 &self.get_l1(),
436 &self.get_l2(),
437 &self.max_delta_step,
438 missing_gradient,
439 missing_hessian,
440 lower_bound,
441 upper_bound,
442 constraint,
443 )
444 }
445 }
446 };
447 let missing_gain = gain_given_weight(
448 &self.get_l2(),
449 missing_gradient,
450 missing_hessian,
451 missing_weight,
452 );
453 let missing_info = NodeInfo {
454 grad: missing_gradient,
455 gain: missing_gain,
456 cover: missing_hessian,
457 weight: missing_weight,
458 bounds: (lower_bound, upper_bound),
462 };
463 let missing_node = if ((missing_gradient != 0.0) || (missing_hessian != 0.0)) && self.allow_missing_splits {
465 MissingInfo::Branch(
466 missing_info
467 )
468 } else {
469 MissingInfo::Leaf(
470 missing_info
471 )
472 };
473
474 if (right_hessian < self.min_leaf_weight) || (left_hessian < self.min_leaf_weight) {
475 return None;
477 }
478 Some((
479 NodeInfo {
480 grad: left_gradient,
481 gain: left_gain,
482 cover: left_hessian,
483 weight: left_weight,
484 bounds: (f32::NEG_INFINITY, f32::INFINITY),
485 },
486 NodeInfo {
487 grad: right_gradient,
488 gain: right_gain,
489 cover: right_hessian,
490 weight: right_weight,
491 bounds: (f32::NEG_INFINITY, f32::INFINITY),
492 },
493 missing_node,
494 ))
495 }
496
497 fn handle_split_info(
498 &self,
499 split_info: SplitInfo,
500 n_nodes: &usize,
501 node: &mut SplittableNode,
502 index: &mut [usize],
503 col_index: &[usize],
504 data: &Matrix<u16>,
505 cuts: &JaggedMatrix<f64>,
506 grad: &[f32],
507 hess: &[f32],
508 parallel: bool,
509 ) -> Vec<SplittableNode> {
510 let missing_child = *n_nodes;
511 let left_child = missing_child + 1;
512 let right_child = missing_child + 2;
513 node.update_children(missing_child, left_child, right_child, &split_info);
514
515 let (mut missing_is_leaf, mut missing_info) = match split_info.missing_node {
516 MissingInfo::Branch(i) => {
517 if self
518 .terminate_missing_features
519 .contains(&split_info.split_feature)
520 {
521 (true, i)
522 } else {
523 (false, i)
524 }
525 }
526 MissingInfo::Leaf(i) => (true, i),
527 _ => unreachable!(),
528 };
529 if let MissingNodeTreatment::AssignToParent = self.missing_node_treatment {
534 missing_info.weight = node.weight_value;
535 }
536 let (mut missing_split_idx, mut split_idx) = pivot_on_split_exclude_missing(
541 &mut index[node.start_idx..node.stop_idx],
542 data.get_col(split_info.split_feature),
543 split_info.split_bin,
544 );
545 let total_recs = node.stop_idx - node.start_idx;
547 let n_right = total_recs - split_idx;
548 let n_left = total_recs - n_right - missing_split_idx;
549 let n_missing = total_recs - (n_right + n_left);
550 let max_ = match [n_missing, n_left, n_right]
551 .iter()
552 .enumerate()
553 .max_by(|(_, i), (_, j)| i.cmp(j))
554 {
555 Some((i, _)) => i,
556 None => 0,
560 };
561
562 split_idx += node.start_idx;
566 missing_split_idx += node.start_idx;
567
568 let left_histograms: HistogramMatrix;
570 let right_histograms: HistogramMatrix;
571 let missing_histograms: HistogramMatrix;
572 if n_missing == 0 {
573 missing_is_leaf = true;
576 if max_ == 1 {
577 missing_histograms = HistogramMatrix::empty();
578 right_histograms = HistogramMatrix::new(
579 data,
580 cuts,
581 grad,
582 hess,
583 &index[split_idx..node.stop_idx],
584 col_index,
585 parallel,
586 true,
587 );
588 left_histograms =
589 HistogramMatrix::from_parent_child(&node.histograms, &right_histograms);
590 } else {
591 missing_histograms = HistogramMatrix::empty();
592 left_histograms = HistogramMatrix::new(
593 data,
594 cuts,
595 grad,
596 hess,
597 &index[missing_split_idx..split_idx],
598 col_index,
599 parallel,
600 true,
601 );
602 right_histograms =
603 HistogramMatrix::from_parent_child(&node.histograms, &left_histograms);
604 }
605 } else if max_ == 0 {
606 left_histograms = HistogramMatrix::new(
609 data,
610 cuts,
611 grad,
612 hess,
613 &index[missing_split_idx..split_idx],
614 col_index,
615 parallel,
616 true,
617 );
618 right_histograms = HistogramMatrix::new(
619 data,
620 cuts,
621 grad,
622 hess,
623 &index[split_idx..node.stop_idx],
624 col_index,
625 parallel,
626 true,
627 );
628 missing_histograms = HistogramMatrix::from_parent_two_children(
629 &node.histograms,
630 &left_histograms,
631 &right_histograms,
632 )
633 } else if max_ == 1 {
634 missing_histograms = HistogramMatrix::new(
635 data,
636 cuts,
637 grad,
638 hess,
639 &index[node.start_idx..missing_split_idx],
640 col_index,
641 parallel,
642 true,
643 );
644 right_histograms = HistogramMatrix::new(
645 data,
646 cuts,
647 grad,
648 hess,
649 &index[split_idx..node.stop_idx],
650 col_index,
651 parallel,
652 true,
653 );
654 left_histograms = HistogramMatrix::from_parent_two_children(
655 &node.histograms,
656 &missing_histograms,
657 &right_histograms,
658 )
659 } else {
660 missing_histograms = HistogramMatrix::new(
662 data,
663 cuts,
664 grad,
665 hess,
666 &index[node.start_idx..missing_split_idx],
667 col_index,
668 parallel,
669 true,
670 );
671 left_histograms = HistogramMatrix::new(
672 data,
673 cuts,
674 grad,
675 hess,
676 &index[missing_split_idx..split_idx],
677 col_index,
678 parallel,
679 true,
680 );
681 right_histograms = HistogramMatrix::from_parent_two_children(
682 &node.histograms,
683 &missing_histograms,
684 &left_histograms,
685 )
686 }
687
688 let mut missing_node = SplittableNode::from_node_info(
689 missing_child,
690 missing_histograms,
691 node.depth + 1,
692 node.start_idx,
693 missing_split_idx,
694 missing_info,
695 );
696 missing_node.is_missing_leaf = missing_is_leaf;
697 let left_node = SplittableNode::from_node_info(
698 left_child,
699 left_histograms,
700 node.depth + 1,
701 missing_split_idx,
702 split_idx,
703 split_info.left_node,
704 );
705 let right_node = SplittableNode::from_node_info(
706 right_child,
707 right_histograms,
708 node.depth + 1,
709 split_idx,
710 node.stop_idx,
711 split_info.right_node,
712 );
713 vec![missing_node, left_node, right_node]
714 }
715}
716
717pub struct MissingImputerSplitter {
722 pub l1: f32,
723 pub l2: f32,
724 pub max_delta_step: f32,
725 pub gamma: f32,
726 pub min_leaf_weight: f32,
727 pub learning_rate: f32,
728 pub allow_missing_splits: bool,
729 pub constraints_map: ConstraintMap,
730}
731
732impl MissingImputerSplitter {
733 #[allow(clippy::too_many_arguments)]
735 pub fn new(
736 l1: f32,
737 l2: f32,
738 max_delta_step: f32,
739 gamma: f32,
740 min_leaf_weight: f32,
741 learning_rate: f32,
742 allow_missing_splits: bool,
743 constraints_map: ConstraintMap,
744 ) -> Self {
745 MissingImputerSplitter {
746 l1,
747 l2,
748 max_delta_step,
749 gamma,
750 min_leaf_weight,
751 learning_rate,
752 allow_missing_splits,
753 constraints_map,
754 }
755 }
756}
757
758impl Splitter for MissingImputerSplitter {
759 fn get_constraint(&self, feature: &usize) -> Option<&Constraint> {
760 self.constraints_map.get(feature)
761 }
762
763 fn get_gamma(&self) -> f32 {
764 self.gamma
765 }
766
767 fn get_l1(&self) -> f32 {
768 self.l1
769 }
770
771 fn get_l2(&self) -> f32 {
772 self.l2
773 }
774 fn get_max_delta_step(&self) -> f32 {
775 self.max_delta_step
776 }
777
778 fn get_learning_rate(&self) -> f32 {
779 self.learning_rate
780 }
781
782 #[allow(clippy::too_many_arguments)]
783 fn evaluate_split(
784 &self,
785 left_gradient: f32,
786 left_hessian: f32,
787 right_gradient: f32,
788 right_hessian: f32,
789 missing_gradient: f32,
790 missing_hessian: f32,
791 lower_bound: f32,
792 upper_bound: f32,
793 _parent_weight: f32,
794 constraint: Option<&Constraint>,
795 ) -> Option<(NodeInfo, NodeInfo, MissingInfo)> {
796 if ((left_gradient == 0.0) && (left_hessian == 0.0)
800 || (right_gradient == 0.0) && (right_hessian == 0.0))
801 && !self.allow_missing_splits
802 {
803 return None;
804 }
805
806 let mut missing_info = MissingInfo::Right;
808
809 let mut left_gradient = left_gradient;
810 let mut left_hessian = left_hessian;
811
812 let mut right_gradient = right_gradient;
813 let mut right_hessian = right_hessian;
814
815 let mut left_weight = constrained_weight(
816 &self.l1,
817 &self.l2,
818 &self.max_delta_step,
819 left_gradient,
820 left_hessian,
821 lower_bound,
822 upper_bound,
823 constraint,
824 );
825 let mut right_weight = constrained_weight(
826 &self.l1,
827 &self.l2,
828 &self.max_delta_step,
829 right_gradient,
830 right_hessian,
831 lower_bound,
832 upper_bound,
833 constraint,
834 );
835
836 let mut left_gain = gain_given_weight(&self.l2, left_gradient, left_hessian, left_weight);
837 let mut right_gain =
838 gain_given_weight(&self.l2, right_gradient, right_hessian, right_weight);
839
840 if !self.allow_missing_splits {
841 if (right_hessian < self.min_leaf_weight) || (left_hessian < self.min_leaf_weight) {
844 return None;
846 }
847 }
848
849 if (missing_gradient != 0.0) || (missing_hessian != 0.0) {
853 let missing_left_weight = constrained_weight(
858 &self.l1,
859 &self.l2,
860 &self.max_delta_step,
861 left_gradient + missing_gradient,
862 left_hessian + missing_hessian,
863 lower_bound,
864 upper_bound,
865 constraint,
866 );
867 let missing_left_gain = gain_given_weight(
869 &self.l2,
870 left_gradient + missing_gradient,
871 left_hessian + missing_hessian,
872 missing_left_weight,
873 );
874 let missing_left_gain = cull_gain(
876 missing_left_gain,
877 missing_left_weight,
878 right_weight,
879 constraint,
880 );
881
882 let missing_right_weight = constrained_weight(
884 &self.l1,
885 &self.l2,
886 &self.max_delta_step,
887 right_gradient + missing_gradient,
888 right_hessian + missing_hessian,
889 lower_bound,
890 upper_bound,
891 constraint,
892 );
893 let missing_right_gain = gain_given_weight(
895 &self.l2,
896 right_gradient + missing_gradient,
897 right_hessian + missing_hessian,
898 missing_right_weight,
899 );
900 let missing_right_gain = cull_gain(
902 missing_right_gain,
903 left_weight,
904 missing_right_weight,
905 constraint,
906 );
907
908 if (missing_right_gain - right_gain) < (missing_left_gain - left_gain) {
909 left_gradient += missing_gradient;
911 left_hessian += missing_hessian;
912 left_gain = missing_left_gain;
913 left_weight = missing_left_weight;
914 missing_info = MissingInfo::Left;
915 } else {
916 right_gradient += missing_gradient;
918 right_hessian += missing_hessian;
919 right_gain = missing_right_gain;
920 right_weight = missing_right_weight;
921 missing_info = MissingInfo::Right;
922 }
923 }
924
925 if (right_hessian < self.min_leaf_weight) || (left_hessian < self.min_leaf_weight) {
926 return None;
928 }
929 Some((
930 NodeInfo {
931 grad: left_gradient,
932 gain: left_gain,
933 cover: left_hessian,
934 weight: left_weight,
935 bounds: (f32::NEG_INFINITY, f32::INFINITY),
936 },
937 NodeInfo {
938 grad: right_gradient,
939 gain: right_gain,
940 cover: right_hessian,
941 weight: right_weight,
942 bounds: (f32::NEG_INFINITY, f32::INFINITY),
943 },
944 missing_info,
945 ))
946 }
947
948 fn handle_split_info(
949 &self,
950 split_info: SplitInfo,
951 n_nodes: &usize,
952 node: &mut SplittableNode,
953 index: &mut [usize],
954 col_index: &[usize],
955 data: &Matrix<u16>,
956 cuts: &JaggedMatrix<f64>,
957 grad: &[f32],
958 hess: &[f32],
959 parallel: bool,
960 ) -> Vec<SplittableNode> {
961 let left_child = *n_nodes;
962 let right_child = left_child + 1;
963
964 let missing_right = match split_info.missing_node {
965 MissingInfo::Left => false,
966 MissingInfo::Right => true,
967 _ => unreachable!(),
968 };
969
970 let mut split_idx = pivot_on_split(
977 &mut index[node.start_idx..node.stop_idx],
978 data.get_col(split_info.split_feature),
979 split_info.split_bin,
980 missing_right,
981 );
982 let total_recs = node.stop_idx - node.start_idx;
984 let n_right = total_recs - split_idx;
985 let n_left = total_recs - n_right;
986
987 split_idx += node.start_idx;
991
992 let left_histograms: HistogramMatrix;
994 let right_histograms: HistogramMatrix;
995 if n_left < n_right {
996 left_histograms = HistogramMatrix::new(
997 data,
998 cuts,
999 grad,
1000 hess,
1001 &index[node.start_idx..split_idx],
1002 col_index,
1003 parallel,
1004 true,
1005 );
1006 right_histograms =
1007 HistogramMatrix::from_parent_child(&node.histograms, &left_histograms);
1008 } else {
1009 right_histograms = HistogramMatrix::new(
1010 data,
1011 cuts,
1012 grad,
1013 hess,
1014 &index[split_idx..node.stop_idx],
1015 col_index,
1016 parallel,
1017 true,
1018 );
1019 left_histograms =
1020 HistogramMatrix::from_parent_child(&node.histograms, &right_histograms);
1021 }
1022 let missing_child = if missing_right {
1023 right_child
1024 } else {
1025 left_child
1026 };
1027 node.update_children(missing_child, left_child, right_child, &split_info);
1028
1029 let left_node = SplittableNode::from_node_info(
1030 left_child,
1031 left_histograms,
1032 node.depth + 1,
1033 node.start_idx,
1034 split_idx,
1035 split_info.left_node,
1036 );
1037 let right_node = SplittableNode::from_node_info(
1038 right_child,
1039 right_histograms,
1040 node.depth + 1,
1041 split_idx,
1042 node.stop_idx,
1043 split_info.right_node,
1044 );
1045 vec![left_node, right_node]
1046 }
1047}
1048
1049#[cfg(test)]
1050mod tests {
1051 use super::*;
1052 use crate::binning::bin_matrix;
1053 use crate::data::Matrix;
1054 use crate::node::SplittableNode;
1055 use crate::objective::{LogLoss, ObjectiveFunction};
1056 use crate::utils::gain;
1057 use crate::utils::weight;
1058 use std::fs;
1059 #[test]
1060 fn test_best_feature_split() {
1061 let d = vec![4., 2., 3., 4., 5., 1., 4.];
1062 let data = Matrix::new(&d, 7, 1);
1063 let y = vec![0., 0., 0., 1., 1., 0., 1.];
1064 let yhat = vec![0.; 7];
1065 let w = vec![1.; y.len()];
1066 let (grad, hess) = LogLoss::calc_grad_hess(&y, &yhat, &w);
1067 let b = bin_matrix(&data, &w, 10, f64::NAN).unwrap();
1068 let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
1069 let index = data.index.to_owned();
1070 let hists = HistogramMatrix::new(&bdata, &b.cuts, &grad, &hess, &index, &[0], true, true);
1071 let splitter = MissingImputerSplitter {
1072 l1: 0.0,
1073 l2: 0.0,
1074 max_delta_step: 0.,
1075 gamma: 0.0,
1076 min_leaf_weight: 0.0,
1077 learning_rate: 1.0,
1078 allow_missing_splits: true,
1079 constraints_map: ConstraintMap::new(),
1080 };
1081 let mut n = SplittableNode::new(
1082 0,
1083 hists,
1085 0.0,
1086 0.14,
1087 grad.iter().sum::<f32>(),
1088 hess.iter().sum::<f32>(),
1089 0,
1090 0,
1091 grad.len(),
1092 f32::NEG_INFINITY,
1093 f32::INFINITY,
1094 );
1095 let s = splitter.best_feature_split(&mut n, 0, 0).unwrap();
1096 assert_eq!(s.split_value, 4.0);
1097 assert_eq!(s.left_node.cover, 0.75);
1098 assert_eq!(s.right_node.cover, 1.0);
1099 assert_eq!(s.left_node.gain, 3.0);
1100 assert_eq!(s.right_node.gain, 1.0);
1101 assert_eq!(s.split_gain, 3.86);
1102 }
1103
1104 #[test]
1105 fn test_best_split() {
1106 let d: Vec<f64> = vec![0., 0., 0., 1., 0., 0., 0., 4., 2., 3., 4., 5., 1., 4.];
1107 let data = Matrix::new(&d, 7, 2);
1108 let y = vec![0., 0., 0., 1., 1., 0., 1.];
1109 let yhat = vec![0.; 7];
1110 let w = vec![1.; y.len()];
1111 let (grad, hess) = LogLoss::calc_grad_hess(&y, &yhat, &w);
1112
1113 let b = bin_matrix(&data, &w, 10, f64::NAN).unwrap();
1114 let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
1115 let index = data.index.to_owned();
1116 let hists =
1117 HistogramMatrix::new(&bdata, &b.cuts, &grad, &hess, &index, &[0, 1], true, true);
1118 println!("{:?}", hists);
1119 let splitter = MissingImputerSplitter {
1120 l1: 0.0,
1121 l2: 0.0,
1122 max_delta_step: 0.,
1123 gamma: 0.0,
1124 min_leaf_weight: 0.0,
1125 learning_rate: 1.0,
1126 allow_missing_splits: true,
1127 constraints_map: ConstraintMap::new(),
1128 };
1129 let mut n = SplittableNode::new(
1130 0,
1131 hists,
1133 0.0,
1134 0.14,
1135 grad.iter().sum::<f32>(),
1136 hess.iter().sum::<f32>(),
1137 0,
1138 0,
1139 grad.len(),
1140 f32::NEG_INFINITY,
1141 f32::INFINITY,
1142 );
1143 let s = splitter.best_split(&mut n, &[0, 1]).unwrap();
1144 println!("{:?}", s);
1145 assert_eq!(s.split_feature, 1);
1146 assert_eq!(s.split_value, 4.);
1147 assert_eq!(s.left_node.cover, 0.75);
1148 assert_eq!(s.right_node.cover, 1.);
1149 assert_eq!(s.left_node.gain, 3.);
1150 assert_eq!(s.right_node.gain, 1.);
1151 assert_eq!(s.split_gain, 3.86);
1152 }
1153
1154 #[test]
1155 fn test_data_split() {
1156 let file = fs::read_to_string("resources/contiguous_no_missing.csv")
1157 .expect("Something went wrong reading the file");
1158 let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
1159 let file = fs::read_to_string("resources/performance.csv")
1160 .expect("Something went wrong reading the file");
1161 let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
1162 let yhat = vec![0.5; y.len()];
1163 let w = vec![1.; y.len()];
1164 let (grad, hess) = LogLoss::calc_grad_hess(&y, &yhat, &w);
1165
1166 let splitter = MissingImputerSplitter {
1167 l1: 0.0,
1168 l2: 1.0,
1169 max_delta_step: 0.,
1170 gamma: 3.0,
1171 min_leaf_weight: 1.0,
1172 learning_rate: 0.3,
1173 allow_missing_splits: true,
1174 constraints_map: ConstraintMap::new(),
1175 };
1176 let gradient_sum = grad.iter().copied().sum();
1177 let hessian_sum = hess.iter().copied().sum();
1178 let root_weight = weight(
1179 &splitter.l1,
1180 &splitter.l2,
1181 &splitter.max_delta_step,
1182 gradient_sum,
1183 hessian_sum,
1184 );
1185 let root_gain = gain(&splitter.l2, gradient_sum, hessian_sum);
1186 let data = Matrix::new(&data_vec, 891, 5);
1187
1188 let b = bin_matrix(&data, &w, 10, f64::NAN).unwrap();
1189 let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
1190 let index = data.index.to_owned();
1191 let col_index: Vec<usize> = (0..data.cols).collect();
1192 let hists = HistogramMatrix::new(
1193 &bdata, &b.cuts, &grad, &hess, &index, &col_index, true, false,
1194 );
1195
1196 let mut n = SplittableNode::new(
1197 0,
1198 hists,
1200 root_weight,
1201 root_gain,
1202 grad.iter().copied().sum::<f32>(),
1203 hess.iter().copied().sum::<f32>(),
1204 0,
1205 0,
1206 grad.len(),
1207 f32::NEG_INFINITY,
1208 f32::INFINITY,
1209 );
1210 let s = splitter.best_split(&mut n, &col_index).unwrap();
1211 println!("{:?}", s);
1212 n.update_children(2, 1, 2, &s);
1213 assert_eq!(0, s.split_feature);
1214 }
1215}