forust_ml/
splitter.rs

1use std::collections::HashSet;
2
3use crate::constraints::{Constraint, ConstraintMap};
4use crate::data::{JaggedMatrix, Matrix};
5use crate::gradientbooster::MissingNodeTreatment;
6use crate::histogram::HistogramMatrix;
7use crate::node::SplittableNode;
8use crate::tree::Tree;
9use crate::utils::{
10    between, bound_to_parent, constrained_weight, cull_gain, gain_given_weight, pivot_on_split,
11    pivot_on_split_exclude_missing,
12};
13
14#[derive(Debug)]
15pub struct SplitInfo {
16    pub split_gain: f32,
17    pub split_feature: usize,
18    pub split_value: f64,
19    pub split_bin: u16,
20    pub left_node: NodeInfo,
21    pub right_node: NodeInfo,
22    pub missing_node: MissingInfo,
23}
24
25#[derive(Debug)]
26pub struct NodeInfo {
27    pub grad: f32,
28    pub gain: f32,
29    pub cover: f32,
30    pub weight: f32,
31    pub bounds: (f32, f32),
32}
33
34#[derive(Debug)]
35pub enum MissingInfo {
36    Left,
37    Right,
38    Leaf(NodeInfo),
39    Branch(NodeInfo),
40}
41
42pub trait Splitter {
43    /// When a split happens, how many leaves will the tree increase by?
44    /// For example, if a binary split happens, the split will increase the
45    /// number of leaves by 1, if a ternary split happens, the number of leaves will
46    /// increase by 2.
47    fn new_leaves_added(&self) -> usize {
48        1
49    }
50    fn get_constraint(&self, feature: &usize) -> Option<&Constraint>;
51    // fn get_allow_missing_splits(&self) -> bool;
52    fn get_gamma(&self) -> f32;
53    fn get_l1(&self) -> f32;
54    fn get_l2(&self) -> f32;
55    fn get_max_delta_step(&self) -> f32;
56    fn get_learning_rate(&self) -> f32;
57
58    /// Perform any post processing on the tree that is
59    /// relevant for the specific splitter, empty default
60    /// implementation so that it can be called even if it's
61    /// not used.
62    fn clean_up_splits(&self, _tree: &mut Tree) {}
63
64    /// Find the best possible split, considering all feature histograms.
65    /// If we wanted to add Column sampling, this is probably where
66    /// we would need to do it, otherwise, it would be at the tree level.
67    fn best_split(&self, node: &SplittableNode, col_index: &[usize]) -> Option<SplitInfo> {
68        let mut best_split_info = None;
69        let mut best_gain = 0.0;
70        for (idx, feature) in col_index.iter().enumerate() {
71            let split_info = self.best_feature_split(node, *feature, idx);
72            match split_info {
73                Some(info) => {
74                    if info.split_gain > best_gain {
75                        best_gain = info.split_gain;
76                        best_split_info = Some(info);
77                    }
78                }
79                None => continue,
80            }
81        }
82        best_split_info
83    }
84
85    /// Evaluate a split, returning the node info for the left, and right splits,
86    /// as well as the node info the missing data of a feature.
87    #[allow(clippy::too_many_arguments)]
88    fn evaluate_split(
89        &self,
90        left_gradient: f32,
91        left_hessian: f32,
92        right_gradient: f32,
93        right_hessian: f32,
94        missing_gradient: f32,
95        missing_hessian: f32,
96        lower_bound: f32,
97        upper_bound: f32,
98        parent_weight: f32,
99        constraint: Option<&Constraint>,
100    ) -> Option<(NodeInfo, NodeInfo, MissingInfo)>;
101
102    /// The idx is the index of the feature in the histogram data, whereas feature
103    /// is the index of the actual feature in the data.
104    fn best_feature_split(
105        &self,
106        node: &SplittableNode,
107        feature: usize,
108        idx: usize,
109    ) -> Option<SplitInfo> {
110        let mut split_info: Option<SplitInfo> = None;
111        let mut max_gain: Option<f32> = None;
112
113        let HistogramMatrix(histograms) = &node.histograms;
114        let histogram = histograms.get_col(idx);
115
116        // We also know we will have a missing bin.
117        let missing = &histogram[0];
118        let mut cuml_grad = 0.0; // first_bin.gradient_sum;
119        let mut cuml_hess = 0.0; // first_bin.hessian_sum;
120        let constraint = self.get_constraint(&feature);
121
122        let elements = histogram.len();
123        assert!(elements == histogram.len());
124
125        for (i, bin) in histogram[1..].iter().enumerate() {
126            let left_gradient = cuml_grad;
127            let left_hessian = cuml_hess;
128            let right_gradient = node.gradient_sum - cuml_grad - missing.gradient_sum;
129            let right_hessian = node.hessian_sum - cuml_hess - missing.hessian_sum;
130            cuml_grad += bin.gradient_sum;
131            cuml_hess += bin.hessian_sum;
132
133            let (mut left_node_info, mut right_node_info, missing_info) = match self.evaluate_split(
134                left_gradient,
135                left_hessian,
136                right_gradient,
137                right_hessian,
138                missing.gradient_sum,
139                missing.hessian_sum,
140                node.lower_bound,
141                node.upper_bound,
142                node.weight_value,
143                constraint,
144            ) {
145                None => {
146                    continue;
147                }
148                Some(v) => v,
149            };
150
151            let split_gain = node.get_split_gain(
152                &left_node_info,
153                &right_node_info,
154                &missing_info,
155                self.get_gamma(),
156            );
157
158            // Check monotonicity holds
159            let split_gain = cull_gain(
160                split_gain,
161                left_node_info.weight,
162                right_node_info.weight,
163                constraint,
164            );
165
166            if split_gain <= 0.0 {
167                continue;
168            }
169
170            let mid = (left_node_info.weight + right_node_info.weight) / 2.0;
171            let (left_bounds, right_bounds) = match constraint {
172                None | Some(Constraint::Unconstrained) => (
173                    (node.lower_bound, node.upper_bound),
174                    (node.lower_bound, node.upper_bound),
175                ),
176                Some(Constraint::Negative) => ((mid, node.upper_bound), (node.lower_bound, mid)),
177                Some(Constraint::Positive) => ((node.lower_bound, mid), (mid, node.upper_bound)),
178            };
179            left_node_info.bounds = left_bounds;
180            right_node_info.bounds = right_bounds;
181
182            // If split gain is NaN, one of the sides is empty, do not allow
183            // this split.
184            let split_gain = if split_gain.is_nan() { 0.0 } else { split_gain };
185            if max_gain.is_none() || split_gain > max_gain.unwrap() {
186                max_gain = Some(split_gain);
187                split_info = Some(SplitInfo {
188                    split_gain,
189                    split_feature: feature,
190                    split_value: bin.cut_value,
191                    split_bin: (i + 1) as u16,
192                    left_node: left_node_info,
193                    right_node: right_node_info,
194                    missing_node: missing_info,
195                });
196            }
197        }
198        split_info
199    }
200
201    /// Handle the split info, creating the children nodes, this function
202    /// will return a vector of new splitable nodes, that can be added to the
203    /// growable stack, and further split, or converted to leaf nodes.
204    #[allow(clippy::too_many_arguments)]
205    fn handle_split_info(
206        &self,
207        split_info: SplitInfo,
208        n_nodes: &usize,
209        node: &mut SplittableNode,
210        index: &mut [usize],
211        col_index: &[usize],
212        data: &Matrix<u16>,
213        cuts: &JaggedMatrix<f64>,
214        grad: &[f32],
215        hess: &[f32],
216        parallel: bool,
217    ) -> Vec<SplittableNode>;
218
219    /// Split the node, if we cant find a best split, we will need to
220    /// return an empty vector, this node is a leaf.
221    #[allow(clippy::too_many_arguments)]
222    fn split_node(
223        &self,
224        n_nodes: &usize,
225        node: &mut SplittableNode,
226        index: &mut [usize],
227        col_index: &[usize],
228        data: &Matrix<u16>,
229        cuts: &JaggedMatrix<f64>,
230        grad: &[f32],
231        hess: &[f32],
232        parallel: bool,
233    ) -> Vec<SplittableNode> {
234        match self.best_split(node, col_index) {
235            Some(split_info) => self.handle_split_info(
236                split_info, n_nodes, node, index, col_index, data, cuts, grad, hess, parallel,
237            ),
238            None => Vec::new(),
239        }
240    }
241}
242
243/// Missing branch splitter
244/// Always creates a separate branch for the missing values of a feature.
245/// This results, in every node having a specific "missing", direction.
246/// If this node is able, it will be split further, otherwise it will
247/// a leaf node will be generated.
248pub struct MissingBranchSplitter {
249    pub l1: f32,
250    pub l2: f32,
251    pub max_delta_step: f32,
252    pub gamma: f32,
253    pub min_leaf_weight: f32,
254    pub learning_rate: f32,
255    pub allow_missing_splits: bool,
256    pub constraints_map: ConstraintMap,
257    pub terminate_missing_features: HashSet<usize>,
258    pub missing_node_treatment: MissingNodeTreatment,
259    pub force_children_to_bound_parent: bool,
260}
261
262impl MissingBranchSplitter {
263    pub fn new_leaves_added(&self) -> usize {
264        2
265    }
266    pub fn update_average_missing_nodes(tree: &mut Tree, node_idx: usize) -> f64 {
267        let node = &tree.nodes[node_idx];
268
269        if node.is_leaf {
270            return node.weight_value as f64;
271        }
272
273        let right = node.right_child;
274        let left = node.left_child;
275        let current_node = node.num;
276        let missing = node.missing_node;
277
278        let right_hessian = tree.nodes[right].hessian_sum as f64;
279        let right_avg_weight = Self::update_average_missing_nodes(tree, right);
280
281        let left_hessian = tree.nodes[left].hessian_sum as f64;
282        let left_avg_weight = Self::update_average_missing_nodes(tree, left);
283
284        // This way this process supports missing branches that terminate (and will neutralize)
285        // and then if missing is split further those values will have non-zero contributions.
286        let (missing_hessian, missing_avg_weight, missing_leaf) = if tree.nodes[missing].is_leaf {
287            (0., 0., true)
288        } else {
289            (
290                tree.nodes[missing].hessian_sum as f64,
291                Self::update_average_missing_nodes(tree, missing),
292                false,
293            )
294        };
295
296        let update = (right_avg_weight * right_hessian
297            + left_avg_weight * left_hessian
298            + missing_avg_weight * missing_hessian)
299            / (left_hessian + right_hessian + missing_hessian);
300
301        // Update current node, and the missing value
302        if let Some(n) = tree.nodes.get_mut(current_node) {
303            n.weight_value = update as f32;
304        }
305        // Only update the missing node if it's a leaf, otherwise we will auto-update
306        // them via the recursion called earlier.
307        if missing_leaf {
308            if let Some(m) = tree.nodes.get_mut(missing) {
309                m.weight_value = update as f32;
310            }
311        }
312
313        update
314    }
315}
316
317impl Splitter for MissingBranchSplitter {
318    fn clean_up_splits(&self, tree: &mut Tree) {
319        if let MissingNodeTreatment::AverageLeafWeight = self.missing_node_treatment {
320            MissingBranchSplitter::update_average_missing_nodes(tree, 0);
321        }
322    }
323
324    fn get_constraint(&self, feature: &usize) -> Option<&Constraint> {
325        self.constraints_map.get(feature)
326    }
327
328    fn get_gamma(&self) -> f32 {
329        self.gamma
330    }
331
332    fn get_l1(&self) -> f32 {
333        self.l1
334    }
335
336    fn get_l2(&self) -> f32 {
337        self.l2
338    }
339    fn get_max_delta_step(&self) -> f32 {
340        self.max_delta_step
341    }
342
343    fn get_learning_rate(&self) -> f32 {
344        self.learning_rate
345    }
346
347    fn evaluate_split(
348        &self,
349        left_gradient: f32,
350        left_hessian: f32,
351        right_gradient: f32,
352        right_hessian: f32,
353        missing_gradient: f32,
354        missing_hessian: f32,
355        lower_bound: f32,
356        upper_bound: f32,
357        parent_weight: f32,
358        constraint: Option<&Constraint>,
359    ) -> Option<(NodeInfo, NodeInfo, MissingInfo)> {
360        // If there is no info right, or there is no
361        // info left, there is nothing to split on,
362        // and so we should continue.
363        if (left_gradient == 0.0) && (left_hessian == 0.0)
364            || (right_gradient == 0.0) && (right_hessian == 0.0)
365        {
366            return None;
367        }
368
369        let mut left_weight = constrained_weight(
370            &self.l1,
371            &self.l2,
372            &self.max_delta_step,
373            left_gradient,
374            left_hessian,
375            lower_bound,
376            upper_bound,
377            constraint,
378        );
379        let mut right_weight = constrained_weight(
380            &self.l1,
381            &self.l2,
382            &self.max_delta_step,
383            right_gradient,
384            right_hessian,
385            lower_bound,
386            upper_bound,
387            constraint,
388        );
389
390        if self.force_children_to_bound_parent {
391            (left_weight, right_weight) = bound_to_parent(parent_weight, left_weight, right_weight);
392            assert!(between(lower_bound, upper_bound, left_weight));
393            assert!(between(lower_bound, upper_bound, right_weight));
394        }
395
396        let left_gain = gain_given_weight(&self.l2, left_gradient, left_hessian, left_weight);
397        let right_gain = gain_given_weight(&self.l2, right_gradient, right_hessian, right_weight);
398
399        // Check the min_hessian constraint first
400        if (right_hessian < self.min_leaf_weight) || (left_hessian < self.min_leaf_weight) {
401            // Update for new value
402            return None;
403        }
404
405        // We have not considered missing at all up until this point, we could if we wanted
406        // to give more predictive power probably to missing.
407        // If we don't want to allow the missing branch to be split further,
408        // we will default to creating an empty branch.
409
410        // Set weight based on the missing node treatment.
411        let missing_weight = match self.missing_node_treatment {
412            MissingNodeTreatment::AssignToParent => constrained_weight(
413                &self.get_l1(),
414                &self.get_l2(),
415                &self.max_delta_step,
416                missing_gradient + left_gradient + right_gradient,
417                missing_hessian + left_hessian + right_hessian,
418                lower_bound,
419                upper_bound,
420                constraint,
421            ),
422            // Calculate the local leaf average for now, after training the tree.
423            // Recursively assign to the leaf weights underneath.
424            MissingNodeTreatment::AverageLeafWeight | MissingNodeTreatment::AverageNodeWeight => {
425                (right_weight * right_hessian + left_weight * left_hessian)
426                    / (right_hessian + left_hessian)
427            }
428            MissingNodeTreatment::None => {
429                // If there are no missing records, just default
430                // to the parent weight.
431                if missing_hessian == 0. || missing_gradient == 0. {
432                    parent_weight
433                } else {
434                    constrained_weight(
435                        &self.get_l1(),
436                        &self.get_l2(),
437                        &self.max_delta_step,
438                        missing_gradient,
439                        missing_hessian,
440                        lower_bound,
441                        upper_bound,
442                        constraint,
443                    )
444                }
445            }
446        };
447        let missing_gain = gain_given_weight(
448            &self.get_l2(),
449            missing_gradient,
450            missing_hessian,
451            missing_weight,
452        );
453        let missing_info = NodeInfo {
454            grad: missing_gradient,
455            gain: missing_gain,
456            cover: missing_hessian,
457            weight: missing_weight,
458            // Constrain to the same bounds as the parent.
459            // This will ensure that splits further down in the missing only
460            // branch are monotonic.
461            bounds: (lower_bound, upper_bound),
462        };
463        let missing_node = // Check Missing direction
464        if ((missing_gradient != 0.0) || (missing_hessian != 0.0)) && self.allow_missing_splits {
465            MissingInfo::Branch(
466                missing_info
467            )
468        } else {
469            MissingInfo::Leaf(
470                missing_info
471            )
472        };
473
474        if (right_hessian < self.min_leaf_weight) || (left_hessian < self.min_leaf_weight) {
475            // Update for new value
476            return None;
477        }
478        Some((
479            NodeInfo {
480                grad: left_gradient,
481                gain: left_gain,
482                cover: left_hessian,
483                weight: left_weight,
484                bounds: (f32::NEG_INFINITY, f32::INFINITY),
485            },
486            NodeInfo {
487                grad: right_gradient,
488                gain: right_gain,
489                cover: right_hessian,
490                weight: right_weight,
491                bounds: (f32::NEG_INFINITY, f32::INFINITY),
492            },
493            missing_node,
494        ))
495    }
496
497    fn handle_split_info(
498        &self,
499        split_info: SplitInfo,
500        n_nodes: &usize,
501        node: &mut SplittableNode,
502        index: &mut [usize],
503        col_index: &[usize],
504        data: &Matrix<u16>,
505        cuts: &JaggedMatrix<f64>,
506        grad: &[f32],
507        hess: &[f32],
508        parallel: bool,
509    ) -> Vec<SplittableNode> {
510        let missing_child = *n_nodes;
511        let left_child = missing_child + 1;
512        let right_child = missing_child + 2;
513        node.update_children(missing_child, left_child, right_child, &split_info);
514
515        let (mut missing_is_leaf, mut missing_info) = match split_info.missing_node {
516            MissingInfo::Branch(i) => {
517                if self
518                    .terminate_missing_features
519                    .contains(&split_info.split_feature)
520                {
521                    (true, i)
522                } else {
523                    (false, i)
524                }
525            }
526            MissingInfo::Leaf(i) => (true, i),
527            _ => unreachable!(),
528        };
529        // Set missing weight to parent weight value...
530        // This essentially neutralizes missing.
531        // Manually calculating it, was leading to some small numeric
532        // rounding differences...
533        if let MissingNodeTreatment::AssignToParent = self.missing_node_treatment {
534            missing_info.weight = node.weight_value;
535        }
536        // We need to move all of the index's above and below our
537        // split value.
538        // pivot the sub array that this node has on our split value
539        // Missing all falls to the bottom.
540        let (mut missing_split_idx, mut split_idx) = pivot_on_split_exclude_missing(
541            &mut index[node.start_idx..node.stop_idx],
542            data.get_col(split_info.split_feature),
543            split_info.split_bin,
544        );
545        // Calculate histograms
546        let total_recs = node.stop_idx - node.start_idx;
547        let n_right = total_recs - split_idx;
548        let n_left = total_recs - n_right - missing_split_idx;
549        let n_missing = total_recs - (n_right + n_left);
550        let max_ = match [n_missing, n_left, n_right]
551            .iter()
552            .enumerate()
553            .max_by(|(_, i), (_, j)| i.cmp(j))
554        {
555            Some((i, _)) => i,
556            // if we can't compare them, it doesn't
557            // really matter, build the histogram on
558            // any of them.
559            None => 0,
560        };
561
562        // Now that we have calculated the number of records
563        // add the start index, to make the split_index
564        // relative to the entire index array
565        split_idx += node.start_idx;
566        missing_split_idx += node.start_idx;
567
568        // Build the histograms for the smaller node.
569        let left_histograms: HistogramMatrix;
570        let right_histograms: HistogramMatrix;
571        let missing_histograms: HistogramMatrix;
572        if n_missing == 0 {
573            // If there are no missing records, we know the missing value
574            // will be a leaf, assign this node as a leaf.
575            missing_is_leaf = true;
576            if max_ == 1 {
577                missing_histograms = HistogramMatrix::empty();
578                right_histograms = HistogramMatrix::new(
579                    data,
580                    cuts,
581                    grad,
582                    hess,
583                    &index[split_idx..node.stop_idx],
584                    col_index,
585                    parallel,
586                    true,
587                );
588                left_histograms =
589                    HistogramMatrix::from_parent_child(&node.histograms, &right_histograms);
590            } else {
591                missing_histograms = HistogramMatrix::empty();
592                left_histograms = HistogramMatrix::new(
593                    data,
594                    cuts,
595                    grad,
596                    hess,
597                    &index[missing_split_idx..split_idx],
598                    col_index,
599                    parallel,
600                    true,
601                );
602                right_histograms =
603                    HistogramMatrix::from_parent_child(&node.histograms, &left_histograms);
604            }
605        } else if max_ == 0 {
606            // Max is missing, calculate the other two
607            // levels histograms.
608            left_histograms = HistogramMatrix::new(
609                data,
610                cuts,
611                grad,
612                hess,
613                &index[missing_split_idx..split_idx],
614                col_index,
615                parallel,
616                true,
617            );
618            right_histograms = HistogramMatrix::new(
619                data,
620                cuts,
621                grad,
622                hess,
623                &index[split_idx..node.stop_idx],
624                col_index,
625                parallel,
626                true,
627            );
628            missing_histograms = HistogramMatrix::from_parent_two_children(
629                &node.histograms,
630                &left_histograms,
631                &right_histograms,
632            )
633        } else if max_ == 1 {
634            missing_histograms = HistogramMatrix::new(
635                data,
636                cuts,
637                grad,
638                hess,
639                &index[node.start_idx..missing_split_idx],
640                col_index,
641                parallel,
642                true,
643            );
644            right_histograms = HistogramMatrix::new(
645                data,
646                cuts,
647                grad,
648                hess,
649                &index[split_idx..node.stop_idx],
650                col_index,
651                parallel,
652                true,
653            );
654            left_histograms = HistogramMatrix::from_parent_two_children(
655                &node.histograms,
656                &missing_histograms,
657                &right_histograms,
658            )
659        } else {
660            // right is the largest
661            missing_histograms = HistogramMatrix::new(
662                data,
663                cuts,
664                grad,
665                hess,
666                &index[node.start_idx..missing_split_idx],
667                col_index,
668                parallel,
669                true,
670            );
671            left_histograms = HistogramMatrix::new(
672                data,
673                cuts,
674                grad,
675                hess,
676                &index[missing_split_idx..split_idx],
677                col_index,
678                parallel,
679                true,
680            );
681            right_histograms = HistogramMatrix::from_parent_two_children(
682                &node.histograms,
683                &missing_histograms,
684                &left_histograms,
685            )
686        }
687
688        let mut missing_node = SplittableNode::from_node_info(
689            missing_child,
690            missing_histograms,
691            node.depth + 1,
692            node.start_idx,
693            missing_split_idx,
694            missing_info,
695        );
696        missing_node.is_missing_leaf = missing_is_leaf;
697        let left_node = SplittableNode::from_node_info(
698            left_child,
699            left_histograms,
700            node.depth + 1,
701            missing_split_idx,
702            split_idx,
703            split_info.left_node,
704        );
705        let right_node = SplittableNode::from_node_info(
706            right_child,
707            right_histograms,
708            node.depth + 1,
709            split_idx,
710            node.stop_idx,
711            split_info.right_node,
712        );
713        vec![missing_node, left_node, right_node]
714    }
715}
716
717/// Missing imputer splitter
718/// Splitter that imputes missing values, by sending
719/// them down either the right or left branch, depending
720/// on which results in a higher increase in gain.
721pub struct MissingImputerSplitter {
722    pub l1: f32,
723    pub l2: f32,
724    pub max_delta_step: f32,
725    pub gamma: f32,
726    pub min_leaf_weight: f32,
727    pub learning_rate: f32,
728    pub allow_missing_splits: bool,
729    pub constraints_map: ConstraintMap,
730}
731
732impl MissingImputerSplitter {
733    /// Generate a new missing imputer splitter object.
734    #[allow(clippy::too_many_arguments)]
735    pub fn new(
736        l1: f32,
737        l2: f32,
738        max_delta_step: f32,
739        gamma: f32,
740        min_leaf_weight: f32,
741        learning_rate: f32,
742        allow_missing_splits: bool,
743        constraints_map: ConstraintMap,
744    ) -> Self {
745        MissingImputerSplitter {
746            l1,
747            l2,
748            max_delta_step,
749            gamma,
750            min_leaf_weight,
751            learning_rate,
752            allow_missing_splits,
753            constraints_map,
754        }
755    }
756}
757
758impl Splitter for MissingImputerSplitter {
759    fn get_constraint(&self, feature: &usize) -> Option<&Constraint> {
760        self.constraints_map.get(feature)
761    }
762
763    fn get_gamma(&self) -> f32 {
764        self.gamma
765    }
766
767    fn get_l1(&self) -> f32 {
768        self.l1
769    }
770
771    fn get_l2(&self) -> f32 {
772        self.l2
773    }
774    fn get_max_delta_step(&self) -> f32 {
775        self.max_delta_step
776    }
777
778    fn get_learning_rate(&self) -> f32 {
779        self.learning_rate
780    }
781
782    #[allow(clippy::too_many_arguments)]
783    fn evaluate_split(
784        &self,
785        left_gradient: f32,
786        left_hessian: f32,
787        right_gradient: f32,
788        right_hessian: f32,
789        missing_gradient: f32,
790        missing_hessian: f32,
791        lower_bound: f32,
792        upper_bound: f32,
793        _parent_weight: f32,
794        constraint: Option<&Constraint>,
795    ) -> Option<(NodeInfo, NodeInfo, MissingInfo)> {
796        // If there is no info right, or there is no
797        // info left, we will possibly lead to a missing only
798        // split, if we don't want this, bomb.
799        if ((left_gradient == 0.0) && (left_hessian == 0.0)
800            || (right_gradient == 0.0) && (right_hessian == 0.0))
801            && !self.allow_missing_splits
802        {
803            return None;
804        }
805
806        // By default missing values will go into the right node.
807        let mut missing_info = MissingInfo::Right;
808
809        let mut left_gradient = left_gradient;
810        let mut left_hessian = left_hessian;
811
812        let mut right_gradient = right_gradient;
813        let mut right_hessian = right_hessian;
814
815        let mut left_weight = constrained_weight(
816            &self.l1,
817            &self.l2,
818            &self.max_delta_step,
819            left_gradient,
820            left_hessian,
821            lower_bound,
822            upper_bound,
823            constraint,
824        );
825        let mut right_weight = constrained_weight(
826            &self.l1,
827            &self.l2,
828            &self.max_delta_step,
829            right_gradient,
830            right_hessian,
831            lower_bound,
832            upper_bound,
833            constraint,
834        );
835
836        let mut left_gain = gain_given_weight(&self.l2, left_gradient, left_hessian, left_weight);
837        let mut right_gain =
838            gain_given_weight(&self.l2, right_gradient, right_hessian, right_weight);
839
840        if !self.allow_missing_splits {
841            // Check the min_hessian constraint first, if we do not
842            // want to allow missing only splits.
843            if (right_hessian < self.min_leaf_weight) || (left_hessian < self.min_leaf_weight) {
844                // Update for new value
845                return None;
846            }
847        }
848
849        // Check Missing direction
850        // Don't even worry about it, if there are no missing values
851        // in this bin.
852        if (missing_gradient != 0.0) || (missing_hessian != 0.0) {
853            // If
854            // TODO: Consider making this safer, by casting to f64, summing, and then
855            // back to f32...
856            // The weight if missing went left
857            let missing_left_weight = constrained_weight(
858                &self.l1,
859                &self.l2,
860                &self.max_delta_step,
861                left_gradient + missing_gradient,
862                left_hessian + missing_hessian,
863                lower_bound,
864                upper_bound,
865                constraint,
866            );
867            // The gain if missing went left
868            let missing_left_gain = gain_given_weight(
869                &self.l2,
870                left_gradient + missing_gradient,
871                left_hessian + missing_hessian,
872                missing_left_weight,
873            );
874            // Confirm this wouldn't break monotonicity.
875            let missing_left_gain = cull_gain(
876                missing_left_gain,
877                missing_left_weight,
878                right_weight,
879                constraint,
880            );
881
882            // The gain if missing went right
883            let missing_right_weight = constrained_weight(
884                &self.l1,
885                &self.l2,
886                &self.max_delta_step,
887                right_gradient + missing_gradient,
888                right_hessian + missing_hessian,
889                lower_bound,
890                upper_bound,
891                constraint,
892            );
893            // The gain is missing went right
894            let missing_right_gain = gain_given_weight(
895                &self.l2,
896                right_gradient + missing_gradient,
897                right_hessian + missing_hessian,
898                missing_right_weight,
899            );
900            // Confirm this wouldn't break monotonicity.
901            let missing_right_gain = cull_gain(
902                missing_right_gain,
903                left_weight,
904                missing_right_weight,
905                constraint,
906            );
907
908            if (missing_right_gain - right_gain) < (missing_left_gain - left_gain) {
909                // Missing goes left
910                left_gradient += missing_gradient;
911                left_hessian += missing_hessian;
912                left_gain = missing_left_gain;
913                left_weight = missing_left_weight;
914                missing_info = MissingInfo::Left;
915            } else {
916                // Missing goes right
917                right_gradient += missing_gradient;
918                right_hessian += missing_hessian;
919                right_gain = missing_right_gain;
920                right_weight = missing_right_weight;
921                missing_info = MissingInfo::Right;
922            }
923        }
924
925        if (right_hessian < self.min_leaf_weight) || (left_hessian < self.min_leaf_weight) {
926            // Update for new value
927            return None;
928        }
929        Some((
930            NodeInfo {
931                grad: left_gradient,
932                gain: left_gain,
933                cover: left_hessian,
934                weight: left_weight,
935                bounds: (f32::NEG_INFINITY, f32::INFINITY),
936            },
937            NodeInfo {
938                grad: right_gradient,
939                gain: right_gain,
940                cover: right_hessian,
941                weight: right_weight,
942                bounds: (f32::NEG_INFINITY, f32::INFINITY),
943            },
944            missing_info,
945        ))
946    }
947
948    fn handle_split_info(
949        &self,
950        split_info: SplitInfo,
951        n_nodes: &usize,
952        node: &mut SplittableNode,
953        index: &mut [usize],
954        col_index: &[usize],
955        data: &Matrix<u16>,
956        cuts: &JaggedMatrix<f64>,
957        grad: &[f32],
958        hess: &[f32],
959        parallel: bool,
960    ) -> Vec<SplittableNode> {
961        let left_child = *n_nodes;
962        let right_child = left_child + 1;
963
964        let missing_right = match split_info.missing_node {
965            MissingInfo::Left => false,
966            MissingInfo::Right => true,
967            _ => unreachable!(),
968        };
969
970        // We need to move all of the index's above and below our
971        // split value.
972        // pivot the sub array that this node has on our split value
973        // Here we assign missing to a specific direction.
974        // This will need to be refactored once we add a
975        // separate missing branch.
976        let mut split_idx = pivot_on_split(
977            &mut index[node.start_idx..node.stop_idx],
978            data.get_col(split_info.split_feature),
979            split_info.split_bin,
980            missing_right,
981        );
982        // Calculate histograms
983        let total_recs = node.stop_idx - node.start_idx;
984        let n_right = total_recs - split_idx;
985        let n_left = total_recs - n_right;
986
987        // Now that we have calculated the number of records
988        // add the start index, to make the split_index
989        // relative to the entire index array
990        split_idx += node.start_idx;
991
992        // Build the histograms for the smaller node.
993        let left_histograms: HistogramMatrix;
994        let right_histograms: HistogramMatrix;
995        if n_left < n_right {
996            left_histograms = HistogramMatrix::new(
997                data,
998                cuts,
999                grad,
1000                hess,
1001                &index[node.start_idx..split_idx],
1002                col_index,
1003                parallel,
1004                true,
1005            );
1006            right_histograms =
1007                HistogramMatrix::from_parent_child(&node.histograms, &left_histograms);
1008        } else {
1009            right_histograms = HistogramMatrix::new(
1010                data,
1011                cuts,
1012                grad,
1013                hess,
1014                &index[split_idx..node.stop_idx],
1015                col_index,
1016                parallel,
1017                true,
1018            );
1019            left_histograms =
1020                HistogramMatrix::from_parent_child(&node.histograms, &right_histograms);
1021        }
1022        let missing_child = if missing_right {
1023            right_child
1024        } else {
1025            left_child
1026        };
1027        node.update_children(missing_child, left_child, right_child, &split_info);
1028
1029        let left_node = SplittableNode::from_node_info(
1030            left_child,
1031            left_histograms,
1032            node.depth + 1,
1033            node.start_idx,
1034            split_idx,
1035            split_info.left_node,
1036        );
1037        let right_node = SplittableNode::from_node_info(
1038            right_child,
1039            right_histograms,
1040            node.depth + 1,
1041            split_idx,
1042            node.stop_idx,
1043            split_info.right_node,
1044        );
1045        vec![left_node, right_node]
1046    }
1047}
1048
1049#[cfg(test)]
1050mod tests {
1051    use super::*;
1052    use crate::binning::bin_matrix;
1053    use crate::data::Matrix;
1054    use crate::node::SplittableNode;
1055    use crate::objective::{LogLoss, ObjectiveFunction};
1056    use crate::utils::gain;
1057    use crate::utils::weight;
1058    use std::fs;
1059    #[test]
1060    fn test_best_feature_split() {
1061        let d = vec![4., 2., 3., 4., 5., 1., 4.];
1062        let data = Matrix::new(&d, 7, 1);
1063        let y = vec![0., 0., 0., 1., 1., 0., 1.];
1064        let yhat = vec![0.; 7];
1065        let w = vec![1.; y.len()];
1066        let (grad, hess) = LogLoss::calc_grad_hess(&y, &yhat, &w);
1067        let b = bin_matrix(&data, &w, 10, f64::NAN).unwrap();
1068        let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
1069        let index = data.index.to_owned();
1070        let hists = HistogramMatrix::new(&bdata, &b.cuts, &grad, &hess, &index, &[0], true, true);
1071        let splitter = MissingImputerSplitter {
1072            l1: 0.0,
1073            l2: 0.0,
1074            max_delta_step: 0.,
1075            gamma: 0.0,
1076            min_leaf_weight: 0.0,
1077            learning_rate: 1.0,
1078            allow_missing_splits: true,
1079            constraints_map: ConstraintMap::new(),
1080        };
1081        let mut n = SplittableNode::new(
1082            0,
1083            // vec![0, 1, 2, 3, 4, 5, 6],
1084            hists,
1085            0.0,
1086            0.14,
1087            grad.iter().sum::<f32>(),
1088            hess.iter().sum::<f32>(),
1089            0,
1090            0,
1091            grad.len(),
1092            f32::NEG_INFINITY,
1093            f32::INFINITY,
1094        );
1095        let s = splitter.best_feature_split(&mut n, 0, 0).unwrap();
1096        assert_eq!(s.split_value, 4.0);
1097        assert_eq!(s.left_node.cover, 0.75);
1098        assert_eq!(s.right_node.cover, 1.0);
1099        assert_eq!(s.left_node.gain, 3.0);
1100        assert_eq!(s.right_node.gain, 1.0);
1101        assert_eq!(s.split_gain, 3.86);
1102    }
1103
1104    #[test]
1105    fn test_best_split() {
1106        let d: Vec<f64> = vec![0., 0., 0., 1., 0., 0., 0., 4., 2., 3., 4., 5., 1., 4.];
1107        let data = Matrix::new(&d, 7, 2);
1108        let y = vec![0., 0., 0., 1., 1., 0., 1.];
1109        let yhat = vec![0.; 7];
1110        let w = vec![1.; y.len()];
1111        let (grad, hess) = LogLoss::calc_grad_hess(&y, &yhat, &w);
1112
1113        let b = bin_matrix(&data, &w, 10, f64::NAN).unwrap();
1114        let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
1115        let index = data.index.to_owned();
1116        let hists =
1117            HistogramMatrix::new(&bdata, &b.cuts, &grad, &hess, &index, &[0, 1], true, true);
1118        println!("{:?}", hists);
1119        let splitter = MissingImputerSplitter {
1120            l1: 0.0,
1121            l2: 0.0,
1122            max_delta_step: 0.,
1123            gamma: 0.0,
1124            min_leaf_weight: 0.0,
1125            learning_rate: 1.0,
1126            allow_missing_splits: true,
1127            constraints_map: ConstraintMap::new(),
1128        };
1129        let mut n = SplittableNode::new(
1130            0,
1131            // vec![0, 1, 2, 3, 4, 5, 6],
1132            hists,
1133            0.0,
1134            0.14,
1135            grad.iter().sum::<f32>(),
1136            hess.iter().sum::<f32>(),
1137            0,
1138            0,
1139            grad.len(),
1140            f32::NEG_INFINITY,
1141            f32::INFINITY,
1142        );
1143        let s = splitter.best_split(&mut n, &[0, 1]).unwrap();
1144        println!("{:?}", s);
1145        assert_eq!(s.split_feature, 1);
1146        assert_eq!(s.split_value, 4.);
1147        assert_eq!(s.left_node.cover, 0.75);
1148        assert_eq!(s.right_node.cover, 1.);
1149        assert_eq!(s.left_node.gain, 3.);
1150        assert_eq!(s.right_node.gain, 1.);
1151        assert_eq!(s.split_gain, 3.86);
1152    }
1153
1154    #[test]
1155    fn test_data_split() {
1156        let file = fs::read_to_string("resources/contiguous_no_missing.csv")
1157            .expect("Something went wrong reading the file");
1158        let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
1159        let file = fs::read_to_string("resources/performance.csv")
1160            .expect("Something went wrong reading the file");
1161        let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
1162        let yhat = vec![0.5; y.len()];
1163        let w = vec![1.; y.len()];
1164        let (grad, hess) = LogLoss::calc_grad_hess(&y, &yhat, &w);
1165
1166        let splitter = MissingImputerSplitter {
1167            l1: 0.0,
1168            l2: 1.0,
1169            max_delta_step: 0.,
1170            gamma: 3.0,
1171            min_leaf_weight: 1.0,
1172            learning_rate: 0.3,
1173            allow_missing_splits: true,
1174            constraints_map: ConstraintMap::new(),
1175        };
1176        let gradient_sum = grad.iter().copied().sum();
1177        let hessian_sum = hess.iter().copied().sum();
1178        let root_weight = weight(
1179            &splitter.l1,
1180            &splitter.l2,
1181            &splitter.max_delta_step,
1182            gradient_sum,
1183            hessian_sum,
1184        );
1185        let root_gain = gain(&splitter.l2, gradient_sum, hessian_sum);
1186        let data = Matrix::new(&data_vec, 891, 5);
1187
1188        let b = bin_matrix(&data, &w, 10, f64::NAN).unwrap();
1189        let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
1190        let index = data.index.to_owned();
1191        let col_index: Vec<usize> = (0..data.cols).collect();
1192        let hists = HistogramMatrix::new(
1193            &bdata, &b.cuts, &grad, &hess, &index, &col_index, true, false,
1194        );
1195
1196        let mut n = SplittableNode::new(
1197            0,
1198            // (0..(data.rows - 1)).collect(),
1199            hists,
1200            root_weight,
1201            root_gain,
1202            grad.iter().copied().sum::<f32>(),
1203            hess.iter().copied().sum::<f32>(),
1204            0,
1205            0,
1206            grad.len(),
1207            f32::NEG_INFINITY,
1208            f32::INFINITY,
1209        );
1210        let s = splitter.best_split(&mut n, &col_index).unwrap();
1211        println!("{:?}", s);
1212        n.update_children(2, 1, 2, &s);
1213        assert_eq!(0, s.split_feature);
1214    }
1215}