forust_ml/
tree.rs

1use crate::data::{JaggedMatrix, Matrix};
2use crate::gradientbooster::GrowPolicy;
3use crate::grower::Grower;
4use crate::histogram::HistogramMatrix;
5use crate::node::{Node, SplittableNode};
6use crate::partial_dependence::tree_partial_dependence;
7use crate::sampler::SampleMethod;
8use crate::splitter::Splitter;
9use crate::utils::fast_f64_sum;
10use crate::utils::{gain, odds, weight};
11use rayon::prelude::*;
12use serde::{Deserialize, Serialize};
13use std::collections::{BinaryHeap, HashMap, VecDeque};
14use std::fmt::{self, Display};
15
16#[derive(Deserialize, Serialize)]
17pub struct Tree {
18    pub nodes: Vec<Node>,
19}
20
21impl Default for Tree {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl Tree {
28    pub fn new() -> Self {
29        Tree { nodes: Vec::new() }
30    }
31
32    #[allow(clippy::too_many_arguments)]
33    pub fn fit<T: Splitter>(
34        &mut self,
35        data: &Matrix<u16>,
36        mut index: Vec<usize>,
37        col_index: &[usize],
38        cuts: &JaggedMatrix<f64>,
39        grad: &[f32],
40        hess: &[f32],
41        splitter: &T,
42        max_leaves: usize,
43        max_depth: usize,
44        parallel: bool,
45        sample_method: &SampleMethod,
46        grow_policy: &GrowPolicy,
47    ) {
48        // Recreating the index for each tree, ensures that the tree construction is faster
49        // for the root node. This also ensures that sorting the records is always fast,
50        // because we are starting from a nearly sorted array.
51        let (gradient_sum, hessian_sum, sort) = match sample_method {
52            // We don't need to sort, if we are not sampling. This is because
53            // the data is already sorted.
54            SampleMethod::None => (fast_f64_sum(grad), fast_f64_sum(hess), false),
55            _ => {
56                // Accumulate using f64 for numeric fidelity.
57                let mut gs: f64 = 0.;
58                let mut hs: f64 = 0.;
59                for i in index.iter() {
60                    let i_ = *i;
61                    gs += grad[i_] as f64;
62                    hs += hess[i_] as f64;
63                }
64                (gs as f32, hs as f32, true)
65            }
66        };
67
68        let mut n_nodes = 1;
69        let root_gain = gain(&splitter.get_l2(), gradient_sum, hessian_sum);
70        let root_weight = weight(
71            &splitter.get_l1(),
72            &splitter.get_l2(),
73            &splitter.get_max_delta_step(),
74            gradient_sum,
75            hessian_sum,
76        );
77        // Calculate the histograms for the root node.
78        let root_hists =
79            HistogramMatrix::new(data, cuts, grad, hess, &index, col_index, parallel, sort);
80        let root_node = SplittableNode::new(
81            0,
82            root_hists,
83            root_weight,
84            root_gain,
85            gradient_sum,
86            hessian_sum,
87            0,
88            0,
89            index.len(),
90            f32::NEG_INFINITY,
91            f32::INFINITY,
92        );
93        // Add the first node to the tree nodes.
94        self.nodes
95            .push(root_node.as_node(splitter.get_learning_rate()));
96        let mut n_leaves = 1;
97
98        let mut growable: Box<dyn Grower> = match grow_policy {
99            GrowPolicy::DepthWise => Box::<VecDeque<SplittableNode>>::default(),
100            GrowPolicy::LossGuide => Box::<BinaryHeap<SplittableNode>>::default(),
101        };
102
103        growable.add_node(root_node);
104        while !growable.is_empty() {
105            // If this will push us over the max leaves parameter, break.
106            if (n_leaves + splitter.new_leaves_added()) > max_leaves {
107                break;
108            }
109            // We know there is a value here, because of how the
110            // while loop is setup.
111            // Grab a splitable node from the stack
112            // If we can split it, and update the corresponding
113            // tree nodes children.
114            let mut node = growable.get_next_node();
115            let n_idx = node.num;
116
117            let depth = node.depth + 1;
118
119            // If we have hit max depth, skip this node
120            // but keep going, because there may be other
121            // valid shallower nodes.
122            if depth > max_depth {
123                continue;
124            }
125
126            // For max_leaves, subtract 1 from the n_leaves
127            // every time we pop from the growable stack
128            // then, if we can add two children, add two to
129            // n_leaves. If we can't split the node any
130            // more, then just add 1 back to n_leaves
131            n_leaves -= 1;
132
133            let new_nodes = splitter.split_node(
134                &n_nodes, &mut node, &mut index, col_index, data, cuts, grad, hess, parallel,
135            );
136
137            let n_new_nodes = new_nodes.len();
138            if n_new_nodes == 0 {
139                n_leaves += 1;
140            } else {
141                self.nodes[n_idx].make_parent_node(node);
142                n_leaves += n_new_nodes;
143                n_nodes += n_new_nodes;
144                for n in new_nodes {
145                    self.nodes.push(n.as_node(splitter.get_learning_rate()));
146                    if !n.is_missing_leaf {
147                        growable.add_node(n)
148                    }
149                }
150            }
151        }
152
153        // Any final post processing required.
154        splitter.clean_up_splits(self);
155    }
156
157    pub fn predict_contributions_row_probability_change(
158        &self,
159        row: &[f64],
160        contribs: &mut [f64],
161        missing: &f64,
162        current_logodds: f64,
163    ) -> f64 {
164        contribs[contribs.len() - 1] +=
165            odds(current_logodds + self.nodes[0].weight_value as f64) - odds(current_logodds);
166        let mut node_idx = 0;
167        let mut lo = current_logodds;
168        loop {
169            let node = &self.nodes[node_idx];
170            let node_odds = odds(node.weight_value as f64 + current_logodds);
171            if node.is_leaf {
172                lo += node.weight_value as f64;
173                break;
174            }
175            // Get change of weight given child's weight.
176            let child_idx = node.get_child_idx(&row[node.split_feature], missing);
177            let child_odds = odds(self.nodes[child_idx].weight_value as f64 + current_logodds);
178            let delta = child_odds - node_odds;
179            contribs[node.split_feature] += delta;
180            node_idx = child_idx;
181        }
182        lo
183    }
184
185    // Branch average difference predictions
186    pub fn predict_contributions_row_midpoint_difference(
187        &self,
188        row: &[f64],
189        contribs: &mut [f64],
190        missing: &f64,
191    ) {
192        // Bias term is left as 0.
193
194        let mut node_idx = 0;
195        loop {
196            let node = &self.nodes[node_idx];
197            if node.is_leaf {
198                break;
199            }
200            // Get change of weight given child's weight.
201            //       p
202            //    / | \
203            //   l  m  r
204            //
205            // where l < r and we are going down r
206            // The contribution for a would be r - l.
207
208            let child_idx = node.get_child_idx(&row[node.split_feature], missing);
209            let child = &self.nodes[child_idx];
210            // If we are going down the missing branch, do nothing and leave
211            // it at zero.
212            if node.has_missing_branch() && child_idx == node.missing_node {
213                node_idx = child_idx;
214                continue;
215            }
216            let other_child = if child_idx == node.left_child {
217                &self.nodes[node.right_child]
218            } else {
219                &self.nodes[node.left_child]
220            };
221            let mid = (child.weight_value * child.hessian_sum
222                + other_child.weight_value * other_child.hessian_sum)
223                / (child.hessian_sum + other_child.hessian_sum);
224            let delta = child.weight_value - mid;
225            contribs[node.split_feature] += delta as f64;
226            node_idx = child_idx;
227        }
228    }
229
230    // Branch difference predictions.
231    pub fn predict_contributions_row_branch_difference(
232        &self,
233        row: &[f64],
234        contribs: &mut [f64],
235        missing: &f64,
236    ) {
237        // Bias term is left as 0.
238
239        let mut node_idx = 0;
240        loop {
241            let node = &self.nodes[node_idx];
242            if node.is_leaf {
243                break;
244            }
245            // Get change of weight given child's weight.
246            //       p
247            //    / | \
248            //   l  m  r
249            //
250            // where l < r and we are going down r
251            // The contribution for a would be r - l.
252
253            let child_idx = node.get_child_idx(&row[node.split_feature], missing);
254            // If we are going down the missing branch, do nothing and leave
255            // it at zero.
256            if node.has_missing_branch() && child_idx == node.missing_node {
257                node_idx = child_idx;
258                continue;
259            }
260            let other_child = if child_idx == node.left_child {
261                &self.nodes[node.right_child]
262            } else {
263                &self.nodes[node.left_child]
264            };
265            let delta = self.nodes[child_idx].weight_value - other_child.weight_value;
266            contribs[node.split_feature] += delta as f64;
267            node_idx = child_idx;
268        }
269    }
270
271    // How does the travelled childs weight change relative to the
272    // mode branch.
273    pub fn predict_contributions_row_mode_difference(
274        &self,
275        row: &[f64],
276        contribs: &mut [f64],
277        missing: &f64,
278    ) {
279        // Bias term is left as 0.
280        let mut node_idx = 0;
281        loop {
282            let node = &self.nodes[node_idx];
283            if node.is_leaf {
284                break;
285            }
286
287            let child_idx = node.get_child_idx(&row[node.split_feature], missing);
288            // If we are going down the missing branch, do nothing and leave
289            // it at zero.
290            if node.has_missing_branch() && child_idx == node.missing_node {
291                node_idx = child_idx;
292                continue;
293            }
294            let left_node = &self.nodes[node.left_child];
295            let right_node = &self.nodes[node.right_child];
296            let child_weight = self.nodes[child_idx].weight_value;
297
298            let delta = if left_node.hessian_sum == right_node.hessian_sum {
299                0.
300            } else if left_node.hessian_sum > right_node.hessian_sum {
301                child_weight - left_node.weight_value
302            } else {
303                child_weight - right_node.weight_value
304            };
305            contribs[node.split_feature] += delta as f64;
306            node_idx = child_idx;
307        }
308    }
309
310    pub fn predict_contributions_row_weight(
311        &self,
312        row: &[f64],
313        contribs: &mut [f64],
314        missing: &f64,
315    ) {
316        // Add the bias term first...
317        contribs[contribs.len() - 1] += self.nodes[0].weight_value as f64;
318        let mut node_idx = 0;
319        loop {
320            let node = &self.nodes[node_idx];
321            if node.is_leaf {
322                break;
323            }
324            // Get change of weight given child's weight.
325            let child_idx = node.get_child_idx(&row[node.split_feature], missing);
326            let node_weight = self.nodes[node_idx].weight_value as f64;
327            let child_weight = self.nodes[child_idx].weight_value as f64;
328            let delta = child_weight - node_weight;
329            contribs[node.split_feature] += delta;
330            node_idx = child_idx
331        }
332    }
333
334    pub fn predict_contributions_weight(
335        &self,
336        data: &Matrix<f64>,
337        contribs: &mut [f64],
338        missing: &f64,
339    ) {
340        // There needs to always be at least 2 trees
341        data.index
342            .par_iter()
343            .zip(contribs.par_chunks_mut(data.cols + 1))
344            .for_each(|(row, contribs)| {
345                self.predict_contributions_row_weight(&data.get_row(*row), contribs, missing)
346            })
347    }
348
349    /// This is the method that XGBoost uses.
350    pub fn predict_contributions_row_average(
351        &self,
352        row: &[f64],
353        contribs: &mut [f64],
354        weights: &[f64],
355        missing: &f64,
356    ) {
357        // Add the bias term first...
358        contribs[contribs.len() - 1] += weights[0];
359        let mut node_idx = 0;
360        loop {
361            let node = &self.nodes[node_idx];
362            if node.is_leaf {
363                break;
364            }
365            // Get change of weight given child's weight.
366            let child_idx = node.get_child_idx(&row[node.split_feature], missing);
367            let node_weight = weights[node_idx];
368            let child_weight = weights[child_idx];
369            let delta = child_weight - node_weight;
370            contribs[node.split_feature] += delta;
371            node_idx = child_idx
372        }
373    }
374
375    pub fn predict_contributions_average(
376        &self,
377        data: &Matrix<f64>,
378        contribs: &mut [f64],
379        weights: &[f64],
380        missing: &f64,
381    ) {
382        // There needs to always be at least 2 trees
383        data.index
384            .par_iter()
385            .zip(contribs.par_chunks_mut(data.cols + 1))
386            .for_each(|(row, contribs)| {
387                self.predict_contributions_row_average(
388                    &data.get_row(*row),
389                    contribs,
390                    weights,
391                    missing,
392                )
393            })
394    }
395
396    fn predict_leaf(&self, data: &Matrix<f64>, row: usize, missing: &f64) -> &Node {
397        let mut node_idx = 0;
398        loop {
399            let node = &self.nodes[node_idx];
400            if node.is_leaf {
401                return node;
402            } else {
403                node_idx = node.get_child_idx(data.get(row, node.split_feature), missing);
404            }
405        }
406    }
407
408    pub fn predict_row_from_row_slice(&self, row: &[f64], missing: &f64) -> f64 {
409        let mut node_idx = 0;
410        loop {
411            let node = &self.nodes[node_idx];
412            if node.is_leaf {
413                return node.weight_value as f64;
414            } else {
415                node_idx = node.get_child_idx(&row[node.split_feature], missing);
416            }
417        }
418    }
419
420    fn predict_single_threaded(&self, data: &Matrix<f64>, missing: &f64) -> Vec<f64> {
421        data.index
422            .iter()
423            .map(|i| self.predict_leaf(data, *i, missing).weight_value as f64)
424            .collect()
425    }
426
427    fn predict_parallel(&self, data: &Matrix<f64>, missing: &f64) -> Vec<f64> {
428        data.index
429            .par_iter()
430            .map(|i| self.predict_leaf(data, *i, missing).weight_value as f64)
431            .collect()
432    }
433
434    pub fn predict(&self, data: &Matrix<f64>, parallel: bool, missing: &f64) -> Vec<f64> {
435        if parallel {
436            self.predict_parallel(data, missing)
437        } else {
438            self.predict_single_threaded(data, missing)
439        }
440    }
441
442    pub fn predict_leaf_indices(&self, data: &Matrix<f64>, missing: &f64) -> Vec<usize> {
443        data.index
444            .par_iter()
445            .map(|i| self.predict_leaf(data, *i, missing).num)
446            .collect()
447    }
448
449    pub fn value_partial_dependence(&self, feature: usize, value: f64, missing: &f64) -> f64 {
450        tree_partial_dependence(self, 0, feature, value, 1.0, missing)
451    }
452    fn distribute_node_leaf_weights(&self, i: usize, weights: &mut [f64]) -> f64 {
453        let node = &self.nodes[i];
454        let mut w = node.weight_value as f64;
455        if !node.is_leaf {
456            let left_node = &self.nodes[node.left_child];
457            let right_node = &self.nodes[node.right_child];
458            w = left_node.hessian_sum as f64
459                * self.distribute_node_leaf_weights(node.left_child, weights);
460            w += right_node.hessian_sum as f64
461                * self.distribute_node_leaf_weights(node.right_child, weights);
462            // If this a tree with a missing branch.
463            if node.has_missing_branch() {
464                let missing_node = &self.nodes[node.missing_node];
465                w += missing_node.hessian_sum as f64
466                    * self.distribute_node_leaf_weights(node.missing_node, weights);
467            }
468            w /= node.hessian_sum as f64;
469        }
470        weights[i] = w;
471        w
472    }
473    pub fn distribute_leaf_weights(&self) -> Vec<f64> {
474        let mut weights = vec![0.; self.nodes.len()];
475        self.distribute_node_leaf_weights(0, &mut weights);
476        weights
477    }
478
479    pub fn get_average_leaf_weights(&self, i: usize) -> f64 {
480        let node = &self.nodes[i];
481        let mut w = node.weight_value as f64;
482        if node.is_leaf {
483            w
484        } else {
485            let left_node = &self.nodes[node.left_child];
486            let right_node = &self.nodes[node.right_child];
487            w = left_node.hessian_sum as f64 * self.get_average_leaf_weights(node.left_child);
488            w += right_node.hessian_sum as f64 * self.get_average_leaf_weights(node.right_child);
489            // If this a tree with a missing branch.
490            if node.has_missing_branch() {
491                let missing_node = &self.nodes[node.missing_node];
492                w += missing_node.hessian_sum as f64
493                    * self.get_average_leaf_weights(node.missing_node);
494            }
495            w /= node.hessian_sum as f64;
496            w
497        }
498    }
499
500    fn calc_feature_node_stats<F>(
501        &self,
502        calc_stat: &F,
503        node: &Node,
504        stats: &mut HashMap<usize, (f32, usize)>,
505    ) where
506        F: Fn(&Node) -> f32,
507    {
508        if node.is_leaf {
509            return;
510        }
511        stats
512            .entry(node.split_feature)
513            .and_modify(|(v, c)| {
514                *v += calc_stat(node);
515                *c += 1;
516            })
517            .or_insert((calc_stat(node), 1));
518        self.calc_feature_node_stats(calc_stat, &self.nodes[node.left_child], stats);
519        self.calc_feature_node_stats(calc_stat, &self.nodes[node.right_child], stats);
520        if node.has_missing_branch() {
521            self.calc_feature_node_stats(calc_stat, &self.nodes[node.missing_node], stats);
522        }
523    }
524
525    fn get_node_stats<F>(&self, calc_stat: &F, stats: &mut HashMap<usize, (f32, usize)>)
526    where
527        F: Fn(&Node) -> f32,
528    {
529        self.calc_feature_node_stats(calc_stat, &self.nodes[0], stats);
530    }
531
532    pub fn calculate_importance_weight(&self, stats: &mut HashMap<usize, (f32, usize)>) {
533        self.get_node_stats(&|_: &Node| 1., stats);
534    }
535
536    pub fn calculate_importance_gain(&self, stats: &mut HashMap<usize, (f32, usize)>) {
537        self.get_node_stats(&|n: &Node| n.split_gain, stats);
538    }
539
540    pub fn calculate_importance_cover(&self, stats: &mut HashMap<usize, (f32, usize)>) {
541        self.get_node_stats(&|n: &Node| n.hessian_sum, stats);
542    }
543}
544
545impl Display for Tree {
546    // This trait requires `fmt` with this exact signature.
547    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
548        let mut print_buffer: Vec<usize> = vec![0];
549        let mut r = String::new();
550        while let Some(idx) = print_buffer.pop() {
551            let node = &self.nodes[idx];
552            if node.is_leaf {
553                r += format!("{}{}\n", "      ".repeat(node.depth).as_str(), node).as_str();
554            } else {
555                r += format!("{}{}\n", "      ".repeat(node.depth).as_str(), node).as_str();
556                print_buffer.push(node.right_child);
557                print_buffer.push(node.left_child);
558                if node.has_missing_branch() {
559                    print_buffer.push(node.missing_node);
560                }
561            }
562        }
563        write!(f, "{}", r)
564    }
565}
566
567#[cfg(test)]
568mod tests {
569    use super::*;
570    use crate::binning::bin_matrix;
571    use crate::constraints::{Constraint, ConstraintMap};
572    use crate::objective::{LogLoss, ObjectiveFunction};
573    use crate::sampler::{RandomSampler, Sampler};
574    use crate::splitter::MissingImputerSplitter;
575    use crate::utils::precision_round;
576    use rand::rngs::StdRng;
577    use rand::SeedableRng;
578    use std::fs;
579    #[test]
580    fn test_tree_fit_with_subsample() {
581        let file = fs::read_to_string("resources/contiguous_no_missing.csv")
582            .expect("Something went wrong reading the file");
583        let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
584        let file = fs::read_to_string("resources/performance.csv")
585            .expect("Something went wrong reading the file");
586        let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
587        let yhat = vec![0.5; y.len()];
588        let w = vec![1.; y.len()];
589        let (mut g, mut h) = LogLoss::calc_grad_hess(&y, &yhat, &w);
590        // let mut h = LogLoss::calc_hess(&y, &yhat, &w);
591
592        let data = Matrix::new(&data_vec, 891, 5);
593        let splitter = MissingImputerSplitter {
594            l1: 0.0,
595            l2: 1.0,
596            max_delta_step: 0.,
597            gamma: 3.0,
598            min_leaf_weight: 1.0,
599            learning_rate: 0.3,
600            allow_missing_splits: true,
601            constraints_map: ConstraintMap::new(),
602        };
603        let mut tree = Tree::new();
604
605        let b = bin_matrix(&data, &w, 300, f64::NAN).unwrap();
606        let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
607        let mut rng = StdRng::seed_from_u64(0);
608        let (index, excluded) =
609            RandomSampler::new(0.5).sample(&mut rng, &data.index, &mut g, &mut h);
610        assert!(excluded.len() > 0);
611        let col_index: Vec<usize> = (0..data.cols).collect();
612        tree.fit(
613            &bdata,
614            index,
615            &col_index,
616            &b.cuts,
617            &g,
618            &h,
619            &splitter,
620            usize::MAX,
621            5,
622            true,
623            &SampleMethod::Random,
624            &GrowPolicy::DepthWise,
625        );
626    }
627
628    #[test]
629    fn test_tree_fit() {
630        let file = fs::read_to_string("resources/contiguous_no_missing.csv")
631            .expect("Something went wrong reading the file");
632        let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
633        let file = fs::read_to_string("resources/performance.csv")
634            .expect("Something went wrong reading the file");
635        let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
636        let yhat = vec![0.5; y.len()];
637        let w = vec![1.; y.len()];
638        let (g, h) = LogLoss::calc_grad_hess(&y, &yhat, &w);
639
640        let data = Matrix::new(&data_vec, 891, 5);
641        let splitter = MissingImputerSplitter {
642            l1: 0.0,
643            l2: 1.0,
644            max_delta_step: 0.,
645            gamma: 3.0,
646            min_leaf_weight: 1.0,
647            learning_rate: 0.3,
648            allow_missing_splits: true,
649            constraints_map: ConstraintMap::new(),
650        };
651        let mut tree = Tree::new();
652
653        let b = bin_matrix(&data, &w, 300, f64::NAN).unwrap();
654        let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
655        let col_index: Vec<usize> = (0..data.cols).collect();
656        tree.fit(
657            &bdata,
658            data.index.to_owned(),
659            &col_index,
660            &b.cuts,
661            &g,
662            &h,
663            &splitter,
664            usize::MAX,
665            5,
666            true,
667            &SampleMethod::None,
668            &GrowPolicy::DepthWise,
669        );
670
671        // println!("{}", tree);
672        // let preds = tree.predict(&data, false);
673        // println!("{:?}", &preds[0..10]);
674        assert_eq!(25, tree.nodes.len());
675        // Test contributions prediction...
676        let weights = tree.distribute_leaf_weights();
677        let mut contribs = vec![0.; (data.cols + 1) * data.rows];
678        tree.predict_contributions_average(&data, &mut contribs, &weights, &f64::NAN);
679        let full_preds = tree.predict(&data, true, &f64::NAN);
680        assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
681
682        let contribs_preds: Vec<f64> = contribs
683            .chunks(data.cols + 1)
684            .map(|i| i.iter().sum())
685            .collect();
686        println!("{:?}", &contribs[0..10]);
687        println!("{:?}", &contribs_preds[0..10]);
688
689        assert_eq!(contribs_preds.len(), full_preds.len());
690        for (i, j) in full_preds.iter().zip(contribs_preds) {
691            assert_eq!(precision_round(*i, 7), precision_round(j, 7));
692        }
693
694        // Weight contributions
695        let mut contribs = vec![0.; (data.cols + 1) * data.rows];
696        tree.predict_contributions_weight(&data, &mut contribs, &f64::NAN);
697        let full_preds = tree.predict(&data, true, &f64::NAN);
698        assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
699
700        let contribs_preds: Vec<f64> = contribs
701            .chunks(data.cols + 1)
702            .map(|i| i.iter().sum())
703            .collect();
704        println!("{:?}", &contribs[0..10]);
705        println!("{:?}", &contribs_preds[0..10]);
706
707        assert_eq!(contribs_preds.len(), full_preds.len());
708        for (i, j) in full_preds.iter().zip(contribs_preds) {
709            assert_eq!(precision_round(*i, 7), precision_round(j, 7));
710        }
711    }
712
713    #[test]
714    fn test_tree_colsample() {
715        let file = fs::read_to_string("resources/contiguous_no_missing.csv")
716            .expect("Something went wrong reading the file");
717        let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
718        let file = fs::read_to_string("resources/performance.csv")
719            .expect("Something went wrong reading the file");
720        let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
721        let yhat = vec![0.5; y.len()];
722        let w = vec![1.; y.len()];
723        let (g, h) = LogLoss::calc_grad_hess(&y, &yhat, &w);
724
725        let data = Matrix::new(&data_vec, 891, 5);
726        let splitter = MissingImputerSplitter {
727            l1: 0.0,
728            l2: 1.0,
729            max_delta_step: 0.,
730            gamma: 3.0,
731            min_leaf_weight: 1.0,
732            learning_rate: 0.3,
733            allow_missing_splits: true,
734            constraints_map: ConstraintMap::new(),
735        };
736        let mut tree = Tree::new();
737
738        let b = bin_matrix(&data, &w, 300, f64::NAN).unwrap();
739        let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
740        let col_index: Vec<usize> = vec![1, 3];
741        tree.fit(
742            &bdata,
743            data.index.to_owned(),
744            &col_index,
745            &b.cuts,
746            &g,
747            &h,
748            &splitter,
749            usize::MAX,
750            5,
751            false,
752            &SampleMethod::None,
753            &GrowPolicy::DepthWise,
754        );
755        for n in tree.nodes {
756            if !n.is_leaf {
757                assert!((n.split_feature == 1) || (n.split_feature == 3))
758            }
759        }
760    }
761
762    #[test]
763    fn test_tree_fit_monotone() {
764        let file = fs::read_to_string("resources/contiguous_no_missing.csv")
765            .expect("Something went wrong reading the file");
766        let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
767        let file = fs::read_to_string("resources/performance.csv")
768            .expect("Something went wrong reading the file");
769        let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
770        let yhat = vec![0.5; y.len()];
771        let w = vec![1.; y.len()];
772        let (g, h) = LogLoss::calc_grad_hess(&y, &yhat, &w);
773        println!("GRADIENT -- {:?}", h);
774
775        let data_ = Matrix::new(&data_vec, 891, 5);
776        let data = Matrix::new(data_.get_col(1), 891, 1);
777        let map = ConstraintMap::from([(0, Constraint::Negative)]);
778        let splitter = MissingImputerSplitter {
779            l1: 0.0,
780            l2: 1.0,
781            max_delta_step: 0.,
782            gamma: 0.0,
783            min_leaf_weight: 1.0,
784            learning_rate: 0.3,
785            allow_missing_splits: true,
786            constraints_map: map,
787        };
788        let mut tree = Tree::new();
789
790        let b = bin_matrix(&data, &w, 300, f64::NAN).unwrap();
791        let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
792        let col_index: Vec<usize> = (0..data.cols).collect();
793        tree.fit(
794            &bdata,
795            data.index.to_owned(),
796            &col_index,
797            &b.cuts,
798            &g,
799            &h,
800            &splitter,
801            usize::MAX,
802            5,
803            true,
804            &SampleMethod::None,
805            &GrowPolicy::DepthWise,
806        );
807
808        // println!("{}", tree);
809        let mut pred_data_vec = data.get_col(0).to_owned();
810        pred_data_vec.sort_by(|a, b| a.partial_cmp(b).unwrap());
811        pred_data_vec.dedup();
812        let pred_data = Matrix::new(&pred_data_vec, pred_data_vec.len(), 1);
813
814        let preds = tree.predict(&pred_data, false, &f64::NAN);
815        let increasing = preds.windows(2).all(|a| a[0] >= a[1]);
816        assert!(increasing);
817
818        let weights = tree.distribute_leaf_weights();
819
820        // Average contributions
821        let mut contribs = vec![0.; (data.cols + 1) * data.rows];
822        tree.predict_contributions_average(&data, &mut contribs, &weights, &f64::NAN);
823        let full_preds = tree.predict(&data, true, &f64::NAN);
824        assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
825        let contribs_preds: Vec<f64> = contribs
826            .chunks(data.cols + 1)
827            .map(|i| i.iter().sum())
828            .collect();
829        assert_eq!(contribs_preds.len(), full_preds.len());
830        for (i, j) in full_preds.iter().zip(contribs_preds) {
831            assert_eq!(precision_round(*i, 7), precision_round(j, 7));
832        }
833
834        // Weight contributions
835        let mut contribs = vec![0.; (data.cols + 1) * data.rows];
836        tree.predict_contributions_weight(&data, &mut contribs, &f64::NAN);
837        let full_preds = tree.predict(&data, true, &f64::NAN);
838        assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
839        let contribs_preds: Vec<f64> = contribs
840            .chunks(data.cols + 1)
841            .map(|i| i.iter().sum())
842            .collect();
843        assert_eq!(contribs_preds.len(), full_preds.len());
844        for (i, j) in full_preds.iter().zip(contribs_preds) {
845            assert_eq!(precision_round(*i, 7), precision_round(j, 7));
846        }
847    }
848
849    #[test]
850    fn test_tree_fit_lossguide() {
851        let file = fs::read_to_string("resources/contiguous_no_missing.csv")
852            .expect("Something went wrong reading the file");
853        let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
854        let file = fs::read_to_string("resources/performance.csv")
855            .expect("Something went wrong reading the file");
856        let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
857        let yhat = vec![0.5; y.len()];
858        let w = vec![1.; y.len()];
859        let (g, h) = LogLoss::calc_grad_hess(&y, &yhat, &w);
860
861        let data = Matrix::new(&data_vec, 891, 5);
862        let splitter = MissingImputerSplitter {
863            l1: 0.0,
864            l2: 1.0,
865            max_delta_step: 0.,
866            gamma: 3.0,
867            min_leaf_weight: 1.0,
868            learning_rate: 0.3,
869            allow_missing_splits: false,
870            constraints_map: ConstraintMap::new(),
871        };
872        let mut tree = Tree::new();
873
874        let b = bin_matrix(&data, &w, 300, f64::NAN).unwrap();
875        let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
876        let col_index: Vec<usize> = (0..data.cols).collect();
877        tree.fit(
878            &bdata,
879            data.index.to_owned(),
880            &col_index,
881            &b.cuts,
882            &g,
883            &h,
884            &splitter,
885            usize::MAX,
886            usize::MAX,
887            true,
888            &SampleMethod::None,
889            &GrowPolicy::LossGuide,
890        );
891
892        println!("{}", tree);
893        // let preds = tree.predict(&data, false);
894        // println!("{:?}", &preds[0..10]);
895        // assert_eq!(25, tree.nodes.len());
896        // Test contributions prediction...
897        let weights = tree.distribute_leaf_weights();
898        let mut contribs = vec![0.; (data.cols + 1) * data.rows];
899        tree.predict_contributions_average(&data, &mut contribs, &weights, &f64::NAN);
900        let full_preds = tree.predict(&data, true, &f64::NAN);
901        assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
902
903        let contribs_preds: Vec<f64> = contribs
904            .chunks(data.cols + 1)
905            .map(|i| i.iter().sum())
906            .collect();
907        println!("{:?}", &contribs[0..10]);
908        println!("{:?}", &contribs_preds[0..10]);
909
910        assert_eq!(contribs_preds.len(), full_preds.len());
911        for (i, j) in full_preds.iter().zip(contribs_preds) {
912            assert_eq!(precision_round(*i, 7), precision_round(j, 7));
913        }
914
915        // Weight contributions
916        let mut contribs = vec![0.; (data.cols + 1) * data.rows];
917        tree.predict_contributions_weight(&data, &mut contribs, &f64::NAN);
918        let full_preds = tree.predict(&data, true, &f64::NAN);
919        assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
920
921        let contribs_preds: Vec<f64> = contribs
922            .chunks(data.cols + 1)
923            .map(|i| i.iter().sum())
924            .collect();
925        println!("{:?}", &contribs[0..10]);
926        println!("{:?}", &contribs_preds[0..10]);
927
928        assert_eq!(contribs_preds.len(), full_preds.len());
929        for (i, j) in full_preds.iter().zip(contribs_preds) {
930            assert_eq!(precision_round(*i, 7), precision_round(j, 7));
931        }
932    }
933}